everydaytok commited on
Commit
04701d7
·
verified ·
1 Parent(s): 7ec13e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -80
app.py CHANGED
@@ -4,123 +4,79 @@ from transformers import BartTokenizer, BartForConditionalGeneration
4
  from transformers.modeling_outputs import BaseModelOutput
5
 
6
  # ==========================================
7
- # 1. SETUP: Load Model (Global Scope)
8
  # ==========================================
9
- model_name = "facebook/bart-base"
 
 
10
  print(f"Loading {model_name}...")
11
  tokenizer = BartTokenizer.from_pretrained(model_name)
12
  model = BartForConditionalGeneration.from_pretrained(model_name)
13
- model.eval() # Set to evaluation mode
14
 
15
  # ==========================================
16
- # 2. CORE LOGIC FUNCTIONS
17
  # ==========================================
18
 
19
  def text_to_embedding(text):
20
- """Encodes text into the BART Latent Space (Vectors)."""
21
  inputs = tokenizer(text, return_tensors="pt")
22
  with torch.no_grad():
23
  encoder_outputs = model.model.encoder(**inputs)
24
  return encoder_outputs.last_hidden_state
25
 
26
  def embedding_to_text(embedding_tensor):
27
- """Decodes a Vector back into Text."""
28
  encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
 
29
  with torch.no_grad():
30
  generated_ids = model.generate(
31
  encoder_outputs=encoder_outputs_wrapped,
 
 
32
  max_length=50,
33
- num_beams=4,
34
- early_stopping=True
 
 
35
  )
 
36
  decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
37
  return decoded_text
38
 
39
- # ==========================================
40
- # 3. GRADIO INTERFACE FUNCTIONS
41
- # ==========================================
42
-
43
- def run_reconstruction(text):
44
- if not text:
45
- return "", "Please enter text."
46
-
47
- # 1. Encode
48
- vector = text_to_embedding(text)
49
-
50
- # 2. Decode
51
- reconstructed = embedding_to_text(vector)
52
-
53
- # 3. Get Stats
54
- shape_info = f"Vector Shape: {vector.shape} (Batch, Tokens, Dimensions)"
55
- preview = f"First 5 values: {vector[0][0][:5].numpy().tolist()}"
56
-
57
- debug_info = f"{shape_info}\n{preview}"
58
-
59
- return reconstructed, debug_info
60
-
61
- def run_mixing(text1, text2):
62
- if not text1 or not text2:
63
- return "Please enter two sentences."
64
 
65
- # 1. Get vectors
66
  v1 = text_to_embedding(text1)
67
  v2 = text_to_embedding(text2)
68
 
69
- # 2. Align lengths (Truncate to minimum length)
70
- # Note: In a production app, you might want to pad instead of truncate,
71
- # but for this specific "averaging" demo, truncation prevents dimension mismatch errors.
72
  min_len = min(v1.shape[1], v2.shape[1])
73
-
74
- v1_cut = v1[:, :min_len, :]
75
- v2_cut = v2[:, :min_len, :]
76
 
77
- # 3. Math: Average the vectors
78
- v_mixed = (v1_cut + v2_cut) / 2.0
79
-
80
- # 4. Decode
81
- mixed_text = embedding_to_text(v_mixed)
82
 
83
- return mixed_text
84
 
85
  # ==========================================
86
- # 4. BUILD UI
87
  # ==========================================
88
 
89
- with gr.Blocks(title="BART Latent Space Explorer", theme=gr.themes.Soft()) as demo:
90
- gr.Markdown("# 🧠 BART Latent Space Explorer")
91
- gr.Markdown("This tool uses `facebook/bart-base` to convert text into mathematical vectors (Embeddings) and back.")
92
-
93
- with gr.Tabs():
94
-
95
- # --- TAB 1: RECONSTRUCTION ---
96
- with gr.TabItem("1. Auto-Encoder Test"):
97
- gr.Markdown("Type a sentence. The model will turn it into numbers, then turn those numbers back into text.")
98
-
99
- with gr.Row():
100
- with gr.Column():
101
- input_text = gr.Textbox(label="Original Sentence", value="The cat sat on the mat.")
102
- btn_recon = gr.Button("Encode & Decode", variant="primary")
103
-
104
- with gr.Column():
105
- output_recon = gr.Textbox(label="Reconstructed Text")
106
- output_debug = gr.Code(label="Vector Stats", language="json")
107
-
108
- btn_recon.click(run_reconstruction, inputs=input_text, outputs=[output_recon, output_debug])
109
 
110
- # --- TAB 2: VECTOR MIXING ---
111
- with gr.TabItem("2. Vector Mixing (Math)"):
112
- gr.Markdown("Type two different sentences. We will average their mathematical representations. Results may be surreal!")
113
-
114
- with gr.Row():
115
- with gr.Column():
116
- mix_in_1 = gr.Textbox(label="Sentence A", value="The weather is sunny.")
117
- mix_in_2 = gr.Textbox(label="Sentence B", value="The weather is rainy.")
118
- btn_mix = gr.Button("Calculate Average Meaning", variant="primary")
119
-
120
- with gr.Column():
121
- mix_out = gr.Textbox(label="The AI's 'Middle Ground' Thought", lines=4)
122
 
123
- btn_mix.click(run_mixing, inputs=[mix_in_1, mix_in_2], outputs=mix_out)
124
 
125
  if __name__ == "__main__":
126
  demo.launch()
 
4
  from transformers.modeling_outputs import BaseModelOutput
5
 
6
  # ==========================================
7
+ # 1. SETUP: Use BART-Large (The Best "Parrot")
8
  # ==========================================
9
+ # BART is an auto-encoder. Its job is to reconstruct inputs, not chat.
10
+ model_name = "facebook/bart-large"
11
+
12
  print(f"Loading {model_name}...")
13
  tokenizer = BartTokenizer.from_pretrained(model_name)
14
  model = BartForConditionalGeneration.from_pretrained(model_name)
15
+ model.eval()
16
 
17
  # ==========================================
18
+ # 2. STRICT LOGIC
19
  # ==========================================
20
 
21
  def text_to_embedding(text):
 
22
  inputs = tokenizer(text, return_tensors="pt")
23
  with torch.no_grad():
24
  encoder_outputs = model.model.encoder(**inputs)
25
  return encoder_outputs.last_hidden_state
26
 
27
  def embedding_to_text(embedding_tensor):
 
28
  encoder_outputs_wrapped = BaseModelOutput(last_hidden_state=embedding_tensor)
29
+
30
  with torch.no_grad():
31
  generated_ids = model.generate(
32
  encoder_outputs=encoder_outputs_wrapped,
33
+
34
+ # --- STRICT TRANSCRIPTION SETTINGS ---
35
  max_length=50,
36
+ num_beams=1, # Greedy Search (No creative exploring)
37
+ do_sample=False, # Deterministic (No randomness)
38
+ temperature=1.0, # Standard probability curve
39
+ repetition_penalty=1.0 # Don't punish repeating words (we want exact copies)
40
  )
41
+
42
  decoded_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
  return decoded_text
44
 
45
+ def run_weighted_mixing(text1, text2, mix_ratio):
46
+ if not text1 or not text2: return "Enter sentences."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
48
  v1 = text_to_embedding(text1)
49
  v2 = text_to_embedding(text2)
50
 
51
+ # Truncate to min length
 
 
52
  min_len = min(v1.shape[1], v2.shape[1])
53
+ v1 = v1[:, :min_len, :]
54
+ v2 = v2[:, :min_len, :]
 
55
 
56
+ # Weighted Average
57
+ v_mixed = (v1 * (1 - mix_ratio)) + (v2 * mix_ratio)
 
 
 
58
 
59
+ return embedding_to_text(v_mixed)
60
 
61
  # ==========================================
62
+ # 3. UI
63
  # ==========================================
64
 
65
+ with gr.Blocks(title="BART-Large Vector Decoder", theme=gr.themes.Soft()) as demo:
66
+ gr.Markdown("# 🦜 BART Strict Vector Decoder")
67
+ gr.Markdown("This version uses `bart-large` with **Greedy Search** to force direct transcription instead of creative generation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ with gr.Row():
70
+ t1 = gr.Textbox(label="Start Sentence", value="The dog is happy.")
71
+ t2 = gr.Textbox(label="End Sentence", value="The cat is angry.")
72
+
73
+ # 0.0 means 100% Start Sentence. 1.0 means 100% End Sentence.
74
+ slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Ratio (0 = Start, 1 = End)")
75
+
76
+ btn_mix = gr.Button("Decode Vector")
77
+ out = gr.Textbox(label="Decoded Text")
 
 
 
78
 
79
+ btn_mix.click(run_weighted_mixing, inputs=[t1, t2, slider], outputs=out)
80
 
81
  if __name__ == "__main__":
82
  demo.launch()