File size: 3,827 Bytes
3a82207
7ec13e1
004a76d
da60d06
08c1bd3
7ec13e1
509cc45
7ec13e1
509cc45
004a76d
da60d06
004a76d
 
 
63b82b4
7ec13e1
509cc45
7ec13e1
08c1bd3
da60d06
 
 
004a76d
7ec13e1
3a82207
da60d06
 
 
 
 
509cc45
004a76d
509cc45
7ec13e1
da60d06
509cc45
c9308b7
7ec13e1
509cc45
 
7ec13e1
 
509cc45
 
7ec13e1
004a76d
 
509cc45
 
004a76d
 
7ec13e1
c9308b7
509cc45
c9308b7
 
 
 
 
 
 
 
509cc45
c9308b7
 
 
509cc45
 
 
c9308b7
004a76d
 
509cc45
c9308b7
509cc45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec13e1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput

# ==========================================
# 1. SETUP
# ==========================================
model_name = "google/flan-t5-large"

print(f"Loading {model_name}...")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.eval() 

# ==========================================
# 2. LOGIC
# ==========================================

def text_to_embedding(text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        encoder_outputs = model.encoder(**inputs)
    return encoder_outputs.last_hidden_state

def embedding_to_text(embedding_tensor):
    encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
    with torch.no_grad():
        generated_ids = model.generate(
            encoder_outputs=encoder_outputs_wrapped,
            max_length=100,
            num_beams=5,
            repetition_penalty=2.5,
            early_stopping=True
        )
    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

def run_mixing(text1, text2):
    if not text1 or not text2: return "Please enter two sentences."
    
    v1 = text_to_embedding(text1)
    v2 = text_to_embedding(text2)
    
    # Truncate to min length
    min_len = min(v1.shape[1], v2.shape[1])
    v1 = v1[:, :min_len, :]
    v2 = v2[:, :min_len, :]
    
    # 50/50 Average
    v_mixed = (v1 + v2) / 2.0
    return embedding_to_text(v_mixed)

def run_weighted_mixing(text1, text2, mix_ratio):
    if not text1 or not text2: return "Please enter two sentences."

    v1 = text_to_embedding(text1)
    v2 = text_to_embedding(text2)

    min_len = min(v1.shape[1], v2.shape[1])
    v1 = v1[:, :min_len, :]
    v2 = v2[:, :min_len, :]

    # Weighted Average formula
    v_mixed = (v1 * (1 - mix_ratio)) + (v2 * mix_ratio)
    return embedding_to_text(v_mixed)

# ==========================================
# 3. GRADIO UI (FIXED STRUCTURE)
# ==========================================

with gr.Blocks(title="FLAN-T5 Latent Explorer", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🧠 FLAN-T5 Latent Space Mixer")
    gr.Markdown(f"Running `{model_name}`.")

    # We need a Tabs container to hold the TabItems
    with gr.Tabs():
        
        # --- TAB 1: 50/50 MIX ---
        with gr.TabItem("1. Simple Mix (50/50)"):
            with gr.Row():
                with gr.Column():
                    t1_simple = gr.Textbox(label="Concept A", value="The King is powerful.")
                    t2_simple = gr.Textbox(label="Concept B", value="The woman is beautiful.")
                    btn_simple = gr.Button("Mix Vectors", variant="primary")
                with gr.Column():
                    out_simple = gr.Textbox(label="Result", lines=2)

            btn_simple.click(run_mixing, inputs=[t1_simple, t2_simple], outputs=out_simple)

        # --- TAB 2: WEIGHTED SLIDER ---
        with gr.TabItem("2. Weighted Morph (Slider)"):
            gr.Markdown("Slide between the two sentences to see how the meaning shifts.")
            with gr.Row():
                t1_morph = gr.Textbox(label="Start Sentence", value="The dog is happy.")
                t2_morph = gr.Textbox(label="End Sentence", value="The cat is angry.")
            
            slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Mixing Ratio (Left = Start, Right = End)")
            
            btn_morph = gr.Button("Morph", variant="primary")
            out_morph = gr.Textbox(label="Result")

            btn_morph.click(run_weighted_mixing, inputs=[t1_morph, t2_morph, slider], outputs=out_morph)

if __name__ == "__main__":
    demo.launch()