everydaytok's picture
Update app.py
509cc45 verified
import torch
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
# ==========================================
# 1. SETUP
# ==========================================
model_name = "google/flan-t5-large"
print(f"Loading {model_name}...")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.eval()
# ==========================================
# 2. LOGIC
# ==========================================
def text_to_embedding(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
encoder_outputs = model.encoder(**inputs)
return encoder_outputs.last_hidden_state
def embedding_to_text(embedding_tensor):
encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
with torch.no_grad():
generated_ids = model.generate(
encoder_outputs=encoder_outputs_wrapped,
max_length=100,
num_beams=5,
repetition_penalty=2.5,
early_stopping=True
)
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
def run_mixing(text1, text2):
if not text1 or not text2: return "Please enter two sentences."
v1 = text_to_embedding(text1)
v2 = text_to_embedding(text2)
# Truncate to min length
min_len = min(v1.shape[1], v2.shape[1])
v1 = v1[:, :min_len, :]
v2 = v2[:, :min_len, :]
# 50/50 Average
v_mixed = (v1 + v2) / 2.0
return embedding_to_text(v_mixed)
def run_weighted_mixing(text1, text2, mix_ratio):
if not text1 or not text2: return "Please enter two sentences."
v1 = text_to_embedding(text1)
v2 = text_to_embedding(text2)
min_len = min(v1.shape[1], v2.shape[1])
v1 = v1[:, :min_len, :]
v2 = v2[:, :min_len, :]
# Weighted Average formula
v_mixed = (v1 * (1 - mix_ratio)) + (v2 * mix_ratio)
return embedding_to_text(v_mixed)
# ==========================================
# 3. GRADIO UI (FIXED STRUCTURE)
# ==========================================
with gr.Blocks(title="FLAN-T5 Latent Explorer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 FLAN-T5 Latent Space Mixer")
gr.Markdown(f"Running `{model_name}`.")
# We need a Tabs container to hold the TabItems
with gr.Tabs():
# --- TAB 1: 50/50 MIX ---
with gr.TabItem("1. Simple Mix (50/50)"):
with gr.Row():
with gr.Column():
t1_simple = gr.Textbox(label="Concept A", value="The King is powerful.")
t2_simple = gr.Textbox(label="Concept B", value="The woman is beautiful.")
btn_simple = gr.Button("Mix Vectors", variant="primary")
with gr.Column():
out_simple = gr.Textbox(label="Result", lines=2)
btn_simple.click(run_mixing, inputs=[t1_simple, t2_simple], outputs=out_simple)
# --- TAB 2: WEIGHTED SLIDER ---
with gr.TabItem("2. Weighted Morph (Slider)"):
gr.Markdown("Slide between the two sentences to see how the meaning shifts.")
with gr.Row():
t1_morph = gr.Textbox(label="Start Sentence", value="The dog is happy.")
t2_morph = gr.Textbox(label="End Sentence", value="The cat is angry.")
slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Mixing Ratio (Left = Start, Right = End)")
btn_morph = gr.Button("Morph", variant="primary")
out_morph = gr.Textbox(label="Result")
btn_morph.click(run_weighted_mixing, inputs=[t1_morph, t2_morph, slider], outputs=out_morph)
if __name__ == "__main__":
demo.launch()