Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| from transformers.modeling_outputs import BaseModelOutput | |
| # ========================================== | |
| # 1. SETUP: Load Model (Global Scope) | |
| # ========================================== | |
| model_name = "facebook/bart-large" # "facebook/bart-base" | |
| print(f"Loading {model_name}...") | |
| tokenizer = BartTokenizer.from_pretrained(model_name) | |
| model = BartForConditionalGeneration.from_pretrained(model_name) | |
| model.eval() # Set to evaluation mode | |
| # ========================================== | |
| # 2. CORE LOGIC FUNCTIONS | |
| # ========================================== | |
| def text_to_embedding(text): | |
| """Encodes text into the BART Latent Space (Vectors).""" | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| encoder_outputs = model.model.encoder(**inputs) | |
| return encoder_outputs.last_hidden_state | |
| def embedding_to_text(embedding_tensor): | |
| """Decodes a Vector back into Text.""" | |
| encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| encoder_outputs=encoder_outputs_wrapped, | |
| max_length=50, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return decoded_text | |
| # ========================================== | |
| # 3. GRADIO INTERFACE FUNCTIONS | |
| # ========================================== | |
| def run_reconstruction(text): | |
| if not text: | |
| return "", "Please enter text." | |
| # 1. Encode | |
| vector = text_to_embedding(text) | |
| # 2. Decode | |
| reconstructed = embedding_to_text(vector) | |
| # 3. Get Stats | |
| shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)" | |
| preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}" | |
| debug_info = f"{shape_info}\n{preview}" | |
| return reconstructed, debug_info | |
| def run_mixing(text1, text2): | |
| if not text1 or not text2: | |
| return "Please enter two sentences." | |
| # 1. Get vectors | |
| v1 = text_to_embedding(text1) | |
| v2 = text_to_embedding(text2) | |
| # 2. Align lengths (Truncate to minimum length) | |
| # Note: In a production app, you might want to pad instead of truncate, | |
| # but for this specific "averaging" demo, truncation prevents dimension mismatch errors. | |
| min_len = min(v1.shape[1], v2.shape[1]) | |
| v1_cut = v1[:, :min_len, :] | |
| v2_cut = v2[:, :min_len, :] | |
| # 3. Math: Average the vectors | |
| v_mixed = (v1_cut + v2_cut) / 2.0 | |
| # 4. Decode | |
| mixed_text = embedding_to_text(v_mixed) | |
| return mixed_text | |
| # ========================================== | |
| # 4. BUILD UI | |
| # ========================================== | |
| with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🧠 BART Latent Space Explorer") | |
| gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.") | |
| with gr.Tabs(): | |
| # --- TAB 1: RECONSTRUCTION --- | |
| with gr.TabItem("1. Auto-Encoder Test"): | |
| gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.") | |
| btn_recon = gr.Button("Encode & Decode", variant="primary") | |
| with gr.Column(): | |
| output_recon = gr.Textbox(label="Reconstructed Text") | |
| output_debug = gr.Code(label="Vector Stats", language="json") | |
| btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug]) | |
| # --- TAB 2: VECTOR MIXING --- | |
| with gr.TabItem("2. Vector Mixing (Math)"): | |
| gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.") | |
| mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.") | |
| btn_mix = gr.Button("Calculate Average Meaning", variant="primary") | |
| with gr.Column(): | |
| mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4) | |
| btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out) | |
| if __name__ == "__main__": | |
| demo.launch() |