everydaytok's picture
Update app.py
51fab22 verified
import torch
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
# ==========================================
# 1. SETUP: Load Model (Global Scope)
# ==========================================
model_name = "facebook/bart-large" # "facebook/bart-base"
print(f"Loading {model_name}...")
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.eval() # Set to evaluation mode
# ==========================================
# 2. CORE LOGIC FUNCTIONS
# ==========================================
def text_to_embedding(text):
"""Encodes text into the BART Latent Space (Vectors)."""
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
encoder_outputs = model.model.encoder(**inputs)
return encoder_outputs.last_hidden_state
def embedding_to_text(embedding_tensor):
"""Decodes a Vector back into Text."""
encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
with torch.no_grad():
generated_ids = model.generate(
encoder_outputs=encoder_outputs_wrapped,
max_length=50,
num_beams=4,
early_stopping=True
)
decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return decoded_text
# ==========================================
# 3. GRADIO INTERFACE FUNCTIONS
# ==========================================
def run_reconstruction(text):
if not text:
return "", "Please enter text."
# 1. Encode
vector = text_to_embedding(text)
# 2. Decode
reconstructed = embedding_to_text(vector)
# 3. Get Stats
shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)"
preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}"
debug_info = f"{shape_info}\n{preview}"
return reconstructed, debug_info
def run_mixing(text1, text2):
if not text1 or not text2:
return "Please enter two sentences."
# 1. Get vectors
v1 = text_to_embedding(text1)
v2 = text_to_embedding(text2)
# 2. Align lengths (Truncate to minimum length)
# Note: In a production app, you might want to pad instead of truncate,
# but for this specific "averaging" demo, truncation prevents dimension mismatch errors.
min_len = min(v1.shape[1], v2.shape[1])
v1_cut = v1[:, :min_len, :]
v2_cut = v2[:, :min_len, :]
# 3. Math: Average the vectors
v_mixed = (v1_cut + v2_cut) / 2.0
# 4. Decode
mixed_text = embedding_to_text(v_mixed)
return mixed_text
# ==========================================
# 4. BUILD UI
# ==========================================
with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 BART Latent Space Explorer")
gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.")
with gr.Tabs():
# --- TAB 1: RECONSTRUCTION ---
with gr.TabItem("1. Auto-Encoder Test"):
gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.")
btn_recon = gr.Button("Encode & Decode", variant="primary")
with gr.Column():
output_recon = gr.Textbox(label="Reconstructed Text")
output_debug = gr.Code(label="Vector Stats", language="json")
btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug])
# --- TAB 2: VECTOR MIXING ---
with gr.TabItem("2. Vector Mixing (Math)"):
gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!")
with gr.Row():
with gr.Column():
mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.")
mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.")
btn_mix = gr.Button("Calculate Average Meaning", variant="primary")
with gr.Column():
mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4)
btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out)
if __name__ == "__main__":
demo.launch()