dwmk commited on
Commit
56f231a
·
verified ·
1 Parent(s): 4a2e309

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -157
app.py CHANGED
@@ -12,7 +12,6 @@ from sklearn.linear_model import LogisticRegression
12
  from sklearn.preprocessing import LabelEncoder
13
  import kagglehub
14
  import warnings
15
- import random
16
 
17
  # Suppress sklearn warnings for cleaner logs
18
  warnings.filterwarnings("ignore")
@@ -24,53 +23,30 @@ class EpisodicMemory:
24
  def __init__(self, capacity=2000):
25
  self.memory_x = []
26
  self.memory_y = []
27
- self.memory_text = [] # New: Store raw text for replies
28
  self.capacity = capacity
29
 
30
- def store(self, x, y, text_content):
31
  # Store on CPU to save GPU VRAM
32
  curr_x = x.detach().cpu()
33
  curr_y = y.detach().cpu()
34
-
35
- # Handle batch or single item
36
- if len(curr_x.shape) > 1:
37
- batch_size = curr_x.size(0)
38
- else:
39
- batch_size = 1
40
- curr_x = curr_x.unsqueeze(0)
41
- curr_y = curr_y.unsqueeze(0)
42
- text_content = [text_content]
43
-
44
- for i in range(batch_size):
45
  if len(self.memory_x) >= self.capacity:
46
  self.memory_x.pop(0)
47
  self.memory_y.pop(0)
48
- self.memory_text.pop(0)
49
-
50
  self.memory_x.append(curr_x[i])
51
  self.memory_y.append(curr_y[i])
52
- # Store corresponding text (handle potential index mismatch in loops)
53
- txt = text_content[i] if isinstance(text_content, list) else text_content
54
- self.memory_text.append(txt)
55
 
56
  def retrieve(self, query_x, k=5):
57
  if not self.memory_x:
58
- return None, None
59
  mem_tensor = torch.stack(self.memory_x).to(query_x.device)
60
  distances = torch.cdist(query_x, mem_tensor)
61
- top_k = torch.topk(distances, k, largest=False)
62
- indices = top_k.indices
63
 
64
  # Gather labels
65
  retrieved_y = [torch.stack([self.memory_y[idx] for idx in sample_indices])
66
- for sample_indices in indices]
67
-
68
- # Gather text for the "Best Match" (closest neighbor)
69
- # We take the nearest neighbor (index 0 of top k) for the reply
70
- closest_indices = indices[:, 0].cpu().tolist()
71
- retrieved_text = [self.memory_text[idx] for idx in closest_indices]
72
-
73
- return torch.stack(retrieved_y).to(query_x.device), retrieved_text
74
 
75
  class H3MOS(nn.Module):
76
  def __init__(self, input_dim, hidden_dim, output_dim):
@@ -95,12 +71,12 @@ class H3MOS(nn.Module):
95
 
96
  # Fast Path (Training or Empty Memory)
97
  if training_mode or len(self.hippocampus.memory_x) < 10:
98
- return raw_logits, None
99
 
100
  # Memory Retrieval & Integration
101
- past_labels, retrieved_texts = self.hippocampus.retrieve(x, k=5)
102
  if past_labels is None:
103
- return raw_logits, None
104
 
105
  mem_votes = torch.zeros_like(raw_logits)
106
  for i in range(x.size(0)):
@@ -110,9 +86,7 @@ class H3MOS(nn.Module):
110
  mem_probs = F.softmax(mem_votes, dim=1)
111
 
112
  # Dynamic Gating: 80% Neural, 20% Memory
113
- final_logits = (0.8 * raw_logits) + (0.2 * mem_probs * 5.0)
114
-
115
- return final_logits, retrieved_texts
116
 
117
  # --- 2. DATA SETUP & TRAINING PIPELINE ---
118
 
@@ -126,7 +100,7 @@ try:
126
  except Exception as e:
127
  print("Error loading data:", e)
128
  # Fallback dummy data if kaggle fails (for testing)
129
- df = pd.DataFrame({'content': ['test', 'good job', 'bad day'], 'emoji': ['👍', '❤️', '😭']})
130
 
131
  # Mappings
132
  sent_map = {'❤️':'Positive', '👍':'Positive', '😂':'Positive', '💯':'Positive', '😢':'Negative', '😭':'Negative', '😮':'Neutral'}
@@ -137,24 +111,13 @@ tfidf = TfidfVectorizer(max_features=600, stop_words='english')
137
  X_sparse = tfidf.fit_transform(df['content'])
138
  X_dense = torch.FloatTensor(X_sparse.toarray()).to(device)
139
 
140
- # Reply Bank Construction (For non-neural models)
141
- # We organize valid "utterances" by their emoji label to simulate responses
142
- reply_bank = {}
143
- unique_emojis = df['emoji'].unique()
144
- for emo in unique_emojis:
145
- # Filter messages that resulted in this emoji
146
- msgs = df[df['emoji'] == emo]['content'].tolist()
147
- # Keep short, punchy replies
148
- msgs = [m for m in msgs if len(m.split()) < 15]
149
- reply_bank[emo] = msgs if msgs else ["Interesting."]
150
-
151
  # Model Zoo Containers
152
  tasks = ['emoji', 'sentiment', 'intent']
153
  model_names = ['DISTIL', 'RandomForest', 'SVM', 'NaiveBayes', 'LogReg', 'GradBoost']
154
  zoo = {task: {} for task in tasks}
155
  encoders = {}
156
 
157
- print("🧠 Training Models & Encoding Memories... (This may take a moment)")
158
 
159
  for task in tasks:
160
  # Prepare Labels
@@ -177,23 +140,18 @@ for task in tasks:
177
  optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
178
 
179
  model.train()
180
- # Short training loop
181
  for epoch in range(25):
182
  optimizer.zero_grad()
183
- out, _ = model(X_dense, training_mode=True)
184
  loss = F.cross_entropy(out, y_tensor)
185
  loss.backward()
186
  optimizer.step()
187
-
188
- # Populate memory: DISTIL learns by storing training examples
189
- # We store 10% of data per epoch to build the "brain"
190
  if epoch % 5 == 0:
191
  with torch.no_grad():
192
- # Random sample indices
193
- idx = torch.randperm(X_dense.size(0))[:100]
194
- # Store Vector + Label + Actual Text Content
195
- batch_text = df.iloc[idx.cpu().numpy()]['content'].tolist()
196
- model.hippocampus.store(X_dense[idx], y_tensor[idx], batch_text)
197
 
198
  model.eval()
199
  zoo[task]['DISTIL'] = model
@@ -207,24 +165,7 @@ for task in tasks:
207
 
208
  print("✅ Training Complete.")
209
 
210
- # --- 3. INFERENCE & GENERATION LOGIC ---
211
-
212
- def generate_reply(model_name, predicted_emoji, distil_retrieved_text=None):
213
- """
214
- Generates a text reply.
215
- - DISTIL uses Associative Recall (nearest neighbor text).
216
- - Others use Random Sampling from the Reply Bank based on their prediction.
217
- """
218
- try:
219
- if model_name == 'DISTIL' and distil_retrieved_text:
220
- # H3MOS echoes a memory that feels "associatively related"
221
- return f"\"{distil_retrieved_text}\""
222
-
223
- # Standard models pick a vibe-matched message from the dataset
224
- candidates = reply_bank.get(predicted_emoji, ["I don't know what to say."])
225
- return f"\"{random.choice(candidates)}\""
226
- except:
227
- return "..."
228
 
229
  def get_predictions(text):
230
  """Runs all models on the text."""
@@ -233,41 +174,13 @@ def get_predictions(text):
233
 
234
  results = {name: {} for name in model_names}
235
 
236
- # 1. First, get Emoji predictions (Primary task for replies)
237
- emoji_preds = {}
238
- distil_text_memory = None
239
-
240
- # Run Emoji Task first to determine the reply "Vibe"
241
- task = 'emoji'
242
- le = encoders[task]
243
-
244
- for name in model_names:
245
- if name == 'DISTIL':
246
- with torch.no_grad():
247
- logits, mem_texts = zoo[task][name](vec_t)
248
- pred_idx = torch.argmax(logits, dim=1).item()
249
- pred_label = le.inverse_transform([pred_idx])[0]
250
- # Capture the memory text for DISTIL
251
- if mem_texts: distil_text_memory = mem_texts[0]
252
- else:
253
- pred_idx = zoo[task][name].predict(vec_s)[0]
254
- pred_label = le.inverse_transform([pred_idx])[0]
255
-
256
- emoji_preds[name] = pred_label
257
- results[name]['emoji'] = pred_label
258
-
259
- # GENERATE TEXT REPLY
260
- # We pass the memory text if it's DISTIL, otherwise None
261
- mem_txt = distil_text_memory if name == 'DISTIL' else None
262
- results[name]['reply'] = generate_reply(name, pred_label, mem_txt)
263
-
264
- # 2. Run other tasks (Sentiment/Intent) just for labels
265
- for task in ['sentiment', 'intent']:
266
  le = encoders[task]
 
267
  for name in model_names:
268
  if name == 'DISTIL':
269
  with torch.no_grad():
270
- logits, _ = zoo[task][name](vec_t)
271
  pred_idx = torch.argmax(logits, dim=1).item()
272
  pred_label = le.inverse_transform([pred_idx])[0]
273
  else:
@@ -280,6 +193,9 @@ def get_predictions(text):
280
 
281
  # --- 4. UI STYLING & INTERFACE ---
282
 
 
 
 
283
  CSS = """
284
  .chat-window { font-family: 'Segoe UI', sans-serif; }
285
 
@@ -305,8 +221,7 @@ CSS = """
305
 
306
  .model-card {
307
  background: white;
308
- min-width: 160px; /* Wider to fit text */
309
- max-width: 160px;
310
  border-radius: 12px;
311
  padding: 12px;
312
  box-shadow: 0 4px 12px rgba(0,0,0,0.08);
@@ -315,60 +230,40 @@ CSS = """
315
  align-items: center;
316
  border: 1px solid #eee;
317
  transition: transform 0.2s;
318
- position: relative;
319
  }
320
- .model-card:hover { transform: translateY(-3px); border-color: #cbd5e0; }
321
 
322
  .card-name {
323
- font-size: 10px;
324
- font-weight: 800;
325
  text-transform: uppercase;
326
- color: #a0aec0;
327
  margin-bottom: 4px;
328
- letter-spacing: 1px;
329
  }
330
 
331
  .card-emoji {
332
- font-size: 32px;
333
- margin: 2px 0;
334
- line-height: 1;
335
- }
336
-
337
- /* The generated reply bubble */
338
- .card-reply {
339
- font-size: 11px;
340
- color: #2d3748;
341
- background: #edf2f7;
342
- padding: 6px 8px;
343
- border-radius: 8px;
344
- margin: 8px 0;
345
- text-align: center;
346
- font-style: italic;
347
- min-height: 40px;
348
- display: flex;
349
- align-items: center;
350
- justify-content: center;
351
- line-height: 1.2;
352
- width: 100%;
353
  }
354
 
355
  .card-badge {
356
- font-size: 9px;
357
- padding: 2px 6px;
358
- border-radius: 4px;
359
- margin-top: auto; /* Push to bottom */
360
- font-weight: 700;
361
- text-transform: uppercase;
362
  }
363
 
364
- .bg-Pos { background-color: #c6f6d5; color: #22543d; }
365
- .bg-Neg { background-color: #fed7d7; color: #742a2a; }
366
- .bg-Neu { background-color: #e2e8f0; color: #4a5568; }
367
 
368
  .intent-row {
369
- font-size: 9px;
370
- color: #718096;
371
- margin-top: 4px;
 
 
372
  width: 100%;
373
  text-align: center;
374
  }
@@ -380,26 +275,25 @@ def chat_logic(message, history):
380
 
381
  preds = get_predictions(message)
382
 
383
- # 1. Create User Message HTML (with Emoji Consensus)
384
- # Simple majority voting for the "Consensus" bar
385
- emojis = [preds[m]['emoji'] for m in model_names]
386
- reaction_string = "".join(emojis)
387
 
388
  user_html = f"""
389
  <div>
390
  {message}
391
- <div class="user-reactions" title="Consensus">{reaction_string}</div>
392
  </div>
393
  """
394
  history.append({"role": "user", "content": user_html})
395
 
396
- # 2. Create Scrollable Bot Reply HTML
397
  cards_html = '<div class="model-scroll-container">'
398
 
399
  for name in model_names:
400
  p = preds[name]
401
 
402
- # Color coding
403
  sent_cls = "bg-Neu"
404
  if "Pos" in p['sentiment']: sent_cls = "bg-Pos"
405
  elif "Neg" in p['sentiment']: sent_cls = "bg-Neg"
@@ -408,7 +302,6 @@ def chat_logic(message, history):
408
  <div class="model-card">
409
  <div class="card-name">{name}</div>
410
  <div class="card-emoji">{p['emoji']}</div>
411
- <div class="card-reply">{p['reply']}</div>
412
  <div class="card-badge {sent_cls}">{p['sentiment']}</div>
413
  <div class="intent-row">{p['intent']}</div>
414
  </div>
 
12
  from sklearn.preprocessing import LabelEncoder
13
  import kagglehub
14
  import warnings
 
15
 
16
  # Suppress sklearn warnings for cleaner logs
17
  warnings.filterwarnings("ignore")
 
23
  def __init__(self, capacity=2000):
24
  self.memory_x = []
25
  self.memory_y = []
 
26
  self.capacity = capacity
27
 
28
+ def store(self, x, y):
29
  # Store on CPU to save GPU VRAM
30
  curr_x = x.detach().cpu()
31
  curr_y = y.detach().cpu()
32
+ for i in range(curr_x.size(0)):
 
 
 
 
 
 
 
 
 
 
33
  if len(self.memory_x) >= self.capacity:
34
  self.memory_x.pop(0)
35
  self.memory_y.pop(0)
 
 
36
  self.memory_x.append(curr_x[i])
37
  self.memory_y.append(curr_y[i])
 
 
 
38
 
39
  def retrieve(self, query_x, k=5):
40
  if not self.memory_x:
41
+ return None
42
  mem_tensor = torch.stack(self.memory_x).to(query_x.device)
43
  distances = torch.cdist(query_x, mem_tensor)
44
+ top_k_indices = torch.topk(distances, k, largest=False).indices
 
45
 
46
  # Gather labels
47
  retrieved_y = [torch.stack([self.memory_y[idx] for idx in sample_indices])
48
+ for sample_indices in top_k_indices]
49
+ return torch.stack(retrieved_y).to(query_x.device)
 
 
 
 
 
 
50
 
51
  class H3MOS(nn.Module):
52
  def __init__(self, input_dim, hidden_dim, output_dim):
 
71
 
72
  # Fast Path (Training or Empty Memory)
73
  if training_mode or len(self.hippocampus.memory_x) < 10:
74
+ return raw_logits
75
 
76
  # Memory Retrieval & Integration
77
+ past_labels = self.hippocampus.retrieve(x, k=5)
78
  if past_labels is None:
79
+ return raw_logits
80
 
81
  mem_votes = torch.zeros_like(raw_logits)
82
  for i in range(x.size(0)):
 
86
  mem_probs = F.softmax(mem_votes, dim=1)
87
 
88
  # Dynamic Gating: 80% Neural, 20% Memory
89
+ return (0.8 * raw_logits) + (0.2 * mem_probs * 5.0)
 
 
90
 
91
  # --- 2. DATA SETUP & TRAINING PIPELINE ---
92
 
 
100
  except Exception as e:
101
  print("Error loading data:", e)
102
  # Fallback dummy data if kaggle fails (for testing)
103
+ df = pd.DataFrame({'content': ['test'], 'emoji': ['👍']})
104
 
105
  # Mappings
106
  sent_map = {'❤️':'Positive', '👍':'Positive', '😂':'Positive', '💯':'Positive', '😢':'Negative', '😭':'Negative', '😮':'Neutral'}
 
111
  X_sparse = tfidf.fit_transform(df['content'])
112
  X_dense = torch.FloatTensor(X_sparse.toarray()).to(device)
113
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Model Zoo Containers
115
  tasks = ['emoji', 'sentiment', 'intent']
116
  model_names = ['DISTIL', 'RandomForest', 'SVM', 'NaiveBayes', 'LogReg', 'GradBoost']
117
  zoo = {task: {} for task in tasks}
118
  encoders = {}
119
 
120
+ print("🧠 Training Models... (This may take a moment)")
121
 
122
  for task in tasks:
123
  # Prepare Labels
 
140
  optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
141
 
142
  model.train()
143
+ # Short training loop for demo speed
144
  for epoch in range(25):
145
  optimizer.zero_grad()
146
+ out = model(X_dense, training_mode=True)
147
  loss = F.cross_entropy(out, y_tensor)
148
  loss.backward()
149
  optimizer.step()
150
+ # Populate memory occasionally
 
 
151
  if epoch % 5 == 0:
152
  with torch.no_grad():
153
+ idx = torch.randperm(X_dense.size(0))[:50]
154
+ model.hippocampus.store(X_dense[idx], y_tensor[idx])
 
 
 
155
 
156
  model.eval()
157
  zoo[task]['DISTIL'] = model
 
165
 
166
  print("✅ Training Complete.")
167
 
168
+ # --- 3. INFERENCE LOGIC ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def get_predictions(text):
171
  """Runs all models on the text."""
 
174
 
175
  results = {name: {} for name in model_names}
176
 
177
+ for task in tasks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  le = encoders[task]
179
+
180
  for name in model_names:
181
  if name == 'DISTIL':
182
  with torch.no_grad():
183
+ logits = zoo[task][name](vec_t)
184
  pred_idx = torch.argmax(logits, dim=1).item()
185
  pred_label = le.inverse_transform([pred_idx])[0]
186
  else:
 
193
 
194
  # --- 4. UI STYLING & INTERFACE ---
195
 
196
+ def get_avatar_url(seed):
197
+ return f"https://api.dicebear.com/7.x/bottts/svg?seed={seed}&backgroundColor=transparent&size=128"
198
+
199
  CSS = """
200
  .chat-window { font-family: 'Segoe UI', sans-serif; }
201
 
 
221
 
222
  .model-card {
223
  background: white;
224
+ min-width: 140px;
 
225
  border-radius: 12px;
226
  padding: 12px;
227
  box-shadow: 0 4px 12px rgba(0,0,0,0.08);
 
230
  align-items: center;
231
  border: 1px solid #eee;
232
  transition: transform 0.2s;
 
233
  }
234
+ .model-card:hover { transform: translateY(-3px); }
235
 
236
  .card-name {
237
+ font-size: 11px;
238
+ font-weight: 700;
239
  text-transform: uppercase;
240
+ color: #888;
241
  margin-bottom: 4px;
 
242
  }
243
 
244
  .card-emoji {
245
+ font-size: 28px;
246
+ margin: 4px 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  }
248
 
249
  .card-badge {
250
+ font-size: 10px;
251
+ padding: 2px 8px;
252
+ border-radius: 10px;
253
+ margin-top: 4px;
254
+ font-weight: 600;
 
255
  }
256
 
257
+ .bg-Pos { background-color: #e6fffa; color: #2c7a7b; }
258
+ .bg-Neg { background-color: #fff5f5; color: #c53030; }
259
+ .bg-Neu { background-color: #f7fafc; color: #4a5568; }
260
 
261
  .intent-row {
262
+ font-size: 10px;
263
+ color: #666;
264
+ margin-top: 6px;
265
+ border-top: 1px dashed #eee;
266
+ padding-top: 4px;
267
  width: 100%;
268
  text-align: center;
269
  }
 
275
 
276
  preds = get_predictions(message)
277
 
278
+ # 1. Create User Message HTML (with Emoji Reaction Bar)
279
+ # Order: DISTIL, RF, SVM, NB, LR, GB
280
+ reaction_string = "".join([preds[m]['emoji'] for m in model_names])
 
281
 
282
  user_html = f"""
283
  <div>
284
  {message}
285
+ <div class="user-reactions" title="Consensus: {reaction_string}">{reaction_string}</div>
286
  </div>
287
  """
288
  history.append({"role": "user", "content": user_html})
289
 
290
+ # 2. Create Single Bot Reply HTML (Horizontal Scroll Cards)
291
  cards_html = '<div class="model-scroll-container">'
292
 
293
  for name in model_names:
294
  p = preds[name]
295
 
296
+ # Color coding for sentiment
297
  sent_cls = "bg-Neu"
298
  if "Pos" in p['sentiment']: sent_cls = "bg-Pos"
299
  elif "Neg" in p['sentiment']: sent_cls = "bg-Neg"
 
302
  <div class="model-card">
303
  <div class="card-name">{name}</div>
304
  <div class="card-emoji">{p['emoji']}</div>
 
305
  <div class="card-badge {sent_cls}">{p['sentiment']}</div>
306
  <div class="intent-row">{p['intent']}</div>
307
  </div>