File size: 4,728 Bytes
3a82207
7ec13e1
da60d06
 
08c1bd3
7ec13e1
 
 
51fab22
da60d06
 
 
7ec13e1
63b82b4
7ec13e1
 
 
08c1bd3
da60d06
7ec13e1
da60d06
 
 
7ec13e1
3a82207
da60d06
7ec13e1
da60d06
 
 
 
7ec13e1
 
 
da60d06
 
 
3592f5f
da60d06
7ec13e1
da60d06
3a82207
7ec13e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a82207
7ec13e1
 
 
3a82207
7ec13e1
 
 
 
 
 
 
63b82b4
7ec13e1
 
da60d06
7ec13e1
 
 
 
da60d06
 
7ec13e1
da60d06
7ec13e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput

# ==========================================
# 1. SETUP: Load Model (Global Scope)
# ==========================================
model_name = "facebook/bart-large" # "facebook/bart-base"
print(f"Loading {model_name}...")
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.eval() # Set to evaluation mode

# ==========================================
# 2. CORE LOGIC FUNCTIONS
# ==========================================

def text_to_embedding(text):
    """Encodes text into the BART Latent Space (Vectors)."""
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        encoder_outputs = model.model.encoder(**inputs)
    return encoder_outputs.last_hidden_state

def embedding_to_text(embedding_tensor):
    """Decodes a Vector back into Text."""
    encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
    with torch.no_grad():
        generated_ids = model.generate(
            encoder_outputs=encoder_outputs_wrapped,
            max_length=50,
            num_beams=4,
            early_stopping=True
        )
    decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return decoded_text

# ==========================================
# 3. GRADIO INTERFACE FUNCTIONS
# ==========================================

def run_reconstruction(text):
    if not text:
        return "", "Please enter text."
        
    # 1. Encode
    vector = text_to_embedding(text)
    
    # 2. Decode
    reconstructed = embedding_to_text(vector)
    
    # 3. Get Stats
    shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)"
    preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}"
    
    debug_info = f"{shape_info}\n{preview}"
    
    return reconstructed, debug_info

def run_mixing(text1, text2):
    if not text1 or not text2:
        return "Please enter two sentences."

    # 1. Get vectors
    v1 = text_to_embedding(text1)
    v2 = text_to_embedding(text2)

    # 2. Align lengths (Truncate to minimum length)
    # Note: In a production app, you might want to pad instead of truncate,
    # but for this specific "averaging" demo, truncation prevents dimension mismatch errors.
    min_len = min(v1.shape[1], v2.shape[1])
    
    v1_cut = v1[:, :min_len, :]
    v2_cut = v2[:, :min_len, :]

    # 3. Math: Average the vectors
    v_mixed = (v1_cut + v2_cut) / 2.0

    # 4. Decode
    mixed_text = embedding_to_text(v_mixed)
    
    return mixed_text

# ==========================================
# 4. BUILD UI
# ==========================================

with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🧠 BART Latent Space Explorer")
    gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.")

    with gr.Tabs():
        
        # --- TAB 1: RECONSTRUCTION ---
        with gr.TabItem("1. Auto-Encoder Test"):
            gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.")
            
            with gr.Row():
                with gr.Column():
                    input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.")
                    btn_recon = gr.Button("Encode & Decode", variant="primary")
                
                with gr.Column():
                    output_recon = gr.Textbox(label="Reconstructed Text")
                    output_debug = gr.Code(label="Vector Stats", language="json")

            btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug])

        # --- TAB 2: VECTOR MIXING ---
        with gr.TabItem("2. Vector Mixing (Math)"):
            gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!")
            
            with gr.Row():
                with gr.Column():
                    mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.")
                    mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.")
                    btn_mix = gr.Button("Calculate Average Meaning", variant="primary")
                
                with gr.Column():
                    mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4)

            btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out)

if __name__ == "__main__":
    demo.launch()