everydaytok commited on
Commit
7ec13e1
·
verified ·
1 Parent(s): da60d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -65
app.py CHANGED
@@ -1,99 +1,126 @@
1
  import torch
 
2
  from transformers import BartTokenizer, BartForConditionalGeneration
3
  from transformers.modeling_outputs import BaseModelOutput
4
 
5
- # 1. Load the Pre-trained Model and Tokenizer
 
 
6
  model_name = "facebook/bart-base"
7
  print(f"Loading {model_name}...")
8
  tokenizer = BartTokenizer.from_pretrained(model_name)
9
  model = BartForConditionalGeneration.from_pretrained(model_name)
 
10
 
11
- # Ensure model is in eval mode (turns off dropout for consistent results)
12
- model.eval()
 
13
 
14
- # --- FUNCTION 1: ENCODE (Text -> Embedding) ---
15
  def text_to_embedding(text):
16
- print(f"\n--- Encoding: '{text}' ---")
17
-
18
- # Tokenize input
19
  inputs = tokenizer(text, return_tensors="pt")
20
-
21
- # Run ONLY the Encoder part of BART
22
- # We access the internal 'model' and then its 'encoder'
23
  with torch.no_grad():
24
  encoder_outputs = model.model.encoder(**inputs)
25
-
26
- # This is the "Embedding": A tensor of shape (Batch_Size, Seq_Length, 768)
27
- embedding = encoder_outputs.last_hidden_state
28
-
29
- print(f"Generated Vector Shape: {embedding.shape}")
30
- # Shape explanation: [1, 8, 768] means 1 sentence, 8 tokens long, 768 dimensions per token
31
- return embedding
32
 
33
- # --- FUNCTION 2: DECODE (Embedding -> Text) ---
34
  def embedding_to_text(embedding_tensor):
35
- print("--- Decoding Vector back to Text ---")
36
-
37
- # We must wrap the tensor in a specific class so the Generator understands it
38
- # The generator expects an object that has a .last_hidden_state attribute
39
  encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
40
-
41
- # Run the Generator
42
- # We tell it: "Don't encode anything new, use these 'encoder_outputs' I gave you."
43
  with torch.no_grad():
44
  generated_ids = model.generate(
45
  encoder_outputs=encoder_outputs_wrapped,
46
- max_length=20,
47
- num_beams=4 # Use beam search for better quality
 
48
  )
49
-
50
- # Decode the result IDs back to strings
51
  decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
  return decoded_text
53
 
54
  # ==========================================
55
- # TEST RUN
56
  # ==========================================
57
 
58
- # 1. Original Text
59
- original_sentence = "The cat sat on the mat."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # 2. Convert to Vector
62
- vector_representation = text_to_embedding(original_sentence)
 
63
 
64
- # 3. (Optional) Simulate "Math" or "Transmission"
65
- # Let's verify the vectors are real numbers by printing a tiny slice
66
- print(f"First 5 values of vector: {vector_representation[0][0][:5].numpy()}")
 
 
 
 
67
 
68
- # 4. Convert back to Text
69
- reconstructed_text = embedding_to_text(vector_representation)
70
 
71
- print(f"\nOriginal: {original_sentence}")
72
- print(f"Reconstructed: {reconstructed_text}")
 
 
73
 
74
  # ==========================================
75
- # EXPERIMENT: MIXING VECTORS
76
- # Let's try to 'average' two sentences and see what BART dreams up
77
  # ==========================================
78
- print("\n--- The Mixing Experiment ---")
79
- s1 = "The weather is sunny."
80
- s2 = "The weather is rainy."
81
-
82
- # Get vectors
83
- v1 = text_to_embedding(s1)
84
- v2 = text_to_embedding(s2)
85
-
86
- # To average them, they must be the same length (padding is usually handled by tokenizer,
87
- # but here we'll just cut to the minimum length for the demo hack)
88
- min_len = min(v1.shape[1], v2.shape[1])
89
- v1 = v1[:, :min_len, :]
90
- v2 = v2[:, :min_len, :]
91
-
92
- # Calculate the mean vector
93
- v_mixed = (v1 + v2) / 2.0
94
-
95
- # Decode the mixed thought
96
- mixed_text = embedding_to_text(v_mixed)
97
- print(f"Sentence A: {s1}")
98
- print(f"Sentence B: {s2}")
99
- print(f"Mixed Result: {mixed_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import gradio as gr
3
  from transformers import BartTokenizer, BartForConditionalGeneration
4
  from transformers.modeling_outputs import BaseModelOutput
5
 
6
+ # ==========================================
7
+ # 1. SETUP: Load Model (Global Scope)
8
+ # ==========================================
9
  model_name = "facebook/bart-base"
10
  print(f"Loading {model_name}...")
11
  tokenizer = BartTokenizer.from_pretrained(model_name)
12
  model = BartForConditionalGeneration.from_pretrained(model_name)
13
+ model.eval() # Set to evaluation mode
14
 
15
+ # ==========================================
16
+ # 2. CORE LOGIC FUNCTIONS
17
+ # ==========================================
18
 
 
19
  def text_to_embedding(text):
20
+ """Encodes text into the BART Latent Space (Vectors)."""
 
 
21
  inputs = tokenizer(text, return_tensors="pt")
 
 
 
22
  with torch.no_grad():
23
  encoder_outputs = model.model.encoder(**inputs)
24
+ return encoder_outputs.last_hidden_state
 
 
 
 
 
 
25
 
 
26
  def embedding_to_text(embedding_tensor):
27
+ """Decodes a Vector back into Text."""
 
 
 
28
  encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
 
 
 
29
  with torch.no_grad():
30
  generated_ids = model.generate(
31
  encoder_outputs=encoder_outputs_wrapped,
32
+ max_length=50,
33
+ num_beams=4,
34
+ early_stopping=True
35
  )
 
 
36
  decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
37
  return decoded_text
38
 
39
  # ==========================================
40
+ # 3. GRADIO INTERFACE FUNCTIONS
41
  # ==========================================
42
 
43
+ def run_reconstruction(text):
44
+ if not text:
45
+ return "", "Please enter text."
46
+
47
+ # 1. Encode
48
+ vector = text_to_embedding(text)
49
+
50
+ # 2. Decode
51
+ reconstructed = embedding_to_text(vector)
52
+
53
+ # 3. Get Stats
54
+ shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)"
55
+ preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}"
56
+
57
+ debug_info = f"{shape_info}\n{preview}"
58
+
59
+ return reconstructed, debug_info
60
+
61
+ def run_mixing(text1, text2):
62
+ if not text1 or not text2:
63
+ return "Please enter two sentences."
64
 
65
+ # 1. Get vectors
66
+ v1 = text_to_embedding(text1)
67
+ v2 = text_to_embedding(text2)
68
 
69
+ # 2. Align lengths (Truncate to minimum length)
70
+ # Note: In a production app, you might want to pad instead of truncate,
71
+ # but for this specific "averaging" demo, truncation prevents dimension mismatch errors.
72
+ min_len = min(v1.shape[1], v2.shape[1])
73
+
74
+ v1_cut = v1[:, :min_len, :]
75
+ v2_cut = v2[:, :min_len, :]
76
 
77
+ # 3. Math: Average the vectors
78
+ v_mixed = (v1_cut + v2_cut) / 2.0
79
 
80
+ # 4. Decode
81
+ mixed_text = embedding_to_text(v_mixed)
82
+
83
+ return mixed_text
84
 
85
  # ==========================================
86
+ # 4. BUILD UI
 
87
  # ==========================================
88
+
89
+ with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown("# 🧠 BART Latent Space Explorer")
91
+ gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.")
92
+
93
+ with gr.Tabs():
94
+
95
+ # --- TAB 1: RECONSTRUCTION ---
96
+ with gr.TabItem("1. Auto-Encoder Test"):
97
+ gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.")
102
+ btn_recon = gr.Button("Encode & Decode", variant="primary")
103
+
104
+ with gr.Column():
105
+ output_recon = gr.Textbox(label="Reconstructed Text")
106
+ output_debug = gr.Code(label="Vector Stats", language="json")
107
+
108
+ btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug])
109
+
110
+ # --- TAB 2: VECTOR MIXING ---
111
+ with gr.TabItem("2. Vector Mixing (Math)"):
112
+ gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!")
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.")
117
+ mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.")
118
+ btn_mix = gr.Button("Calculate Average Meaning", variant="primary")
119
+
120
+ with gr.Column():
121
+ mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4)
122
+
123
+ btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out)
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch()