import torch import torch.nn as nn import torch.nn.functional as F import json import os import random import numpy as np from collections import deque from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel ACTIONS = ["TRIP", "GITHUB", "MAIL"] NUM_ACTIONS = len(ACTIONS) DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl") # Confidence threshold - below this returns NONE CONFIDENCE_THRESHOLD = 0.6 # Distance threshold for outlier detection (cosine similarity) DISTANCE_THRESHOLD = 0.93 app = FastAPI() model_state = { "ready": False, "agent": None, "tokenizer": None, "encoder": None, "class_centroids": None, # Mean embeddings per class } class MessageRequest(BaseModel): message: str class ActionResponse(BaseModel): action: str score: float class PolicyNetwork(nn.Module): """Policy network that outputs action probabilities.""" def __init__(self, state_dim, num_actions, hidden_dim=128): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim, num_actions) ) # Initialize last layer with small weights for balanced initial policy nn.init.xavier_uniform_(self.net[-1].weight, gain=0.01) nn.init.zeros_(self.net[-1].bias) def forward(self, state): return self.net(state) def get_action_probs(self, state): logits = self.forward(state) return F.softmax(logits, dim=-1) def get_action(self, state, deterministic=False, temperature=1.0): logits = self.forward(state) # Apply temperature for exploration control scaled_logits = logits / temperature probs = F.softmax(scaled_logits, dim=-1) if deterministic: action = torch.argmax(probs, dim=-1) else: dist = torch.distributions.Categorical(probs) action = dist.sample() return action, probs class QNetwork(nn.Module): """Q-Network for action-value estimation.""" def __init__(self, state_dim, num_actions, hidden_dim=128): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_actions) ) def forward(self, state): return self.net(state) class RLAgent: """ RL Agent using Double DQN with proper exploration. """ def __init__(self, state_dim, num_actions, lr=1e-3, gamma=0.95): self.state_dim = state_dim self.num_actions = num_actions self.gamma = gamma # Q-Networks (Double DQN) self.q_net = QNetwork(state_dim, num_actions) self.target_q_net = QNetwork(state_dim, num_actions) self.target_q_net.load_state_dict(self.q_net.state_dict()) # Policy network self.policy_net = PolicyNetwork(state_dim, num_actions) self.q_optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=lr, weight_decay=1e-4) self.policy_optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr=lr, weight_decay=1e-4) # Exploration parameters self.epsilon = 1.0 self.epsilon_min = 0.05 self.epsilon_decay = 0.995 self.temperature = 1.0 def select_action(self, state, deterministic=True): """Select action given state.""" with torch.no_grad(): if deterministic: # Use policy network for inference action, probs = self.policy_net.get_action(state, deterministic=True) action_idx = action.item() # Use entropy-based confidence: high entropy = low confidence entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).item() max_entropy = np.log(self.num_actions) # Maximum possible entropy # Confidence based on how certain the distribution is # Low entropy = high confidence, high entropy = low confidence confidence = 1.0 - (entropy / max_entropy) # Also factor in the raw probability raw_prob = probs[0, action_idx].item() confidence = confidence * raw_prob else: # Epsilon-greedy for training if random.random() < self.epsilon: action_idx = random.randint(0, self.num_actions - 1) confidence = 1.0 / self.num_actions else: action, probs = self.policy_net.get_action(state, deterministic=False, temperature=self.temperature) action_idx = action.item() confidence = probs[0, action_idx].item() return action_idx, confidence def update_q(self, states, actions, rewards, next_states, dones): """Update Q-network using TD learning.""" # Current Q values q_values = self.q_net(states) q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # Target Q values (Double DQN) with torch.no_grad(): # Select best action using online network next_q_online = self.q_net(next_states) best_actions = next_q_online.argmax(dim=1) # Evaluate using target network next_q_target = self.target_q_net(next_states) next_q_values = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1) target_q_values = rewards + self.gamma * next_q_values * (1 - dones) # Q-network loss q_loss = F.smooth_l1_loss(q_values, target_q_values) self.q_optimizer.zero_grad() q_loss.backward() torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0) self.q_optimizer.step() return q_loss.item() def update_policy(self, states, actions): """Update policy network to match Q-values (actor-critic style).""" # Get Q-values for actions with torch.no_grad(): q_values = self.q_net(states) # Advantage = Q(s,a) - V(s), where V(s) = E[Q(s,a)] v_values = q_values.mean(dim=1, keepdim=True) advantages = q_values - v_values # Policy logits logits = self.policy_net(states) log_probs = F.log_softmax(logits, dim=-1) # Policy loss: maximize advantage-weighted log probability action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1) action_advantages = advantages.gather(1, actions.unsqueeze(1)).squeeze(1) # Add entropy bonus for exploration probs = F.softmax(logits, dim=-1) entropy = -(probs * log_probs).sum(dim=-1).mean() policy_loss = -(action_log_probs * action_advantages.detach()).mean() - 0.05 * entropy self.policy_optimizer.zero_grad() policy_loss.backward() torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.policy_optimizer.step() return policy_loss.item() def update_target_network(self, tau=0.005): """Soft update target network.""" for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) def decay_exploration(self): """Decay exploration parameters.""" self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) def load_dataset(): """Load and parse the dataset.""" data = [] with open(DATASET_PATH, "r") as f: for line in f: item = json.loads(line) user_msg = item["messages"][1]["content"] label = item["messages"][2]["content"] if label in ACTIONS: data.append((user_msg, ACTIONS.index(label))) random.shuffle(data) return data def encode_texts(texts, tokenizer, encoder): """Batch encode texts to state representations.""" inputs = tokenizer(texts, return_tensors="pt", truncation=True, max_length=64, padding=True) with torch.no_grad(): hidden = encoder(**inputs).last_hidden_state[:, 0, :] return hidden def train_rl_agent(tokenizer, encoder, data, num_epochs=50, batch_size=64): """ Train RL agent using offline RL on dataset. Uses the dataset as demonstration data: - States: encoded text messages - Actions: correct labels from dataset (expert demonstrations) - Rewards: +1 for correct, -1 for incorrect """ state_dim = 768 # DistilBERT hidden size agent = RLAgent(state_dim, NUM_ACTIONS, lr=3e-4) print("Encoding all dataset examples...") # Pre-encode all texts for efficiency all_texts = [text for text, _ in data] all_labels = [label for _, label in data] # Encode in batches all_states = [] for i in range(0, len(all_texts), batch_size): batch_texts = all_texts[i:i+batch_size] batch_states = encode_texts(batch_texts, tokenizer, encoder) all_states.append(batch_states) all_states = torch.cat(all_states, dim=0) all_labels = torch.tensor(all_labels, dtype=torch.long) print(f"Encoded {len(all_states)} examples") # Print class distribution for i, action_name in enumerate(ACTIONS): count = (all_labels == i).sum().item() print(f" {action_name}: {count} examples") # Create next states (shifted by 1, with wraparound) indices = torch.randperm(len(all_states)) next_states = all_states[indices] print("Starting RL training...") for epoch in range(num_epochs): # Shuffle data each epoch perm = torch.randperm(len(all_states)) states_shuffled = all_states[perm] labels_shuffled = all_labels[perm] next_states_shuffled = next_states[perm] epoch_q_loss = 0 epoch_policy_loss = 0 num_batches = 0 for i in range(0, len(states_shuffled), batch_size): batch_states = states_shuffled[i:i+batch_size] batch_labels = labels_shuffled[i:i+batch_size] batch_next_states = next_states_shuffled[i:i+batch_size] # Simple rewards: +1 for correct, -1 for wrong batch_rewards = torch.ones(len(batch_labels), dtype=torch.float32) batch_dones = torch.zeros(len(batch_labels), dtype=torch.float32) # Add negative examples (wrong actions with negative reward) wrong_actions_list = [] for label in batch_labels: wrong = (label.item() + random.randint(1, NUM_ACTIONS - 1)) % NUM_ACTIONS wrong_actions_list.append(wrong) wrong_actions = torch.tensor(wrong_actions_list, dtype=torch.long) wrong_rewards = -torch.ones(len(batch_labels), dtype=torch.float32) # Combine correct and incorrect transitions combined_states = torch.cat([batch_states, batch_states], dim=0) combined_actions = torch.cat([batch_labels, wrong_actions], dim=0) combined_rewards = torch.cat([batch_rewards, wrong_rewards], dim=0) combined_next_states = torch.cat([batch_next_states, batch_next_states], dim=0) combined_dones = torch.cat([batch_dones, batch_dones], dim=0) # Update Q-network q_loss = agent.update_q( combined_states, combined_actions, combined_rewards, combined_next_states, combined_dones ) # Update policy (only on correct examples) policy_loss = agent.update_policy(batch_states, batch_labels) # Soft update target agent.update_target_network(tau=0.005) epoch_q_loss += q_loss epoch_policy_loss += policy_loss num_batches += 1 agent.decay_exploration() if (epoch + 1) % 10 == 0: # Evaluate with torch.no_grad(): _, probs = agent.policy_net.get_action(all_states, deterministic=True) predictions = probs.argmax(dim=-1) accuracy = (predictions == all_labels).float().mean().item() * 100 # Check policy entropy (diversity) avg_entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean().item() print(f"Epoch {epoch + 1}/{num_epochs} | " f"Q-Loss: {epoch_q_loss/num_batches:.4f} | " f"Policy-Loss: {epoch_policy_loss/num_batches:.4f} | " f"Accuracy: {accuracy:.1f}% | " f"Entropy: {avg_entropy:.3f} | " f"Epsilon: {agent.epsilon:.3f}") # Set networks to eval mode (disables dropout for deterministic inference) agent.policy_net.eval() agent.q_net.eval() # Final evaluation print("\nFinal Evaluation:") with torch.no_grad(): _, probs = agent.policy_net.get_action(all_states, deterministic=True) predictions = probs.argmax(dim=-1) for i, action_name in enumerate(ACTIONS): mask = all_labels == i if mask.sum() > 0: action_acc = (predictions[mask] == i).float().mean().item() * 100 print(f" {action_name}: {action_acc:.1f}% ({mask.sum().item()} samples)") overall_acc = (predictions == all_labels).float().mean().item() * 100 print(f" Overall: {overall_acc:.1f}%") # Compute class centroids for outlier detection print("\nComputing class centroids...") centroids = [] for i in range(NUM_ACTIONS): mask = all_labels == i class_states = all_states[mask] centroid = class_states.mean(dim=0) centroids.append(centroid) class_centroids = torch.stack(centroids) return agent, class_centroids def load_model(): """Load encoder and train RL agent.""" print("Loading tokenizer and encoder...") tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") encoder = AutoModel.from_pretrained("distilbert-base-uncased") encoder.eval() print("Loading dataset...") data = load_dataset() print(f"Dataset size: {len(data)} examples") print("Training RL agent...") agent, class_centroids = train_rl_agent(tokenizer, encoder, data) return tokenizer, encoder, agent, class_centroids def predict(text, tokenizer, encoder, agent, class_centroids): """Use trained RL agent to predict action for given text.""" inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64) with torch.no_grad(): hidden = encoder(**inputs).last_hidden_state[:, 0, :] action_idx, confidence = agent.select_action(hidden, deterministic=True) # Compute cosine similarity to closest class centroid hidden_norm = hidden / hidden.norm(dim=-1, keepdim=True) centroids_norm = class_centroids / class_centroids.norm(dim=-1, keepdim=True) similarities = torch.mm(hidden_norm, centroids_norm.t()).squeeze(0) max_similarity = similarities.max().item() # Return NONE if similarity is too low OR confidence is too low if max_similarity < DISTANCE_THRESHOLD or confidence < CONFIDENCE_THRESHOLD: return "NONE", confidence return ACTIONS[action_idx], confidence @app.get("/health") def health(): return {"status": "ok", "model_ready": model_state["ready"]} @app.on_event("startup") async def startup_event(): import threading def load_in_background(): tokenizer, encoder, agent, class_centroids = load_model() model_state["tokenizer"] = tokenizer model_state["encoder"] = encoder model_state["agent"] = agent model_state["class_centroids"] = class_centroids model_state["ready"] = True print("RL Agent loaded and ready!") thread = threading.Thread(target=load_in_background) thread.start() @app.post("/action", response_model=ActionResponse) def action(request: MessageRequest): if not model_state["ready"]: from fastapi import HTTPException raise HTTPException(status_code=503, detail="Model is still loading, please wait") action_name, score = predict( request.message, model_state["tokenizer"], model_state["encoder"], model_state["agent"], model_state["class_centroids"] ) return ActionResponse(action=action_name, score=round(score, 4))