Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import json | |
| import os | |
| import random | |
| import numpy as np | |
| from collections import deque | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| from datasets import load_dataset | |
| ACTIONS = ["GITHUB", "MAIL", "CALENDAR"] | |
| NUM_ACTIONS = len(ACTIONS) | |
| HF_DATASET = "iteratehack/code19-dataset" | |
| # Confidence threshold - below this returns NONE | |
| CONFIDENCE_THRESHOLD = 0.6 | |
| # Distance threshold for outlier detection (cosine similarity) | |
| DISTANCE_THRESHOLD = 0.93 | |
| app = FastAPI() | |
| model_state = { | |
| "ready": False, | |
| "agent": None, | |
| "tokenizer": None, | |
| "encoder": None, | |
| "class_centroids": None, # Mean embeddings per class | |
| } | |
| class MessageRequest(BaseModel): | |
| message: str | |
| class ActionResponse(BaseModel): | |
| action: str | |
| score: float | |
| class PolicyNetwork(nn.Module): | |
| """Policy network that outputs action probabilities.""" | |
| def __init__(self, state_dim, num_actions, hidden_dim=128): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(state_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_dim, num_actions) | |
| ) | |
| # Initialize last layer with small weights for balanced initial policy | |
| nn.init.xavier_uniform_(self.net[-1].weight, gain=0.01) | |
| nn.init.zeros_(self.net[-1].bias) | |
| def forward(self, state): | |
| return self.net(state) | |
| def get_action_probs(self, state): | |
| logits = self.forward(state) | |
| return F.softmax(logits, dim=-1) | |
| def get_action(self, state, deterministic=False, temperature=1.0): | |
| logits = self.forward(state) | |
| # Apply temperature for exploration control | |
| scaled_logits = logits / temperature | |
| probs = F.softmax(scaled_logits, dim=-1) | |
| if deterministic: | |
| action = torch.argmax(probs, dim=-1) | |
| else: | |
| dist = torch.distributions.Categorical(probs) | |
| action = dist.sample() | |
| return action, probs | |
| class QNetwork(nn.Module): | |
| """Q-Network for action-value estimation.""" | |
| def __init__(self, state_dim, num_actions, hidden_dim=128): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(state_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, num_actions) | |
| ) | |
| def forward(self, state): | |
| return self.net(state) | |
| class RLAgent: | |
| """ | |
| RL Agent using Double DQN with proper exploration. | |
| """ | |
| def __init__(self, state_dim, num_actions, lr=1e-3, gamma=0.95): | |
| self.state_dim = state_dim | |
| self.num_actions = num_actions | |
| self.gamma = gamma | |
| # Q-Networks (Double DQN) | |
| self.q_net = QNetwork(state_dim, num_actions) | |
| self.target_q_net = QNetwork(state_dim, num_actions) | |
| self.target_q_net.load_state_dict(self.q_net.state_dict()) | |
| # Policy network | |
| self.policy_net = PolicyNetwork(state_dim, num_actions) | |
| self.q_optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=lr, weight_decay=1e-4) | |
| self.policy_optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr=lr, weight_decay=1e-4) | |
| # Exploration parameters | |
| self.epsilon = 1.0 | |
| self.epsilon_min = 0.05 | |
| self.epsilon_decay = 0.995 | |
| self.temperature = 1.0 | |
| def select_action(self, state, deterministic=True): | |
| """Select action given state.""" | |
| with torch.no_grad(): | |
| if deterministic: | |
| # Use policy network for inference | |
| action, probs = self.policy_net.get_action(state, deterministic=True) | |
| action_idx = action.item() | |
| # Use entropy-based confidence: high entropy = low confidence | |
| entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).item() | |
| max_entropy = np.log(self.num_actions) # Maximum possible entropy | |
| # Confidence based on how certain the distribution is | |
| # Low entropy = high confidence, high entropy = low confidence | |
| confidence = 1.0 - (entropy / max_entropy) | |
| # Also factor in the raw probability | |
| raw_prob = probs[0, action_idx].item() | |
| confidence = confidence * raw_prob | |
| else: | |
| # Epsilon-greedy for training | |
| if random.random() < self.epsilon: | |
| action_idx = random.randint(0, self.num_actions - 1) | |
| confidence = 1.0 / self.num_actions | |
| else: | |
| action, probs = self.policy_net.get_action(state, deterministic=False, temperature=self.temperature) | |
| action_idx = action.item() | |
| confidence = probs[0, action_idx].item() | |
| return action_idx, confidence | |
| def update_q(self, states, actions, rewards, next_states, dones): | |
| """Update Q-network using TD learning.""" | |
| # Current Q values | |
| q_values = self.q_net(states) | |
| q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) | |
| # Target Q values (Double DQN) | |
| with torch.no_grad(): | |
| # Select best action using online network | |
| next_q_online = self.q_net(next_states) | |
| best_actions = next_q_online.argmax(dim=1) | |
| # Evaluate using target network | |
| next_q_target = self.target_q_net(next_states) | |
| next_q_values = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1) | |
| target_q_values = rewards + self.gamma * next_q_values * (1 - dones) | |
| # Q-network loss | |
| q_loss = F.smooth_l1_loss(q_values, target_q_values) | |
| self.q_optimizer.zero_grad() | |
| q_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0) | |
| self.q_optimizer.step() | |
| return q_loss.item() | |
| def update_policy(self, states, actions): | |
| """Update policy network to match Q-values (actor-critic style).""" | |
| # Get Q-values for actions | |
| with torch.no_grad(): | |
| q_values = self.q_net(states) | |
| # Advantage = Q(s,a) - V(s), where V(s) = E[Q(s,a)] | |
| v_values = q_values.mean(dim=1, keepdim=True) | |
| advantages = q_values - v_values | |
| # Policy logits | |
| logits = self.policy_net(states) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| # Policy loss: maximize advantage-weighted log probability | |
| action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1) | |
| action_advantages = advantages.gather(1, actions.unsqueeze(1)).squeeze(1) | |
| # Add entropy bonus for exploration | |
| probs = F.softmax(logits, dim=-1) | |
| entropy = -(probs * log_probs).sum(dim=-1).mean() | |
| policy_loss = -(action_log_probs * action_advantages.detach()).mean() - 0.05 * entropy | |
| self.policy_optimizer.zero_grad() | |
| policy_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) | |
| self.policy_optimizer.step() | |
| return policy_loss.item() | |
| def update_target_network(self, tau=0.005): | |
| """Soft update target network.""" | |
| for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()): | |
| target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) | |
| def decay_exploration(self): | |
| """Decay exploration parameters.""" | |
| self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) | |
| def load_hf_dataset(): | |
| dataset = load_dataset(HF_DATASET, split="train") | |
| data = [] | |
| for item in dataset: | |
| user_msg = item["messages"][1]["content"] | |
| label = item["messages"][2]["content"] | |
| if label in ACTIONS: | |
| data.append((user_msg, ACTIONS.index(label))) | |
| random.shuffle(data) | |
| return data | |
| def encode_texts(texts, tokenizer, encoder): | |
| """Batch encode texts to state representations.""" | |
| inputs = tokenizer(texts, return_tensors="pt", truncation=True, max_length=64, padding=True) | |
| with torch.no_grad(): | |
| hidden = encoder(**inputs).last_hidden_state[:, 0, :] | |
| return hidden | |
| def train_rl_agent(tokenizer, encoder, data, num_epochs=50, batch_size=64): | |
| """ | |
| Train RL agent using offline RL on dataset. | |
| Uses the dataset as demonstration data: | |
| - States: encoded text messages | |
| - Actions: correct labels from dataset (expert demonstrations) | |
| - Rewards: +1 for correct, -1 for incorrect | |
| """ | |
| state_dim = 768 # DistilBERT hidden size | |
| agent = RLAgent(state_dim, NUM_ACTIONS, lr=3e-4) | |
| print("Encoding all dataset examples...") | |
| # Pre-encode all texts for efficiency | |
| all_texts = [text for text, _ in data] | |
| all_labels = [label for _, label in data] | |
| # Encode in batches | |
| all_states = [] | |
| for i in range(0, len(all_texts), batch_size): | |
| batch_texts = all_texts[i:i+batch_size] | |
| batch_states = encode_texts(batch_texts, tokenizer, encoder) | |
| all_states.append(batch_states) | |
| all_states = torch.cat(all_states, dim=0) | |
| all_labels = torch.tensor(all_labels, dtype=torch.long) | |
| print(f"Encoded {len(all_states)} examples") | |
| # Print class distribution | |
| for i, action_name in enumerate(ACTIONS): | |
| count = (all_labels == i).sum().item() | |
| print(f" {action_name}: {count} examples") | |
| # Create next states (shifted by 1, with wraparound) | |
| indices = torch.randperm(len(all_states)) | |
| next_states = all_states[indices] | |
| print("Starting RL training...") | |
| for epoch in range(num_epochs): | |
| # Shuffle data each epoch | |
| perm = torch.randperm(len(all_states)) | |
| states_shuffled = all_states[perm] | |
| labels_shuffled = all_labels[perm] | |
| next_states_shuffled = next_states[perm] | |
| epoch_q_loss = 0 | |
| epoch_policy_loss = 0 | |
| num_batches = 0 | |
| for i in range(0, len(states_shuffled), batch_size): | |
| batch_states = states_shuffled[i:i+batch_size] | |
| batch_labels = labels_shuffled[i:i+batch_size] | |
| batch_next_states = next_states_shuffled[i:i+batch_size] | |
| # Simple rewards: +1 for correct, -1 for wrong | |
| batch_rewards = torch.ones(len(batch_labels), dtype=torch.float32) | |
| batch_dones = torch.zeros(len(batch_labels), dtype=torch.float32) | |
| # Add negative examples (wrong actions with negative reward) | |
| wrong_actions_list = [] | |
| for label in batch_labels: | |
| wrong = (label.item() + random.randint(1, NUM_ACTIONS - 1)) % NUM_ACTIONS | |
| wrong_actions_list.append(wrong) | |
| wrong_actions = torch.tensor(wrong_actions_list, dtype=torch.long) | |
| wrong_rewards = -torch.ones(len(batch_labels), dtype=torch.float32) | |
| # Combine correct and incorrect transitions | |
| combined_states = torch.cat([batch_states, batch_states], dim=0) | |
| combined_actions = torch.cat([batch_labels, wrong_actions], dim=0) | |
| combined_rewards = torch.cat([batch_rewards, wrong_rewards], dim=0) | |
| combined_next_states = torch.cat([batch_next_states, batch_next_states], dim=0) | |
| combined_dones = torch.cat([batch_dones, batch_dones], dim=0) | |
| # Update Q-network | |
| q_loss = agent.update_q( | |
| combined_states, combined_actions, combined_rewards, | |
| combined_next_states, combined_dones | |
| ) | |
| # Update policy (only on correct examples) | |
| policy_loss = agent.update_policy(batch_states, batch_labels) | |
| # Soft update target | |
| agent.update_target_network(tau=0.005) | |
| epoch_q_loss += q_loss | |
| epoch_policy_loss += policy_loss | |
| num_batches += 1 | |
| agent.decay_exploration() | |
| if (epoch + 1) % 10 == 0: | |
| # Evaluate | |
| with torch.no_grad(): | |
| _, probs = agent.policy_net.get_action(all_states, deterministic=True) | |
| predictions = probs.argmax(dim=-1) | |
| accuracy = (predictions == all_labels).float().mean().item() * 100 | |
| # Check policy entropy (diversity) | |
| avg_entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean().item() | |
| print(f"Epoch {epoch + 1}/{num_epochs} | " | |
| f"Q-Loss: {epoch_q_loss/num_batches:.4f} | " | |
| f"Policy-Loss: {epoch_policy_loss/num_batches:.4f} | " | |
| f"Accuracy: {accuracy:.1f}% | " | |
| f"Entropy: {avg_entropy:.3f} | " | |
| f"Epsilon: {agent.epsilon:.3f}") | |
| # Set networks to eval mode (disables dropout for deterministic inference) | |
| agent.policy_net.eval() | |
| agent.q_net.eval() | |
| # Final evaluation | |
| print("\nFinal Evaluation:") | |
| with torch.no_grad(): | |
| _, probs = agent.policy_net.get_action(all_states, deterministic=True) | |
| predictions = probs.argmax(dim=-1) | |
| for i, action_name in enumerate(ACTIONS): | |
| mask = all_labels == i | |
| if mask.sum() > 0: | |
| action_acc = (predictions[mask] == i).float().mean().item() * 100 | |
| print(f" {action_name}: {action_acc:.1f}% ({mask.sum().item()} samples)") | |
| overall_acc = (predictions == all_labels).float().mean().item() * 100 | |
| print(f" Overall: {overall_acc:.1f}%") | |
| # Compute class centroids for outlier detection | |
| print("\nComputing class centroids...") | |
| centroids = [] | |
| for i in range(NUM_ACTIONS): | |
| mask = all_labels == i | |
| class_states = all_states[mask] | |
| centroid = class_states.mean(dim=0) | |
| centroids.append(centroid) | |
| class_centroids = torch.stack(centroids) | |
| return agent, class_centroids | |
| def load_model(): | |
| """Load encoder and train RL agent.""" | |
| print("Loading tokenizer and encoder...") | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| encoder = AutoModel.from_pretrained("distilbert-base-uncased") | |
| encoder.eval() | |
| print("Loading dataset...") | |
| data = load_hf_dataset() | |
| print(f"Dataset size: {len(data)} examples") | |
| print("Training RL agent...") | |
| agent, class_centroids = train_rl_agent(tokenizer, encoder, data) | |
| return tokenizer, encoder, agent, class_centroids | |
| def predict(text, tokenizer, encoder, agent, class_centroids): | |
| """Use trained RL agent to predict action for given text.""" | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64) | |
| with torch.no_grad(): | |
| hidden = encoder(**inputs).last_hidden_state[:, 0, :] | |
| action_idx, confidence = agent.select_action(hidden, deterministic=True) | |
| # Compute cosine similarity to closest class centroid | |
| hidden_norm = hidden / hidden.norm(dim=-1, keepdim=True) | |
| centroids_norm = class_centroids / class_centroids.norm(dim=-1, keepdim=True) | |
| similarities = torch.mm(hidden_norm, centroids_norm.t()).squeeze(0) | |
| max_similarity = similarities.max().item() | |
| # Return NONE if similarity is too low OR confidence is too low | |
| if max_similarity < DISTANCE_THRESHOLD or confidence < CONFIDENCE_THRESHOLD: | |
| return "NONE", confidence | |
| return ACTIONS[action_idx], confidence | |
| def health(): | |
| return {"status": "ok", "model_ready": model_state["ready"]} | |
| async def startup_event(): | |
| import threading | |
| def load_in_background(): | |
| tokenizer, encoder, agent, class_centroids = load_model() | |
| model_state["tokenizer"] = tokenizer | |
| model_state["encoder"] = encoder | |
| model_state["agent"] = agent | |
| model_state["class_centroids"] = class_centroids | |
| model_state["ready"] = True | |
| print("RL Agent loaded and ready!") | |
| thread = threading.Thread(target=load_in_background) | |
| thread.start() | |
| def action(request: MessageRequest): | |
| if not model_state["ready"]: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=503, detail="Model is still loading, please wait") | |
| action_name, score = predict( | |
| request.message, | |
| model_state["tokenizer"], | |
| model_state["encoder"], | |
| model_state["agent"], | |
| model_state["class_centroids"] | |
| ) | |
| return ActionResponse(action=action_name, score=round(score, 4)) | |