import torch import gradio as gr from transformers import BartTokenizer, BartForConditionalGeneration from transformers.modeling_outputs import BaseModelOutput # ========================================== # 1. SETUP: Load Model (Global Scope) # ========================================== model_name = "facebook/bart-large" # "facebook/bart-base" print(f"Loading {model_name}...") tokenizer = BartTokenizer.from_pretrained(model_name) model = BartForConditionalGeneration.from_pretrained(model_name) model.eval() # Set to evaluation mode # ========================================== # 2. CORE LOGIC FUNCTIONS # ========================================== def text_to_embedding(text): """Encodes text into the BART Latent Space (Vectors).""" inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): encoder_outputs = model.model.encoder(**inputs) return encoder_outputs.last_hidden_state def embedding_to_text(embedding_tensor): """Decodes a Vector back into Text.""" encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor) with torch.no_grad(): generated_ids = model.generate( encoder_outputs=encoder_outputs_wrapped, max_length=50, num_beams=4, early_stopping=True ) decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return decoded_text # ========================================== # 3. GRADIO INTERFACE FUNCTIONS # ========================================== def run_reconstruction(text): if not text: return "", "Please enter text." # 1. Encode vector = text_to_embedding(text) # 2. Decode reconstructed = embedding_to_text(vector) # 3. Get Stats shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)" preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}" debug_info = f"{shape_info}\n{preview}" return reconstructed, debug_info def run_mixing(text1, text2): if not text1 or not text2: return "Please enter two sentences." # 1. Get vectors v1 = text_to_embedding(text1) v2 = text_to_embedding(text2) # 2. Align lengths (Truncate to minimum length) # Note: In a production app, you might want to pad instead of truncate, # but for this specific "averaging" demo, truncation prevents dimension mismatch errors. min_len = min(v1.shape[1], v2.shape[1]) v1_cut = v1[:, :min_len, :] v2_cut = v2[:, :min_len, :] # 3. Math: Average the vectors v_mixed = (v1_cut + v2_cut) / 2.0 # 4. Decode mixed_text = embedding_to_text(v_mixed) return mixed_text # ========================================== # 4. BUILD UI # ========================================== with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 BART Latent Space Explorer") gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.") with gr.Tabs(): # --- TAB 1: RECONSTRUCTION --- with gr.TabItem("1. Auto-Encoder Test"): gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.") with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.") btn_recon = gr.Button("Encode & Decode", variant="primary") with gr.Column(): output_recon = gr.Textbox(label="Reconstructed Text") output_debug = gr.Code(label="Vector Stats", language="json") btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug]) # --- TAB 2: VECTOR MIXING --- with gr.TabItem("2. Vector Mixing (Math)"): gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!") with gr.Row(): with gr.Column(): mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.") mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.") btn_mix = gr.Button("Calculate Average Meaning", variant="primary") with gr.Column(): mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4) btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out) if __name__ == "__main__": demo.launch()