code-19 / app.py
PiotrPasztor's picture
asd
a40c8da
raw
history blame
16.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os
import random
import numpy as np
from collections import deque
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
ACTIONS = ["TRIP", "GITHUB", "MAIL"]
NUM_ACTIONS = len(ACTIONS)
DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")
# Confidence threshold - below this returns NONE
CONFIDENCE_THRESHOLD = 0.6
# Distance threshold for outlier detection (cosine similarity)
DISTANCE_THRESHOLD = 0.93
app = FastAPI()
model_state = {
"ready": False,
"agent": None,
"tokenizer": None,
"encoder": None,
"class_centroids": None, # Mean embeddings per class
}
class MessageRequest(BaseModel):
message: str
class ActionResponse(BaseModel):
action: str
score: float
class PolicyNetwork(nn.Module):
"""Policy network that outputs action probabilities."""
def __init__(self, state_dim, num_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, num_actions)
)
# Initialize last layer with small weights for balanced initial policy
nn.init.xavier_uniform_(self.net[-1].weight, gain=0.01)
nn.init.zeros_(self.net[-1].bias)
def forward(self, state):
return self.net(state)
def get_action_probs(self, state):
logits = self.forward(state)
return F.softmax(logits, dim=-1)
def get_action(self, state, deterministic=False, temperature=1.0):
logits = self.forward(state)
# Apply temperature for exploration control
scaled_logits = logits / temperature
probs = F.softmax(scaled_logits, dim=-1)
if deterministic:
action = torch.argmax(probs, dim=-1)
else:
dist = torch.distributions.Categorical(probs)
action = dist.sample()
return action, probs
class QNetwork(nn.Module):
"""Q-Network for action-value estimation."""
def __init__(self, state_dim, num_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_actions)
)
def forward(self, state):
return self.net(state)
class RLAgent:
"""
RL Agent using Double DQN with proper exploration.
"""
def __init__(self, state_dim, num_actions, lr=1e-3, gamma=0.95):
self.state_dim = state_dim
self.num_actions = num_actions
self.gamma = gamma
# Q-Networks (Double DQN)
self.q_net = QNetwork(state_dim, num_actions)
self.target_q_net = QNetwork(state_dim, num_actions)
self.target_q_net.load_state_dict(self.q_net.state_dict())
# Policy network
self.policy_net = PolicyNetwork(state_dim, num_actions)
self.q_optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=lr, weight_decay=1e-4)
self.policy_optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr=lr, weight_decay=1e-4)
# Exploration parameters
self.epsilon = 1.0
self.epsilon_min = 0.05
self.epsilon_decay = 0.995
self.temperature = 1.0
def select_action(self, state, deterministic=True):
"""Select action given state."""
with torch.no_grad():
if deterministic:
# Use policy network for inference
action, probs = self.policy_net.get_action(state, deterministic=True)
action_idx = action.item()
# Use entropy-based confidence: high entropy = low confidence
entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).item()
max_entropy = np.log(self.num_actions) # Maximum possible entropy
# Confidence based on how certain the distribution is
# Low entropy = high confidence, high entropy = low confidence
confidence = 1.0 - (entropy / max_entropy)
# Also factor in the raw probability
raw_prob = probs[0, action_idx].item()
confidence = confidence * raw_prob
else:
# Epsilon-greedy for training
if random.random() < self.epsilon:
action_idx = random.randint(0, self.num_actions - 1)
confidence = 1.0 / self.num_actions
else:
action, probs = self.policy_net.get_action(state, deterministic=False, temperature=self.temperature)
action_idx = action.item()
confidence = probs[0, action_idx].item()
return action_idx, confidence
def update_q(self, states, actions, rewards, next_states, dones):
"""Update Q-network using TD learning."""
# Current Q values
q_values = self.q_net(states)
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Target Q values (Double DQN)
with torch.no_grad():
# Select best action using online network
next_q_online = self.q_net(next_states)
best_actions = next_q_online.argmax(dim=1)
# Evaluate using target network
next_q_target = self.target_q_net(next_states)
next_q_values = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1)
target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
# Q-network loss
q_loss = F.smooth_l1_loss(q_values, target_q_values)
self.q_optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
self.q_optimizer.step()
return q_loss.item()
def update_policy(self, states, actions):
"""Update policy network to match Q-values (actor-critic style)."""
# Get Q-values for actions
with torch.no_grad():
q_values = self.q_net(states)
# Advantage = Q(s,a) - V(s), where V(s) = E[Q(s,a)]
v_values = q_values.mean(dim=1, keepdim=True)
advantages = q_values - v_values
# Policy logits
logits = self.policy_net(states)
log_probs = F.log_softmax(logits, dim=-1)
# Policy loss: maximize advantage-weighted log probability
action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
action_advantages = advantages.gather(1, actions.unsqueeze(1)).squeeze(1)
# Add entropy bonus for exploration
probs = F.softmax(logits, dim=-1)
entropy = -(probs * log_probs).sum(dim=-1).mean()
policy_loss = -(action_log_probs * action_advantages.detach()).mean() - 0.05 * entropy
self.policy_optimizer.zero_grad()
policy_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.policy_optimizer.step()
return policy_loss.item()
def update_target_network(self, tau=0.005):
"""Soft update target network."""
for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
def decay_exploration(self):
"""Decay exploration parameters."""
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
def load_dataset():
"""Load and parse the dataset."""
data = []
with open(DATASET_PATH, "r") as f:
for line in f:
item = json.loads(line)
user_msg = item["messages"][1]["content"]
label = item["messages"][2]["content"]
if label in ACTIONS:
data.append((user_msg, ACTIONS.index(label)))
random.shuffle(data)
return data
def encode_texts(texts, tokenizer, encoder):
"""Batch encode texts to state representations."""
inputs = tokenizer(texts, return_tensors="pt", truncation=True, max_length=64, padding=True)
with torch.no_grad():
hidden = encoder(**inputs).last_hidden_state[:, 0, :]
return hidden
def train_rl_agent(tokenizer, encoder, data, num_epochs=50, batch_size=64):
"""
Train RL agent using offline RL on dataset.
Uses the dataset as demonstration data:
- States: encoded text messages
- Actions: correct labels from dataset (expert demonstrations)
- Rewards: +1 for correct, -1 for incorrect
"""
state_dim = 768 # DistilBERT hidden size
agent = RLAgent(state_dim, NUM_ACTIONS, lr=3e-4)
print("Encoding all dataset examples...")
# Pre-encode all texts for efficiency
all_texts = [text for text, _ in data]
all_labels = [label for _, label in data]
# Encode in batches
all_states = []
for i in range(0, len(all_texts), batch_size):
batch_texts = all_texts[i:i+batch_size]
batch_states = encode_texts(batch_texts, tokenizer, encoder)
all_states.append(batch_states)
all_states = torch.cat(all_states, dim=0)
all_labels = torch.tensor(all_labels, dtype=torch.long)
print(f"Encoded {len(all_states)} examples")
# Print class distribution
for i, action_name in enumerate(ACTIONS):
count = (all_labels == i).sum().item()
print(f" {action_name}: {count} examples")
# Create next states (shifted by 1, with wraparound)
indices = torch.randperm(len(all_states))
next_states = all_states[indices]
print("Starting RL training...")
for epoch in range(num_epochs):
# Shuffle data each epoch
perm = torch.randperm(len(all_states))
states_shuffled = all_states[perm]
labels_shuffled = all_labels[perm]
next_states_shuffled = next_states[perm]
epoch_q_loss = 0
epoch_policy_loss = 0
num_batches = 0
for i in range(0, len(states_shuffled), batch_size):
batch_states = states_shuffled[i:i+batch_size]
batch_labels = labels_shuffled[i:i+batch_size]
batch_next_states = next_states_shuffled[i:i+batch_size]
# Simple rewards: +1 for correct, -1 for wrong
batch_rewards = torch.ones(len(batch_labels), dtype=torch.float32)
batch_dones = torch.zeros(len(batch_labels), dtype=torch.float32)
# Add negative examples (wrong actions with negative reward)
wrong_actions_list = []
for label in batch_labels:
wrong = (label.item() + random.randint(1, NUM_ACTIONS - 1)) % NUM_ACTIONS
wrong_actions_list.append(wrong)
wrong_actions = torch.tensor(wrong_actions_list, dtype=torch.long)
wrong_rewards = -torch.ones(len(batch_labels), dtype=torch.float32)
# Combine correct and incorrect transitions
combined_states = torch.cat([batch_states, batch_states], dim=0)
combined_actions = torch.cat([batch_labels, wrong_actions], dim=0)
combined_rewards = torch.cat([batch_rewards, wrong_rewards], dim=0)
combined_next_states = torch.cat([batch_next_states, batch_next_states], dim=0)
combined_dones = torch.cat([batch_dones, batch_dones], dim=0)
# Update Q-network
q_loss = agent.update_q(
combined_states, combined_actions, combined_rewards,
combined_next_states, combined_dones
)
# Update policy (only on correct examples)
policy_loss = agent.update_policy(batch_states, batch_labels)
# Soft update target
agent.update_target_network(tau=0.005)
epoch_q_loss += q_loss
epoch_policy_loss += policy_loss
num_batches += 1
agent.decay_exploration()
if (epoch + 1) % 10 == 0:
# Evaluate
with torch.no_grad():
_, probs = agent.policy_net.get_action(all_states, deterministic=True)
predictions = probs.argmax(dim=-1)
accuracy = (predictions == all_labels).float().mean().item() * 100
# Check policy entropy (diversity)
avg_entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean().item()
print(f"Epoch {epoch + 1}/{num_epochs} | "
f"Q-Loss: {epoch_q_loss/num_batches:.4f} | "
f"Policy-Loss: {epoch_policy_loss/num_batches:.4f} | "
f"Accuracy: {accuracy:.1f}% | "
f"Entropy: {avg_entropy:.3f} | "
f"Epsilon: {agent.epsilon:.3f}")
# Set networks to eval mode (disables dropout for deterministic inference)
agent.policy_net.eval()
agent.q_net.eval()
# Final evaluation
print("\nFinal Evaluation:")
with torch.no_grad():
_, probs = agent.policy_net.get_action(all_states, deterministic=True)
predictions = probs.argmax(dim=-1)
for i, action_name in enumerate(ACTIONS):
mask = all_labels == i
if mask.sum() > 0:
action_acc = (predictions[mask] == i).float().mean().item() * 100
print(f" {action_name}: {action_acc:.1f}% ({mask.sum().item()} samples)")
overall_acc = (predictions == all_labels).float().mean().item() * 100
print(f" Overall: {overall_acc:.1f}%")
# Compute class centroids for outlier detection
print("\nComputing class centroids...")
centroids = []
for i in range(NUM_ACTIONS):
mask = all_labels == i
class_states = all_states[mask]
centroid = class_states.mean(dim=0)
centroids.append(centroid)
class_centroids = torch.stack(centroids)
return agent, class_centroids
def load_model():
"""Load encoder and train RL agent."""
print("Loading tokenizer and encoder...")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encoder = AutoModel.from_pretrained("distilbert-base-uncased")
encoder.eval()
print("Loading dataset...")
data = load_dataset()
print(f"Dataset size: {len(data)} examples")
print("Training RL agent...")
agent, class_centroids = train_rl_agent(tokenizer, encoder, data)
return tokenizer, encoder, agent, class_centroids
def predict(text, tokenizer, encoder, agent, class_centroids):
"""Use trained RL agent to predict action for given text."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
with torch.no_grad():
hidden = encoder(**inputs).last_hidden_state[:, 0, :]
action_idx, confidence = agent.select_action(hidden, deterministic=True)
# Compute cosine similarity to closest class centroid
hidden_norm = hidden / hidden.norm(dim=-1, keepdim=True)
centroids_norm = class_centroids / class_centroids.norm(dim=-1, keepdim=True)
similarities = torch.mm(hidden_norm, centroids_norm.t()).squeeze(0)
max_similarity = similarities.max().item()
# Return NONE if similarity is too low OR confidence is too low
if max_similarity < DISTANCE_THRESHOLD or confidence < CONFIDENCE_THRESHOLD:
return "NONE", confidence
return ACTIONS[action_idx], confidence
@app.get("/health")
def health():
return {"status": "ok", "model_ready": model_state["ready"]}
@app.on_event("startup")
async def startup_event():
import threading
def load_in_background():
tokenizer, encoder, agent, class_centroids = load_model()
model_state["tokenizer"] = tokenizer
model_state["encoder"] = encoder
model_state["agent"] = agent
model_state["class_centroids"] = class_centroids
model_state["ready"] = True
print("RL Agent loaded and ready!")
thread = threading.Thread(target=load_in_background)
thread.start()
@app.post("/action", response_model=ActionResponse)
def action(request: MessageRequest):
if not model_state["ready"]:
from fastapi import HTTPException
raise HTTPException(status_code=503, detail="Model is still loading, please wait")
action_name, score = predict(
request.message,
model_state["tokenizer"],
model_state["encoder"],
model_state["agent"],
model_state["class_centroids"]
)
return ActionResponse(action=action_name, score=round(score, 4))