Spaces:
Running
Running
File size: 4,728 Bytes
3a82207 7ec13e1 da60d06 08c1bd3 7ec13e1 51fab22 da60d06 7ec13e1 63b82b4 7ec13e1 08c1bd3 da60d06 7ec13e1 da60d06 7ec13e1 3a82207 da60d06 7ec13e1 da60d06 7ec13e1 da60d06 3592f5f da60d06 7ec13e1 da60d06 3a82207 7ec13e1 3a82207 7ec13e1 3a82207 7ec13e1 63b82b4 7ec13e1 da60d06 7ec13e1 da60d06 7ec13e1 da60d06 7ec13e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
# ==========================================
# 1. SETUP: Load Model (Global Scope)
# ==========================================
model_name = "facebook/bart-large" # "facebook/bart-base"
print(f"Loading {model_name}...")
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.eval() # Set to evaluation mode
# ==========================================
# 2. CORE LOGIC FUNCTIONS
# ==========================================
def text_to_embedding(text):
"""Encodes text into the BART Latent Space (Vectors)."""
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
encoder_outputs = model.model.encoder(**inputs)
return encoder_outputs.last_hidden_state
def embedding_to_text(embedding_tensor):
"""Decodes a Vector back into Text."""
encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
with torch.no_grad():
generated_ids = model.generate(
encoder_outputs=encoder_outputs_wrapped,
max_length=50,
num_beams=4,
early_stopping=True
)
decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return decoded_text
# ==========================================
# 3. GRADIO INTERFACE FUNCTIONS
# ==========================================
def run_reconstruction(text):
if not text:
return "", "Please enter text."
# 1. Encode
vector = text_to_embedding(text)
# 2. Decode
reconstructed = embedding_to_text(vector)
# 3. Get Stats
shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)"
preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}"
debug_info = f"{shape_info}\n{preview}"
return reconstructed, debug_info
def run_mixing(text1, text2):
if not text1 or not text2:
return "Please enter two sentences."
# 1. Get vectors
v1 = text_to_embedding(text1)
v2 = text_to_embedding(text2)
# 2. Align lengths (Truncate to minimum length)
# Note: In a production app, you might want to pad instead of truncate,
# but for this specific "averaging" demo, truncation prevents dimension mismatch errors.
min_len = min(v1.shape[1], v2.shape[1])
v1_cut = v1[:, :min_len, :]
v2_cut = v2[:, :min_len, :]
# 3. Math: Average the vectors
v_mixed = (v1_cut + v2_cut) / 2.0
# 4. Decode
mixed_text = embedding_to_text(v_mixed)
return mixed_text
# ==========================================
# 4. BUILD UI
# ==========================================
with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 BART Latent Space Explorer")
gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.")
with gr.Tabs():
# --- TAB 1: RECONSTRUCTION ---
with gr.TabItem("1. Auto-Encoder Test"):
gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.")
btn_recon = gr.Button("Encode & Decode", variant="primary")
with gr.Column():
output_recon = gr.Textbox(label="Reconstructed Text")
output_debug = gr.Code(label="Vector Stats", language="json")
btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug])
# --- TAB 2: VECTOR MIXING ---
with gr.TabItem("2. Vector Mixing (Math)"):
gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!")
with gr.Row():
with gr.Column():
mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.")
mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.")
btn_mix = gr.Button("Calculate Average Meaning", variant="primary")
with gr.Column():
mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4)
btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out)
if __name__ == "__main__":
demo.launch() |