Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier | |
| from sklearn.svm import SVC | |
| from sklearn.naive_bayes import MultinomialNB | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.preprocessing import LabelEncoder | |
| import kagglehub | |
| import warnings | |
| # Suppress sklearn warnings for cleaner logs | |
| warnings.filterwarnings("ignore") | |
| # --- 1. ARCHITECTURE: H3MOS (Hippocampal Memory & Executive Core) --- | |
| class EpisodicMemory: | |
| """Mimics Hippocampal retention and retrieval of recent experiences.""" | |
| def __init__(self, capacity=2000): | |
| self.memory_x = [] | |
| self.memory_y = [] | |
| self.capacity = capacity | |
| def store(self, x, y): | |
| # Store on CPU to save GPU VRAM | |
| curr_x = x.detach().cpu() | |
| curr_y = y.detach().cpu() | |
| for i in range(curr_x.size(0)): | |
| if len(self.memory_x) >= self.capacity: | |
| self.memory_x.pop(0) | |
| self.memory_y.pop(0) | |
| self.memory_x.append(curr_x[i]) | |
| self.memory_y.append(curr_y[i]) | |
| def retrieve(self, query_x, k=5): | |
| if not self.memory_x: | |
| return None | |
| mem_tensor = torch.stack(self.memory_x).to(query_x.device) | |
| distances = torch.cdist(query_x, mem_tensor) | |
| top_k_indices = torch.topk(distances, k, largest=False).indices | |
| # Gather labels | |
| retrieved_y = [torch.stack([self.memory_y[idx] for idx in sample_indices]) | |
| for sample_indices in top_k_indices] | |
| return torch.stack(retrieved_y).to(query_x.device) | |
| class H3MOS(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, output_dim): | |
| super().__init__() | |
| # Executive Core | |
| self.executive = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.GELU() | |
| ) | |
| # Motor Policy | |
| self.motor = nn.Linear(hidden_dim, output_dim) | |
| # Hippocampus | |
| self.hippocampus = EpisodicMemory(capacity=2000) | |
| def forward(self, x, training_mode=False): | |
| z = self.executive(x) | |
| raw_logits = self.motor(z) | |
| # Fast Path (Training or Empty Memory) | |
| if training_mode or len(self.hippocampus.memory_x) < 10: | |
| return raw_logits | |
| # Memory Retrieval & Integration | |
| past_labels = self.hippocampus.retrieve(x, k=5) | |
| if past_labels is None: | |
| return raw_logits | |
| mem_votes = torch.zeros_like(raw_logits) | |
| for i in range(x.size(0)): | |
| votes = torch.bincount(past_labels[i], minlength=raw_logits.size(1)).float() | |
| mem_votes[i] = votes | |
| mem_probs = F.softmax(mem_votes, dim=1) | |
| # Dynamic Gating: 80% Neural, 20% Memory | |
| return (0.8 * raw_logits) + (0.2 * mem_probs * 5.0) | |
| # --- 2. DATA SETUP & TRAINING PIPELINE --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"๐ Initializing System on {device}...") | |
| # Load Data | |
| try: | |
| path = kagglehub.dataset_download('dewanmukto/social-messages-and-emoji-reactions') | |
| df = pd.read_csv(path+"/messages_emojis.csv").dropna(subset=['content']) | |
| except Exception as e: | |
| print("Error loading data:", e) | |
| # Fallback dummy data if kaggle fails (for testing) | |
| df = pd.DataFrame({'content': ['test'], 'emoji': ['๐']}) | |
| # Mappings | |
| sent_map = {'โค๏ธ':'Positive', '๐':'Positive', '๐':'Positive', '๐ฏ':'Positive', '๐ข':'Negative', '๐ญ':'Negative', '๐ฎ':'Neutral'} | |
| intent_map = {'โค๏ธ':'Emotion', '๐':'Agreement', '๐':'Emotion', '๐ฎ':'Surprise'} | |
| # Vectorization | |
| tfidf = TfidfVectorizer(max_features=600, stop_words='english') | |
| X_sparse = tfidf.fit_transform(df['content']) | |
| X_dense = torch.FloatTensor(X_sparse.toarray()).to(device) | |
| # Model Zoo Containers | |
| tasks = ['emoji', 'sentiment', 'intent'] | |
| model_names = ['DISTIL', 'RandomForest', 'SVM', 'NaiveBayes', 'LogReg', 'GradBoost'] | |
| zoo = {task: {} for task in tasks} | |
| encoders = {} | |
| print("๐ง Training Models... (This may take a moment)") | |
| for task in tasks: | |
| # Prepare Labels | |
| if task == 'emoji': | |
| raw_y = df['emoji'].values | |
| elif task == 'sentiment': | |
| raw_y = df['emoji'].apply(lambda x: sent_map.get(x, 'Neutral')).values | |
| else: | |
| raw_y = df['emoji'].apply(lambda x: intent_map.get(x, 'Other')).values | |
| le = LabelEncoder() | |
| y_nums = le.fit_transform(raw_y) | |
| encoders[task] = le | |
| # 1. Train DISTIL-H3MOS (PyTorch) | |
| y_tensor = torch.LongTensor(y_nums).to(device) | |
| output_dim = len(le.classes_) | |
| model = H3MOS(X_dense.shape[1], 64, output_dim).to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=0.01) | |
| model.train() | |
| # Short training loop for demo speed | |
| for epoch in range(25): | |
| optimizer.zero_grad() | |
| out = model(X_dense, training_mode=True) | |
| loss = F.cross_entropy(out, y_tensor) | |
| loss.backward() | |
| optimizer.step() | |
| # Populate memory occasionally | |
| if epoch % 5 == 0: | |
| with torch.no_grad(): | |
| idx = torch.randperm(X_dense.size(0))[:50] | |
| model.hippocampus.store(X_dense[idx], y_tensor[idx]) | |
| model.eval() | |
| zoo[task]['DISTIL'] = model | |
| # 2. Train Sklearn Models | |
| zoo[task]['RandomForest'] = RandomForestClassifier(n_estimators=50).fit(X_sparse, y_nums) | |
| zoo[task]['SVM'] = SVC(kernel='linear').fit(X_sparse, y_nums) | |
| zoo[task]['NaiveBayes'] = MultinomialNB().fit(X_sparse, y_nums) | |
| zoo[task]['LogReg'] = LogisticRegression(max_iter=500).fit(X_sparse, y_nums) | |
| zoo[task]['GradBoost'] = GradientBoostingClassifier(n_estimators=30).fit(X_sparse, y_nums) | |
| print("โ Training Complete.") | |
| # --- 3. INFERENCE LOGIC --- | |
| def get_predictions(text): | |
| """Runs all models on the text.""" | |
| vec_s = tfidf.transform([text]) | |
| vec_t = torch.FloatTensor(vec_s.toarray()).to(device) | |
| results = {name: {} for name in model_names} | |
| for task in tasks: | |
| le = encoders[task] | |
| for name in model_names: | |
| if name == 'DISTIL': | |
| with torch.no_grad(): | |
| logits = zoo[task][name](vec_t) | |
| pred_idx = torch.argmax(logits, dim=1).item() | |
| pred_label = le.inverse_transform([pred_idx])[0] | |
| else: | |
| pred_idx = zoo[task][name].predict(vec_s)[0] | |
| pred_label = le.inverse_transform([pred_idx])[0] | |
| results[name][task] = pred_label | |
| return results | |
| # --- 4. UI STYLING & INTERFACE --- | |
| def get_avatar_url(seed): | |
| return f"https://api.dicebear.com/7.x/bottts/svg?seed={seed}&backgroundColor=transparent&size=128" | |
| CSS = """ | |
| .chat-window { font-family: 'Segoe UI', sans-serif; } | |
| /* User Message Styling */ | |
| .user-reactions { | |
| margin-top: 8px; | |
| padding-top: 6px; | |
| border-top: 1px solid rgba(255,255,255,0.3); | |
| font-size: 1.2em; | |
| letter-spacing: 4px; | |
| text-align: right; | |
| opacity: 0.9; | |
| } | |
| /* Bot Reply Container */ | |
| .model-scroll-container { | |
| display: flex; | |
| gap: 12px; | |
| overflow-x: auto; | |
| padding: 10px 4px; | |
| scrollbar-width: thin; | |
| } | |
| .model-card { | |
| background: white; | |
| min-width: 140px; | |
| border-radius: 12px; | |
| padding: 12px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.08); | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| border: 1px solid #eee; | |
| transition: transform 0.2s; | |
| } | |
| .model-card:hover { transform: translateY(-3px); } | |
| .card-name { | |
| font-size: 11px; | |
| font-weight: 700; | |
| text-transform: uppercase; | |
| color: #888; | |
| margin-bottom: 4px; | |
| } | |
| .card-emoji { | |
| font-size: 28px; | |
| margin: 4px 0; | |
| } | |
| .card-badge { | |
| font-size: 10px; | |
| padding: 2px 8px; | |
| border-radius: 10px; | |
| margin-top: 4px; | |
| font-weight: 600; | |
| } | |
| .bg-Pos { background-color: #e6fffa; color: #2c7a7b; } | |
| .bg-Neg { background-color: #fff5f5; color: #c53030; } | |
| .bg-Neu { background-color: #f7fafc; color: #4a5568; } | |
| .intent-row { | |
| font-size: 10px; | |
| color: #666; | |
| margin-top: 6px; | |
| border-top: 1px dashed #eee; | |
| padding-top: 4px; | |
| width: 100%; | |
| text-align: center; | |
| } | |
| """ | |
| def chat_logic(message, history): | |
| if not message: | |
| return "", history | |
| preds = get_predictions(message) | |
| # 1. Create User Message HTML (with Emoji Reaction Bar) | |
| # Order: DISTIL, RF, SVM, NB, LR, GB | |
| reaction_string = "".join([preds[m]['emoji'] for m in model_names]) | |
| user_html = f""" | |
| <div> | |
| {message} | |
| <div class="user-reactions" title="Consensus: {reaction_string}">{reaction_string}</div> | |
| </div> | |
| """ | |
| history.append({"role": "user", "content": user_html}) | |
| # 2. Create Single Bot Reply HTML (Horizontal Scroll Cards) | |
| cards_html = '<div class="model-scroll-container">' | |
| for name in model_names: | |
| p = preds[name] | |
| # Color coding for sentiment | |
| sent_cls = "bg-Neu" | |
| if "Pos" in p['sentiment']: sent_cls = "bg-Pos" | |
| elif "Neg" in p['sentiment']: sent_cls = "bg-Neg" | |
| cards_html += f""" | |
| <div class="model-card"> | |
| <div class="card-name">{name}</div> | |
| <div class="card-emoji">{p['emoji']}</div> | |
| <div class="card-badge {sent_cls}">{p['sentiment']}</div> | |
| <div class="intent-row">{p['intent']}</div> | |
| </div> | |
| """ | |
| cards_html += "</div>" | |
| history.append({"role": "assistant", "content": cards_html}) | |
| return "", history | |
| # --- 5. LAUNCH APP --- | |
| with gr.Blocks(css=CSS, title="SentiChat") as demo: | |
| gr.Markdown("### ๐ค Message Analysis") | |
| gr.Markdown("Type a message to see how different AI/ML architectures interpret it. They were trained on [this dataset](https://www.kaggle.com/datasets/dewanmukto/social-messages-and-emoji-reactions).") | |
| chatbot = gr.Chatbot( | |
| elem_id="chat-window", | |
| avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Admin"), | |
| height=600, | |
| render_markdown=False # Important to render our custom HTML | |
| ) | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| placeholder="Type a message (e.g., 'I cant believe you did that!')", | |
| scale=4, | |
| show_label=False, | |
| container=False | |
| ) | |
| btn = gr.Button("Send โถ Analyze", variant="primary", scale=1) | |
| # Event bindings | |
| txt.submit(chat_logic, [txt, chatbot], [txt, chatbot]) | |
| btn.click(chat_logic, [txt, chatbot], [txt, chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch() |