import torch import torch.nn as nn import json import os from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel # Simple RL Classifier using Transformer ACTIONS = ["TRIP", "NONE", "GITHUB", "MAIL"] DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl") app = FastAPI() # Global model state - loaded lazily model_state = {"ready": False, "tokenizer": None, "encoder": None, "policy_head": None} class MessageRequest(BaseModel): message: str class ActionResponse(BaseModel): action: str score: float @app.get("/health") def health(): return {"status": "ok", "model_ready": model_state["ready"]} def load_model(): tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") encoder = AutoModel.from_pretrained("distilbert-base-uncased") # Simple policy head policy_head = nn.Linear(768, len(ACTIONS)) # Load dataset for training data = [] with open(DATASET_PATH, "r") as f: for line in f: item = json.loads(line) user_msg = item["messages"][1]["content"] label = item["messages"][2]["content"] data.append((user_msg, ACTIONS.index(label))) # Quick RL-style training (policy gradient simplified) optimizer = torch.optim.Adam(policy_head.parameters(), lr=1e-3) encoder.eval() for epoch in range(3): total_reward = 0 for text, label in data[:100]: # use subset for speed inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64) with torch.no_grad(): hidden = encoder(**inputs).last_hidden_state[:, 0, :] logits = policy_head(hidden) probs = torch.softmax(logits, dim=-1) # Sample action (RL style) action = torch.multinomial(probs, 1).item() # Reward: +1 if correct, -1 if wrong reward = 1.0 if action == label else -1.0 total_reward += reward # Policy gradient update log_prob = torch.log(probs[0, action]) loss = -log_prob * reward optimizer.zero_grad() loss.backward() optimizer.step() return tokenizer, encoder, policy_head def predict(text, tokenizer, encoder, policy_head): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64) with torch.no_grad(): hidden = encoder(**inputs).last_hidden_state[:, 0, :] logits = policy_head(hidden) probs = torch.softmax(logits, dim=-1) action_idx = torch.argmax(probs, dim=-1).item() score = probs[0, action_idx].item() return ACTIONS[action_idx], score @app.on_event("startup") async def startup_event(): import threading def load_in_background(): tokenizer, encoder, policy_head = load_model() model_state["tokenizer"] = tokenizer model_state["encoder"] = encoder model_state["policy_head"] = policy_head model_state["ready"] = True print("Model loaded and ready!") # Load model in background thread so server can respond immediately thread = threading.Thread(target=load_in_background) thread.start() @app.post("/action", response_model=ActionResponse) def action(request: MessageRequest): if not model_state["ready"]: from fastapi import HTTPException raise HTTPException(status_code=503, detail="Model is still loading, please wait") action_name, score = predict( request.message, model_state["tokenizer"], model_state["encoder"], model_state["policy_head"] ) return ActionResponse(action=action_name, score=round(score, 4))