code-19 / app.py
PiotrPasztor's picture
fix
176390f
raw
history blame
3.74 kB
import torch
import torch.nn as nn
import json
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
# Simple RL Classifier using Transformer
ACTIONS = ["TRIP", "NONE", "GITHUB", "MAIL"]
DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")
app = FastAPI()
# Global model state - loaded lazily
model_state = {"ready": False, "tokenizer": None, "encoder": None, "policy_head": None}
class MessageRequest(BaseModel):
message: str
class ActionResponse(BaseModel):
action: str
score: float
@app.get("/health")
def health():
return {"status": "ok", "model_ready": model_state["ready"]}
def load_model():
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encoder = AutoModel.from_pretrained("distilbert-base-uncased")
# Simple policy head
policy_head = nn.Linear(768, len(ACTIONS))
# Load dataset for training
data = []
with open(DATASET_PATH, "r") as f:
for line in f:
item = json.loads(line)
user_msg = item["messages"][1]["content"]
label = item["messages"][2]["content"]
data.append((user_msg, ACTIONS.index(label)))
# Quick RL-style training (policy gradient simplified)
optimizer = torch.optim.Adam(policy_head.parameters(), lr=1e-3)
encoder.eval()
for epoch in range(3):
total_reward = 0
for text, label in data[:100]: # use subset for speed
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
with torch.no_grad():
hidden = encoder(**inputs).last_hidden_state[:, 0, :]
logits = policy_head(hidden)
probs = torch.softmax(logits, dim=-1)
# Sample action (RL style)
action = torch.multinomial(probs, 1).item()
# Reward: +1 if correct, -1 if wrong
reward = 1.0 if action == label else -1.0
total_reward += reward
# Policy gradient update
log_prob = torch.log(probs[0, action])
loss = -log_prob * reward
optimizer.zero_grad()
loss.backward()
optimizer.step()
return tokenizer, encoder, policy_head
def predict(text, tokenizer, encoder, policy_head):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
with torch.no_grad():
hidden = encoder(**inputs).last_hidden_state[:, 0, :]
logits = policy_head(hidden)
probs = torch.softmax(logits, dim=-1)
action_idx = torch.argmax(probs, dim=-1).item()
score = probs[0, action_idx].item()
return ACTIONS[action_idx], score
@app.on_event("startup")
async def startup_event():
import threading
def load_in_background():
tokenizer, encoder, policy_head = load_model()
model_state["tokenizer"] = tokenizer
model_state["encoder"] = encoder
model_state["policy_head"] = policy_head
model_state["ready"] = True
print("Model loaded and ready!")
# Load model in background thread so server can respond immediately
thread = threading.Thread(target=load_in_background)
thread.start()
@app.post("/action", response_model=ActionResponse)
def action(request: MessageRequest):
if not model_state["ready"]:
from fastapi import HTTPException
raise HTTPException(status_code=503, detail="Model is still loading, please wait")
action_name, score = predict(
request.message,
model_state["tokenizer"],
model_state["encoder"],
model_state["policy_head"]
)
return ActionResponse(action=action_name, score=round(score, 4))