import torch import gradio as gr from transformers import BartTokenizer, BartForConditionalGeneration from transformers.modeling_outputs import BaseModelOutput # ========================================== # 1. SETUP: Use BART-Large (The Best "Parrot") # ========================================== # BART is an auto-encoder. Its job is to reconstruct inputs, not chat. model_name = "facebook/bart-large" print(f"Loading {model_name}...") tokenizer = BartTokenizer.from_pretrained(model_name) model = BartForConditionalGeneration.from_pretrained(model_name) model.eval() # ========================================== # 2. STRICT LOGIC # ========================================== def text_to_embedding(text): inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): encoder_outputs = model.model.encoder(**inputs) return encoder_outputs.last_hidden_state def embedding_to_text(embedding_tensor): encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor) with torch.no_grad(): generated_ids = model.generate( encoder_outputs=encoder_outputs_wrapped, # --- STRICT TRANSCRIPTION SETTINGS --- max_length=50, num_beams=1, # Greedy Search (No creative exploring) do_sample=False, # Deterministic (No randomness) temperature=1.0, # Standard probability curve repetition_penalty=1.0 # Don't punish repeating words (we want exact copies) ) decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return decoded_text def run_weighted_mixing(text1, text2, mix_ratio): if not text1 or not text2: return "Enter sentences." v1 = text_to_embedding(text1) v2 = text_to_embedding(text2) # Truncate to min length min_len = min(v1.shape[1], v2.shape[1]) v1 = v1[:, :min_len, :] v2 = v2[:, :min_len, :] # Weighted Average v_mixed = (v1 * (1 - mix_ratio)) + (v2 * mix_ratio) return embedding_to_text(v_mixed) # ========================================== # 3. UI # ========================================== with gr.Blocks(title="BART-Large Vector Decoder", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🦜 BART Strict Vector Decoder") gr.Markdown("This version uses `bart-large` with **Greedy Search** to force direct transcription instead of creative generation.") with gr.Row(): t1 = gr.Textbox(label="Start Sentence", value="The dog is happy.") t2 = gr.Textbox(label="End Sentence", value="The cat is angry.") # 0.0 means 100% Start Sentence. 1.0 means 100% End Sentence. slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Ratio (0 = Start, 1 = End)") btn_mix = gr.Button("Decode Vector") out = gr.Textbox(label="Decoded Text") btn_mix.click(run_weighted_mixing, inputs=[t1, t2, slider], outputs=out) if __name__ == "__main__": demo.launch()