everydaytok's picture
Update app.py
04701d7 verified
import torch
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
# ==========================================
# 1. SETUP: Use BART-Large (The Best "Parrot")
# ==========================================
# BART is an auto-encoder. Its job is to reconstruct inputs, not chat.
model_name = "facebook/bart-large"
print(f"Loading {model_name}...")
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.eval()
# ==========================================
# 2. STRICT LOGIC
# ==========================================
def text_to_embedding(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
encoder_outputs = model.model.encoder(**inputs)
return encoder_outputs.last_hidden_state
def embedding_to_text(embedding_tensor):
encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
with torch.no_grad():
generated_ids = model.generate(
encoder_outputs=encoder_outputs_wrapped,
# --- STRICT TRANSCRIPTION SETTINGS ---
max_length=50,
num_beams=1, # Greedy Search (No creative exploring)
do_sample=False, # Deterministic (No randomness)
temperature=1.0, # Standard probability curve
repetition_penalty=1.0 # Don't punish repeating words (we want exact copies)
)
decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return decoded_text
def run_weighted_mixing(text1, text2, mix_ratio):
if not text1 or not text2: return "Enter sentences."
v1 = text_to_embedding(text1)
v2 = text_to_embedding(text2)
# Truncate to min length
min_len = min(v1.shape[1], v2.shape[1])
v1 = v1[:, :min_len, :]
v2 = v2[:, :min_len, :]
# Weighted Average
v_mixed = (v1 * (1 - mix_ratio)) + (v2 * mix_ratio)
return embedding_to_text(v_mixed)
# ==========================================
# 3. UI
# ==========================================
with gr.Blocks(title="BART-Large Vector Decoder", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🦜 BART Strict Vector Decoder")
gr.Markdown("This version uses `bart-large` with **Greedy Search** to force direct transcription instead of creative generation.")
with gr.Row():
t1 = gr.Textbox(label="Start Sentence", value="The dog is happy.")
t2 = gr.Textbox(label="End Sentence", value="The cat is angry.")
# 0.0 means 100% Start Sentence. 1.0 means 100% End Sentence.
slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Ratio (0 = Start, 1 = End)")
btn_mix = gr.Button("Decode Vector")
out = gr.Textbox(label="Decoded Text")
btn_mix.click(run_weighted_mixing, inputs=[t1, t2, slider], outputs=out)
if __name__ == "__main__":
demo.launch()