dwmk commited on
Commit
4a2e309
·
verified ·
1 Parent(s): 24be874

new ver - experimental text reply generation added

Browse files
Files changed (1) hide show
  1. app.py +157 -50
app.py CHANGED
@@ -12,6 +12,7 @@ from sklearn.linear_model import LogisticRegression
12
  from sklearn.preprocessing import LabelEncoder
13
  import kagglehub
14
  import warnings
 
15
 
16
  # Suppress sklearn warnings for cleaner logs
17
  warnings.filterwarnings("ignore")
@@ -23,30 +24,53 @@ class EpisodicMemory:
23
  def __init__(self, capacity=2000):
24
  self.memory_x = []
25
  self.memory_y = []
 
26
  self.capacity = capacity
27
 
28
- def store(self, x, y):
29
  # Store on CPU to save GPU VRAM
30
  curr_x = x.detach().cpu()
31
  curr_y = y.detach().cpu()
32
- for i in range(curr_x.size(0)):
 
 
 
 
 
 
 
 
 
 
33
  if len(self.memory_x) >= self.capacity:
34
  self.memory_x.pop(0)
35
  self.memory_y.pop(0)
 
 
36
  self.memory_x.append(curr_x[i])
37
  self.memory_y.append(curr_y[i])
 
 
 
38
 
39
  def retrieve(self, query_x, k=5):
40
  if not self.memory_x:
41
- return None
42
  mem_tensor = torch.stack(self.memory_x).to(query_x.device)
43
  distances = torch.cdist(query_x, mem_tensor)
44
- top_k_indices = torch.topk(distances, k, largest=False).indices
 
45
 
46
  # Gather labels
47
  retrieved_y = [torch.stack([self.memory_y[idx] for idx in sample_indices])
48
- for sample_indices in top_k_indices]
49
- return torch.stack(retrieved_y).to(query_x.device)
 
 
 
 
 
 
50
 
51
  class H3MOS(nn.Module):
52
  def __init__(self, input_dim, hidden_dim, output_dim):
@@ -71,12 +95,12 @@ class H3MOS(nn.Module):
71
 
72
  # Fast Path (Training or Empty Memory)
73
  if training_mode or len(self.hippocampus.memory_x) < 10:
74
- return raw_logits
75
 
76
  # Memory Retrieval & Integration
77
- past_labels = self.hippocampus.retrieve(x, k=5)
78
  if past_labels is None:
79
- return raw_logits
80
 
81
  mem_votes = torch.zeros_like(raw_logits)
82
  for i in range(x.size(0)):
@@ -86,7 +110,9 @@ class H3MOS(nn.Module):
86
  mem_probs = F.softmax(mem_votes, dim=1)
87
 
88
  # Dynamic Gating: 80% Neural, 20% Memory
89
- return (0.8 * raw_logits) + (0.2 * mem_probs * 5.0)
 
 
90
 
91
  # --- 2. DATA SETUP & TRAINING PIPELINE ---
92
 
@@ -100,7 +126,7 @@ try:
100
  except Exception as e:
101
  print("Error loading data:", e)
102
  # Fallback dummy data if kaggle fails (for testing)
103
- df = pd.DataFrame({'content': ['test'], 'emoji': ['👍']})
104
 
105
  # Mappings
106
  sent_map = {'❤️':'Positive', '👍':'Positive', '😂':'Positive', '💯':'Positive', '😢':'Negative', '😭':'Negative', '😮':'Neutral'}
@@ -111,13 +137,24 @@ tfidf = TfidfVectorizer(max_features=600, stop_words='english')
111
  X_sparse = tfidf.fit_transform(df['content'])
112
  X_dense = torch.FloatTensor(X_sparse.toarray()).to(device)
113
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Model Zoo Containers
115
  tasks = ['emoji', 'sentiment', 'intent']
116
  model_names = ['DISTIL', 'RandomForest', 'SVM', 'NaiveBayes', 'LogReg', 'GradBoost']
117
  zoo = {task: {} for task in tasks}
118
  encoders = {}
119
 
120
- print("🧠 Training Models... (This may take a moment)")
121
 
122
  for task in tasks:
123
  # Prepare Labels
@@ -140,18 +177,23 @@ for task in tasks:
140
  optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
141
 
142
  model.train()
143
- # Short training loop for demo speed
144
  for epoch in range(25):
145
  optimizer.zero_grad()
146
- out = model(X_dense, training_mode=True)
147
  loss = F.cross_entropy(out, y_tensor)
148
  loss.backward()
149
  optimizer.step()
150
- # Populate memory occasionally
 
 
151
  if epoch % 5 == 0:
152
  with torch.no_grad():
153
- idx = torch.randperm(X_dense.size(0))[:50]
154
- model.hippocampus.store(X_dense[idx], y_tensor[idx])
 
 
 
155
 
156
  model.eval()
157
  zoo[task]['DISTIL'] = model
@@ -165,7 +207,24 @@ for task in tasks:
165
 
166
  print("✅ Training Complete.")
167
 
168
- # --- 3. INFERENCE LOGIC ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def get_predictions(text):
171
  """Runs all models on the text."""
@@ -174,13 +233,41 @@ def get_predictions(text):
174
 
175
  results = {name: {} for name in model_names}
176
 
177
- for task in tasks:
178
- le = encoders[task]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
 
 
 
 
 
 
 
 
 
 
180
  for name in model_names:
181
  if name == 'DISTIL':
182
  with torch.no_grad():
183
- logits = zoo[task][name](vec_t)
184
  pred_idx = torch.argmax(logits, dim=1).item()
185
  pred_label = le.inverse_transform([pred_idx])[0]
186
  else:
@@ -193,9 +280,6 @@ def get_predictions(text):
193
 
194
  # --- 4. UI STYLING & INTERFACE ---
195
 
196
- def get_avatar_url(seed):
197
- return f"https://api.dicebear.com/7.x/bottts/svg?seed={seed}&backgroundColor=transparent&size=128"
198
-
199
  CSS = """
200
  .chat-window { font-family: 'Segoe UI', sans-serif; }
201
 
@@ -221,7 +305,8 @@ CSS = """
221
 
222
  .model-card {
223
  background: white;
224
- min-width: 140px;
 
225
  border-radius: 12px;
226
  padding: 12px;
227
  box-shadow: 0 4px 12px rgba(0,0,0,0.08);
@@ -230,40 +315,60 @@ CSS = """
230
  align-items: center;
231
  border: 1px solid #eee;
232
  transition: transform 0.2s;
 
233
  }
234
- .model-card:hover { transform: translateY(-3px); }
235
 
236
  .card-name {
237
- font-size: 11px;
238
- font-weight: 700;
239
  text-transform: uppercase;
240
- color: #888;
241
  margin-bottom: 4px;
 
242
  }
243
 
244
  .card-emoji {
245
- font-size: 28px;
246
- margin: 4px 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  }
248
 
249
  .card-badge {
250
- font-size: 10px;
251
- padding: 2px 8px;
252
- border-radius: 10px;
253
- margin-top: 4px;
254
- font-weight: 600;
 
255
  }
256
 
257
- .bg-Pos { background-color: #e6fffa; color: #2c7a7b; }
258
- .bg-Neg { background-color: #fff5f5; color: #c53030; }
259
- .bg-Neu { background-color: #f7fafc; color: #4a5568; }
260
 
261
  .intent-row {
262
- font-size: 10px;
263
- color: #666;
264
- margin-top: 6px;
265
- border-top: 1px dashed #eee;
266
- padding-top: 4px;
267
  width: 100%;
268
  text-align: center;
269
  }
@@ -275,25 +380,26 @@ def chat_logic(message, history):
275
 
276
  preds = get_predictions(message)
277
 
278
- # 1. Create User Message HTML (with Emoji Reaction Bar)
279
- # Order: DISTIL, RF, SVM, NB, LR, GB
280
- reaction_string = "".join([preds[m]['emoji'] for m in model_names])
 
281
 
282
  user_html = f"""
283
  <div>
284
  {message}
285
- <div class="user-reactions" title="Consensus: {reaction_string}">{reaction_string}</div>
286
  </div>
287
  """
288
  history.append({"role": "user", "content": user_html})
289
 
290
- # 2. Create Single Bot Reply HTML (Horizontal Scroll Cards)
291
  cards_html = '<div class="model-scroll-container">'
292
 
293
  for name in model_names:
294
  p = preds[name]
295
 
296
- # Color coding for sentiment
297
  sent_cls = "bg-Neu"
298
  if "Pos" in p['sentiment']: sent_cls = "bg-Pos"
299
  elif "Neg" in p['sentiment']: sent_cls = "bg-Neg"
@@ -302,6 +408,7 @@ def chat_logic(message, history):
302
  <div class="model-card">
303
  <div class="card-name">{name}</div>
304
  <div class="card-emoji">{p['emoji']}</div>
 
305
  <div class="card-badge {sent_cls}">{p['sentiment']}</div>
306
  <div class="intent-row">{p['intent']}</div>
307
  </div>
 
12
  from sklearn.preprocessing import LabelEncoder
13
  import kagglehub
14
  import warnings
15
+ import random
16
 
17
  # Suppress sklearn warnings for cleaner logs
18
  warnings.filterwarnings("ignore")
 
24
  def __init__(self, capacity=2000):
25
  self.memory_x = []
26
  self.memory_y = []
27
+ self.memory_text = [] # New: Store raw text for replies
28
  self.capacity = capacity
29
 
30
+ def store(self, x, y, text_content):
31
  # Store on CPU to save GPU VRAM
32
  curr_x = x.detach().cpu()
33
  curr_y = y.detach().cpu()
34
+
35
+ # Handle batch or single item
36
+ if len(curr_x.shape) > 1:
37
+ batch_size = curr_x.size(0)
38
+ else:
39
+ batch_size = 1
40
+ curr_x = curr_x.unsqueeze(0)
41
+ curr_y = curr_y.unsqueeze(0)
42
+ text_content = [text_content]
43
+
44
+ for i in range(batch_size):
45
  if len(self.memory_x) >= self.capacity:
46
  self.memory_x.pop(0)
47
  self.memory_y.pop(0)
48
+ self.memory_text.pop(0)
49
+
50
  self.memory_x.append(curr_x[i])
51
  self.memory_y.append(curr_y[i])
52
+ # Store corresponding text (handle potential index mismatch in loops)
53
+ txt = text_content[i] if isinstance(text_content, list) else text_content
54
+ self.memory_text.append(txt)
55
 
56
  def retrieve(self, query_x, k=5):
57
  if not self.memory_x:
58
+ return None, None
59
  mem_tensor = torch.stack(self.memory_x).to(query_x.device)
60
  distances = torch.cdist(query_x, mem_tensor)
61
+ top_k = torch.topk(distances, k, largest=False)
62
+ indices = top_k.indices
63
 
64
  # Gather labels
65
  retrieved_y = [torch.stack([self.memory_y[idx] for idx in sample_indices])
66
+ for sample_indices in indices]
67
+
68
+ # Gather text for the "Best Match" (closest neighbor)
69
+ # We take the nearest neighbor (index 0 of top k) for the reply
70
+ closest_indices = indices[:, 0].cpu().tolist()
71
+ retrieved_text = [self.memory_text[idx] for idx in closest_indices]
72
+
73
+ return torch.stack(retrieved_y).to(query_x.device), retrieved_text
74
 
75
  class H3MOS(nn.Module):
76
  def __init__(self, input_dim, hidden_dim, output_dim):
 
95
 
96
  # Fast Path (Training or Empty Memory)
97
  if training_mode or len(self.hippocampus.memory_x) < 10:
98
+ return raw_logits, None
99
 
100
  # Memory Retrieval & Integration
101
+ past_labels, retrieved_texts = self.hippocampus.retrieve(x, k=5)
102
  if past_labels is None:
103
+ return raw_logits, None
104
 
105
  mem_votes = torch.zeros_like(raw_logits)
106
  for i in range(x.size(0)):
 
110
  mem_probs = F.softmax(mem_votes, dim=1)
111
 
112
  # Dynamic Gating: 80% Neural, 20% Memory
113
+ final_logits = (0.8 * raw_logits) + (0.2 * mem_probs * 5.0)
114
+
115
+ return final_logits, retrieved_texts
116
 
117
  # --- 2. DATA SETUP & TRAINING PIPELINE ---
118
 
 
126
  except Exception as e:
127
  print("Error loading data:", e)
128
  # Fallback dummy data if kaggle fails (for testing)
129
+ df = pd.DataFrame({'content': ['test', 'good job', 'bad day'], 'emoji': ['👍', '❤️', '😭']})
130
 
131
  # Mappings
132
  sent_map = {'❤️':'Positive', '👍':'Positive', '😂':'Positive', '💯':'Positive', '😢':'Negative', '😭':'Negative', '😮':'Neutral'}
 
137
  X_sparse = tfidf.fit_transform(df['content'])
138
  X_dense = torch.FloatTensor(X_sparse.toarray()).to(device)
139
 
140
+ # Reply Bank Construction (For non-neural models)
141
+ # We organize valid "utterances" by their emoji label to simulate responses
142
+ reply_bank = {}
143
+ unique_emojis = df['emoji'].unique()
144
+ for emo in unique_emojis:
145
+ # Filter messages that resulted in this emoji
146
+ msgs = df[df['emoji'] == emo]['content'].tolist()
147
+ # Keep short, punchy replies
148
+ msgs = [m for m in msgs if len(m.split()) < 15]
149
+ reply_bank[emo] = msgs if msgs else ["Interesting."]
150
+
151
  # Model Zoo Containers
152
  tasks = ['emoji', 'sentiment', 'intent']
153
  model_names = ['DISTIL', 'RandomForest', 'SVM', 'NaiveBayes', 'LogReg', 'GradBoost']
154
  zoo = {task: {} for task in tasks}
155
  encoders = {}
156
 
157
+ print("🧠 Training Models & Encoding Memories... (This may take a moment)")
158
 
159
  for task in tasks:
160
  # Prepare Labels
 
177
  optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
178
 
179
  model.train()
180
+ # Short training loop
181
  for epoch in range(25):
182
  optimizer.zero_grad()
183
+ out, _ = model(X_dense, training_mode=True)
184
  loss = F.cross_entropy(out, y_tensor)
185
  loss.backward()
186
  optimizer.step()
187
+
188
+ # Populate memory: DISTIL learns by storing training examples
189
+ # We store 10% of data per epoch to build the "brain"
190
  if epoch % 5 == 0:
191
  with torch.no_grad():
192
+ # Random sample indices
193
+ idx = torch.randperm(X_dense.size(0))[:100]
194
+ # Store Vector + Label + Actual Text Content
195
+ batch_text = df.iloc[idx.cpu().numpy()]['content'].tolist()
196
+ model.hippocampus.store(X_dense[idx], y_tensor[idx], batch_text)
197
 
198
  model.eval()
199
  zoo[task]['DISTIL'] = model
 
207
 
208
  print("✅ Training Complete.")
209
 
210
+ # --- 3. INFERENCE & GENERATION LOGIC ---
211
+
212
+ def generate_reply(model_name, predicted_emoji, distil_retrieved_text=None):
213
+ """
214
+ Generates a text reply.
215
+ - DISTIL uses Associative Recall (nearest neighbor text).
216
+ - Others use Random Sampling from the Reply Bank based on their prediction.
217
+ """
218
+ try:
219
+ if model_name == 'DISTIL' and distil_retrieved_text:
220
+ # H3MOS echoes a memory that feels "associatively related"
221
+ return f"\"{distil_retrieved_text}\""
222
+
223
+ # Standard models pick a vibe-matched message from the dataset
224
+ candidates = reply_bank.get(predicted_emoji, ["I don't know what to say."])
225
+ return f"\"{random.choice(candidates)}\""
226
+ except:
227
+ return "..."
228
 
229
  def get_predictions(text):
230
  """Runs all models on the text."""
 
233
 
234
  results = {name: {} for name in model_names}
235
 
236
+ # 1. First, get Emoji predictions (Primary task for replies)
237
+ emoji_preds = {}
238
+ distil_text_memory = None
239
+
240
+ # Run Emoji Task first to determine the reply "Vibe"
241
+ task = 'emoji'
242
+ le = encoders[task]
243
+
244
+ for name in model_names:
245
+ if name == 'DISTIL':
246
+ with torch.no_grad():
247
+ logits, mem_texts = zoo[task][name](vec_t)
248
+ pred_idx = torch.argmax(logits, dim=1).item()
249
+ pred_label = le.inverse_transform([pred_idx])[0]
250
+ # Capture the memory text for DISTIL
251
+ if mem_texts: distil_text_memory = mem_texts[0]
252
+ else:
253
+ pred_idx = zoo[task][name].predict(vec_s)[0]
254
+ pred_label = le.inverse_transform([pred_idx])[0]
255
 
256
+ emoji_preds[name] = pred_label
257
+ results[name]['emoji'] = pred_label
258
+
259
+ # GENERATE TEXT REPLY
260
+ # We pass the memory text if it's DISTIL, otherwise None
261
+ mem_txt = distil_text_memory if name == 'DISTIL' else None
262
+ results[name]['reply'] = generate_reply(name, pred_label, mem_txt)
263
+
264
+ # 2. Run other tasks (Sentiment/Intent) just for labels
265
+ for task in ['sentiment', 'intent']:
266
+ le = encoders[task]
267
  for name in model_names:
268
  if name == 'DISTIL':
269
  with torch.no_grad():
270
+ logits, _ = zoo[task][name](vec_t)
271
  pred_idx = torch.argmax(logits, dim=1).item()
272
  pred_label = le.inverse_transform([pred_idx])[0]
273
  else:
 
280
 
281
  # --- 4. UI STYLING & INTERFACE ---
282
 
 
 
 
283
  CSS = """
284
  .chat-window { font-family: 'Segoe UI', sans-serif; }
285
 
 
305
 
306
  .model-card {
307
  background: white;
308
+ min-width: 160px; /* Wider to fit text */
309
+ max-width: 160px;
310
  border-radius: 12px;
311
  padding: 12px;
312
  box-shadow: 0 4px 12px rgba(0,0,0,0.08);
 
315
  align-items: center;
316
  border: 1px solid #eee;
317
  transition: transform 0.2s;
318
+ position: relative;
319
  }
320
+ .model-card:hover { transform: translateY(-3px); border-color: #cbd5e0; }
321
 
322
  .card-name {
323
+ font-size: 10px;
324
+ font-weight: 800;
325
  text-transform: uppercase;
326
+ color: #a0aec0;
327
  margin-bottom: 4px;
328
+ letter-spacing: 1px;
329
  }
330
 
331
  .card-emoji {
332
+ font-size: 32px;
333
+ margin: 2px 0;
334
+ line-height: 1;
335
+ }
336
+
337
+ /* The generated reply bubble */
338
+ .card-reply {
339
+ font-size: 11px;
340
+ color: #2d3748;
341
+ background: #edf2f7;
342
+ padding: 6px 8px;
343
+ border-radius: 8px;
344
+ margin: 8px 0;
345
+ text-align: center;
346
+ font-style: italic;
347
+ min-height: 40px;
348
+ display: flex;
349
+ align-items: center;
350
+ justify-content: center;
351
+ line-height: 1.2;
352
+ width: 100%;
353
  }
354
 
355
  .card-badge {
356
+ font-size: 9px;
357
+ padding: 2px 6px;
358
+ border-radius: 4px;
359
+ margin-top: auto; /* Push to bottom */
360
+ font-weight: 700;
361
+ text-transform: uppercase;
362
  }
363
 
364
+ .bg-Pos { background-color: #c6f6d5; color: #22543d; }
365
+ .bg-Neg { background-color: #fed7d7; color: #742a2a; }
366
+ .bg-Neu { background-color: #e2e8f0; color: #4a5568; }
367
 
368
  .intent-row {
369
+ font-size: 9px;
370
+ color: #718096;
371
+ margin-top: 4px;
 
 
372
  width: 100%;
373
  text-align: center;
374
  }
 
380
 
381
  preds = get_predictions(message)
382
 
383
+ # 1. Create User Message HTML (with Emoji Consensus)
384
+ # Simple majority voting for the "Consensus" bar
385
+ emojis = [preds[m]['emoji'] for m in model_names]
386
+ reaction_string = "".join(emojis)
387
 
388
  user_html = f"""
389
  <div>
390
  {message}
391
+ <div class="user-reactions" title="Consensus">{reaction_string}</div>
392
  </div>
393
  """
394
  history.append({"role": "user", "content": user_html})
395
 
396
+ # 2. Create Scrollable Bot Reply HTML
397
  cards_html = '<div class="model-scroll-container">'
398
 
399
  for name in model_names:
400
  p = preds[name]
401
 
402
+ # Color coding
403
  sent_cls = "bg-Neu"
404
  if "Pos" in p['sentiment']: sent_cls = "bg-Pos"
405
  elif "Neg" in p['sentiment']: sent_cls = "bg-Neg"
 
408
  <div class="model-card">
409
  <div class="card-name">{name}</div>
410
  <div class="card-emoji">{p['emoji']}</div>
411
+ <div class="card-reply">{p['reply']}</div>
412
  <div class="card-badge {sent_cls}">{p['sentiment']}</div>
413
  <div class="intent-row">{p['intent']}</div>
414
  </div>