dwmk commited on
Commit
ab45c90
ยท
verified ยท
1 Parent(s): f04c010

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -113
app.py CHANGED
@@ -6,203 +6,169 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from sklearn.feature_extraction.text import TfidfVectorizer
8
  from sklearn.ensemble import RandomForestClassifier
9
- from sklearn.linear_model import LogisticRegression
10
- from sklearn.svm import SVC
11
- from sklearn.naive_bayes import MultinomialNB
12
  from sklearn.preprocessing import LabelEncoder
13
  import kagglehub
14
  import time
15
  import random
16
 
17
- # --- 1. CORE MODEL LOGIC ---
18
- # Using the H3MOS and EpisodicMemory architecture provided in the benchmarks
19
 
20
  class EpisodicMemory:
21
- """Mimics Hippocampal retention and retrieval"""
22
  def __init__(self, capacity=2000):
23
  self.memory_x, self.memory_y = [], []
24
  self.capacity = capacity
25
-
26
  def store(self, x, y):
27
  curr_x, curr_y = x.detach().cpu(), y.detach().cpu()
28
  for i in range(curr_x.size(0)):
29
  if len(self.memory_x) >= self.capacity:
30
  self.memory_x.pop(0); self.memory_y.pop(0)
31
  self.memory_x.append(curr_x[i]); self.memory_y.append(curr_y[i])
32
-
33
  def retrieve(self, query_x, k=5):
34
- if len(self.memory_x) < k: return None
35
  mem_tensor = torch.stack(self.memory_x).to(query_x.device)
36
  distances = torch.cdist(query_x, mem_tensor)
37
  top_k_indices = torch.topk(distances, k, largest=False).indices
38
- retrieved_y = [torch.stack([self.memory_y[idx] for idx in sample_indices]) for sample_indices in top_k_indices]
39
- return torch.stack(retrieved_y).to(query_x.device)
40
-
41
- class ExecutiveCore(nn.Module):
42
- def __init__(self, input_dim, hidden_dim):
43
- super().__init__()
44
- self.net = nn.Sequential(
45
- nn.Linear(input_dim, hidden_dim),
46
- nn.LayerNorm(hidden_dim),
47
- nn.GELU(),
48
- nn.Dropout(0.2),
49
- nn.Linear(hidden_dim, hidden_dim),
50
- nn.GELU()
51
- )
52
- def forward(self, x): return self.net(x)
53
-
54
- class MotorPolicy(nn.Module):
55
- def __init__(self, hidden_dim, output_dim):
56
- super().__init__()
57
- self.fc = nn.Linear(hidden_dim, output_dim)
58
- def forward(self, x): return self.fc(x)
59
 
60
  class H3MOS(nn.Module):
61
- """The DISTIL-H3MOS model architecture"""
62
  def __init__(self, input_dim, hidden_dim, output_dim):
63
  super().__init__()
64
- self.executive = ExecutiveCore(input_dim, hidden_dim)
65
- self.motor = MotorPolicy(hidden_dim, output_dim)
66
  self.hippocampus = EpisodicMemory()
67
-
68
  def forward(self, x, training_mode=False):
69
  z = self.executive(x)
70
- if training_mode or len(self.hippocampus.memory_x) < 10:
71
- return self.motor(z)
72
-
73
- past_labels = self.hippocampus.retrieve(x, k=5)
74
  raw_logits = self.motor(z)
75
-
 
76
  if past_labels is None: return raw_logits
77
-
78
  mem_votes = torch.zeros_like(raw_logits)
79
  for i in range(x.size(0)):
80
  votes = torch.bincount(past_labels[i], minlength=raw_logits.size(1)).float()
81
  mem_votes[i] = votes
82
-
83
  return (0.8 * raw_logits) + (0.2 * F.softmax(mem_votes, dim=1) * 5.0)
84
 
85
- # --- 2. DATA & TRAINING ---
 
86
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
- print(f"Loading data on {device}...")
88
 
89
  path = kagglehub.dataset_download('dewanmukto/social-messages-and-emoji-reactions')
90
  df = pd.read_csv(path+"/messages_emojis.csv").dropna(subset=['content'])
91
 
92
- # Mapping logic from the original benchmark
93
  sent_map = {'โค๏ธ':'Pos', '๐Ÿ‘':'Pos', '๐Ÿ˜‚':'Pos', '๐Ÿ’ฏ':'Pos', '๐Ÿ˜ข':'Neg', '๐Ÿ˜ญ':'Neg', '๐Ÿ˜ฎ':'Neu'}
94
  intent_map = {'โค๏ธ':'Emotion', '๐Ÿ‘':'Agreement', '๐Ÿ˜‚':'Emotion', '๐Ÿ˜ฎ':'Surprise'}
95
 
96
- tfidf = TfidfVectorizer(max_features=1000, stop_words='english')
97
  X_sparse = tfidf.fit_transform(df['content'])
98
  X_dense = torch.FloatTensor(X_sparse.toarray()).to(device)
99
 
100
- tasks = {
101
- 'emoji': df['emoji'].values,
102
- 'sentiment': df['emoji'].apply(lambda x: sent_map.get(x, 'Neutral')).values,
103
- 'intent': df['emoji'].apply(lambda x: intent_map.get(x, 'Other')).values
104
- }
105
-
106
  model_zoo = {}
107
  encoders = {}
108
 
109
- for task, y_labels in tasks.items():
 
 
110
  le = LabelEncoder()
111
  y_enc = torch.LongTensor(le.fit_transform(y_labels)).to(device)
112
  encoders[task] = le
113
 
114
- # Train H3MOS
115
- h3_model = H3MOS(X_dense.shape[1], 64, len(le.classes_)).to(device)
116
- opt = torch.optim.Adam(h3_model.parameters(), lr=0.01)
117
- for _ in range(30):
118
- opt.zero_grad()
119
- loss = F.cross_entropy(h3_model(X_dense, True), y_enc)
120
- loss.backward(); opt.step()
121
 
122
- # Store Sklearn models for comparison
123
- rf = RandomForestClassifier(n_estimators=50).fit(X_sparse, y_labels)
124
-
125
- model_zoo[task] = {"H3MOS": h3_model, "RandomForest": rf}
126
-
127
- # --- 3. CHAT INTERFACE ---
128
 
129
- def get_avatar(seed):
130
- return f"https://api.dicebear.com/7.x/adventurer/svg?seed={seed}"
131
 
132
  CSS = """
133
- .reaction-pill {
134
- background: rgba(0, 0, 0, 0.05);
135
- border-radius: 12px;
136
- padding: 2px 10px;
137
- font-size: 16px;
138
- margin-top: 8px;
139
- display: inline-block;
140
- border: 1px solid #ddd;
141
  }
 
 
 
142
  """
143
 
144
- def predict_all(text):
 
 
 
145
  vec_s = tfidf.transform([text])
146
  vec_t = torch.FloatTensor(vec_s.toarray()).to(device)
147
-
148
- results = {}
149
  for task in ['emoji', 'sentiment', 'intent']:
150
- # H3MOS Inference
151
  with torch.no_grad():
152
- h3_out = model_zoo[task]["H3MOS"](vec_t)
153
- h3_pred = encoders[task].inverse_transform([torch.argmax(h3_out).item()])[0]
154
-
155
- # RF Inference
156
- rf_pred = model_zoo[task]["RandomForest"].predict(vec_s)[0]
157
- results[task] = {"H3MOS": h3_pred, "RandomForest": rf_pred}
158
- return results
159
 
160
- def chat_fn(message, history):
161
  if not message: return "", history
162
 
163
- preds = predict_all(message)
164
 
165
- # Create Reaction HTML (Hover shows model breakdown)
166
- emoji_h3 = preds['emoji']['H3MOS']
167
- emoji_rf = preds['emoji']['RandomForest']
168
- reaction_html = f"<div class='reaction-pill' title='H3MOS: {emoji_h3} | RF: {emoji_rf}'>{emoji_h3} ๐Ÿค–</div>"
 
169
 
170
  # 1. Add User Message
171
  history.append({"role": "user", "content": f"{message}<br>{reaction_html}"})
172
  yield history
173
 
174
- # 2. Sequential Bot Replies
175
- bots = [("DISTIL-H3MOS", "H3MOS"), ("RandomForest", "RandomForest")]
176
- for bot_name, key in bots:
177
- time.sleep(random.uniform(0.8, 1.5)) # Simulation delay
 
 
 
 
 
 
 
178
 
179
- sent = preds['sentiment'][key]
180
- intent = preds['intent'][key]
 
 
 
 
 
 
 
 
 
181
 
182
- bot_msg = f"**Sentiment:** {sent} \n**Intent:** {intent}"
183
- history.append({
184
- "role": "assistant",
185
- "content": bot_msg,
186
- "metadata": {"title": bot_name}
187
- })
188
  yield history
189
 
190
  with gr.Blocks() as demo:
191
- gr.Markdown("## ๐Ÿ’ฌ Social AI Group Chat")
192
- gr.Markdown("Type a message. Models will react with emojis and reply with their analysis.")
193
 
194
  chatbot = gr.Chatbot(
195
  elem_id="chat-window",
196
- avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=User"),
197
- bubble_full_width=False
 
198
  )
199
 
200
  with gr.Row():
201
- msg_input = gr.Textbox(placeholder="Say something...", show_label=False, scale=4)
202
- submit = gr.Button("Send", variant="primary")
203
 
204
- msg_input.submit(chat_fn, [msg_input, chatbot], [msg_input, chatbot])
205
- submit.click(chat_fn, [msg_input, chatbot], [msg_input, chatbot])
206
 
207
- # Corrected launch for Gradio 6.0 compatibility
208
  demo.launch(css=CSS)
 
6
  import torch.nn.functional as F
7
  from sklearn.feature_extraction.text import TfidfVectorizer
8
  from sklearn.ensemble import RandomForestClassifier
 
 
 
9
  from sklearn.preprocessing import LabelEncoder
10
  import kagglehub
11
  import time
12
  import random
13
 
14
+ # --- 1. ARCHITECTURE (From your benchmark) ---
 
15
 
16
  class EpisodicMemory:
 
17
  def __init__(self, capacity=2000):
18
  self.memory_x, self.memory_y = [], []
19
  self.capacity = capacity
 
20
  def store(self, x, y):
21
  curr_x, curr_y = x.detach().cpu(), y.detach().cpu()
22
  for i in range(curr_x.size(0)):
23
  if len(self.memory_x) >= self.capacity:
24
  self.memory_x.pop(0); self.memory_y.pop(0)
25
  self.memory_x.append(curr_x[i]); self.memory_y.append(curr_y[i])
 
26
  def retrieve(self, query_x, k=5):
27
+ if not self.memory_x: return None
28
  mem_tensor = torch.stack(self.memory_x).to(query_x.device)
29
  distances = torch.cdist(query_x, mem_tensor)
30
  top_k_indices = torch.topk(distances, k, largest=False).indices
31
+ return torch.stack([torch.stack([self.memory_y[idx] for idx in s_idx]) for s_idx in top_k_indices]).to(query_x.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class H3MOS(nn.Module):
 
34
  def __init__(self, input_dim, hidden_dim, output_dim):
35
  super().__init__()
36
+ self.executive = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU())
37
+ self.motor = nn.Linear(hidden_dim, output_dim)
38
  self.hippocampus = EpisodicMemory()
 
39
  def forward(self, x, training_mode=False):
40
  z = self.executive(x)
 
 
 
 
41
  raw_logits = self.motor(z)
42
+ if training_mode: return raw_logits
43
+ past_labels = self.hippocampus.retrieve(x, k=5)
44
  if past_labels is None: return raw_logits
 
45
  mem_votes = torch.zeros_like(raw_logits)
46
  for i in range(x.size(0)):
47
  votes = torch.bincount(past_labels[i], minlength=raw_logits.size(1)).float()
48
  mem_votes[i] = votes
 
49
  return (0.8 * raw_logits) + (0.2 * F.softmax(mem_votes, dim=1) * 5.0)
50
 
51
+ # --- 2. DATA LOAD & QUICK TRAIN ---
52
+
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ print(f"Initializing models on {device}...")
55
 
56
  path = kagglehub.dataset_download('dewanmukto/social-messages-and-emoji-reactions')
57
  df = pd.read_csv(path+"/messages_emojis.csv").dropna(subset=['content'])
58
 
 
59
  sent_map = {'โค๏ธ':'Pos', '๐Ÿ‘':'Pos', '๐Ÿ˜‚':'Pos', '๐Ÿ’ฏ':'Pos', '๐Ÿ˜ข':'Neg', '๐Ÿ˜ญ':'Neg', '๐Ÿ˜ฎ':'Neu'}
60
  intent_map = {'โค๏ธ':'Emotion', '๐Ÿ‘':'Agreement', '๐Ÿ˜‚':'Emotion', '๐Ÿ˜ฎ':'Surprise'}
61
 
62
+ tfidf = TfidfVectorizer(max_features=500, stop_words='english')
63
  X_sparse = tfidf.fit_transform(df['content'])
64
  X_dense = torch.FloatTensor(X_sparse.toarray()).to(device)
65
 
 
 
 
 
 
 
66
  model_zoo = {}
67
  encoders = {}
68
 
69
+ # We train H3MOS and a RandomForest for variety
70
+ for task in ['emoji', 'sentiment', 'intent']:
71
+ y_labels = df['emoji'].values if task == 'emoji' else df['emoji'].apply(lambda x: sent_map.get(x, 'Neutral') if task == 'sentiment' else intent_map.get(x, 'Other')).values
72
  le = LabelEncoder()
73
  y_enc = torch.LongTensor(le.fit_transform(y_labels)).to(device)
74
  encoders[task] = le
75
 
76
+ # H3MOS
77
+ h3 = H3MOS(X_dense.shape[1], 64, len(le.classes_)).to(device)
78
+ opt = torch.optim.Adam(h3.parameters(), lr=0.01)
79
+ for _ in range(20):
80
+ opt.zero_grad(); F.cross_entropy(h3(X_dense, True), y_enc).backward(); opt.step()
 
 
81
 
82
+ # RandomForest
83
+ rf = RandomForestClassifier(n_estimators=20).fit(X_sparse, y_labels)
84
+ model_zoo[task] = {"H3MOS": h3, "RF": rf}
 
 
 
85
 
86
+ # --- 3. THE UI & LOGIC ---
 
87
 
88
  CSS = """
89
+ .reaction-btn {
90
+ background: #f0f2f5; border: 1px solid #ddd; border-radius: 15px;
91
+ padding: 2px 8px; font-size: 14px; cursor: pointer; margin-top: 5px;
 
 
 
 
 
92
  }
93
+ .bot-header { display: flex; align-items: center; margin-bottom: 5px; }
94
+ .bot-avatar { width: 28px; height: 28px; border-radius: 50%; margin-right: 8px; border: 1px solid #eee; }
95
+ .bot-name { font-weight: bold; font-size: 0.9em; color: #555; }
96
  """
97
 
98
+ def get_avatar_url(name):
99
+ return f"https://api.dicebear.com/7.x/adventurer/svg?seed={name}"
100
+
101
+ def predict(text):
102
  vec_s = tfidf.transform([text])
103
  vec_t = torch.FloatTensor(vec_s.toarray()).to(device)
104
+ res = {}
 
105
  for task in ['emoji', 'sentiment', 'intent']:
 
106
  with torch.no_grad():
107
+ h3_idx = torch.argmax(model_zoo[task]["H3MOS"](vec_t)).item()
108
+ h3_p = encoders[task].inverse_transform([h3_idx])[0]
109
+ rf_p = model_zoo[task]["RF"].predict(vec_s)[0]
110
+ res[task] = {"DISTIL-H3MOS": h3_p, "RandomForest": rf_p}
111
+ return res
 
 
112
 
113
+ def chat_interface(message, history):
114
  if not message: return "", history
115
 
116
+ preds = predict(message)
117
 
118
+ # Reaction Logic
119
+ h3_emoji = preds['emoji']['DISTIL-H3MOS']
120
+ rf_emoji = preds['emoji']['RandomForest']
121
+ details = f"DISTIL-H3MOS: {h3_emoji} | RandomForest: {rf_emoji}"
122
+ reaction_html = f"<button class='reaction-btn' title='{details}'>{h3_emoji} ๐Ÿค–</button>"
123
 
124
  # 1. Add User Message
125
  history.append({"role": "user", "content": f"{message}<br>{reaction_html}"})
126
  yield history
127
 
128
+ # 2. Simulate Group Members Replying
129
+ bots = ["DISTIL-H3MOS", "RandomForest"]
130
+ random.shuffle(bots)
131
+
132
+ for bot in bots:
133
+ # Simulate "Typing..."
134
+ time.sleep(random.uniform(0.5, 1.2))
135
+
136
+ sent = preds['sentiment'][bot]
137
+ intent = preds['intent'][bot]
138
+ avatar = get_avatar_url(bot)
139
 
140
+ # Format as a social media message
141
+ bot_content = f"""
142
+ <div class="bot-header">
143
+ <img src="{avatar}" class="bot-avatar">
144
+ <span class="bot-name">{bot}</span>
145
+ </div>
146
+ <div style="padding-left: 36px;">
147
+ <b>Sentiment:</b> {sent}<br>
148
+ <b>Intent:</b> {intent}
149
+ </div>
150
+ """
151
 
152
+ history.append({"role": "assistant", "content": bot_content})
 
 
 
 
 
153
  yield history
154
 
155
  with gr.Blocks() as demo:
156
+ gr.Markdown("### ๐Ÿ“ฑ Model Group Chat")
157
+ gr.Markdown("The models below analyze your message for sentiment and intent in real-time.")
158
 
159
  chatbot = gr.Chatbot(
160
  elem_id="chat-window",
161
+ # avatar_images takes a tuple: (user_avatar, bot_placeholder)
162
+ avatar_images=(get_avatar_url("User"), None),
163
+ height=500
164
  )
165
 
166
  with gr.Row():
167
+ txt = gr.Textbox(placeholder="Type a message...", show_label=False, scale=4)
168
+ btn = gr.Button("Send", variant="primary")
169
 
170
+ txt.submit(chat_interface, [txt, chatbot], [txt, chatbot])
171
+ btn.click(chat_interface, [txt, chatbot], [txt, chatbot])
172
 
173
+ # Launch with the CSS injection
174
  demo.launch(css=CSS)