Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import json | |
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| # Simple RL Classifier using Transformer | |
| ACTIONS = ["TRIP", "NONE", "GITHUB", "MAIL"] | |
| DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl") | |
| app = FastAPI() | |
| # Global model state - loaded lazily | |
| model_state = {"ready": False, "tokenizer": None, "encoder": None, "policy_head": None} | |
| class MessageRequest(BaseModel): | |
| message: str | |
| class ActionResponse(BaseModel): | |
| action: str | |
| score: float | |
| def health(): | |
| return {"status": "ok", "model_ready": model_state["ready"]} | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| encoder = AutoModel.from_pretrained("distilbert-base-uncased") | |
| # Simple policy head | |
| policy_head = nn.Linear(768, len(ACTIONS)) | |
| # Load dataset for training | |
| data = [] | |
| with open(DATASET_PATH, "r") as f: | |
| for line in f: | |
| item = json.loads(line) | |
| user_msg = item["messages"][1]["content"] | |
| label = item["messages"][2]["content"] | |
| data.append((user_msg, ACTIONS.index(label))) | |
| # Quick RL-style training (policy gradient simplified) | |
| optimizer = torch.optim.Adam(policy_head.parameters(), lr=1e-3) | |
| encoder.eval() | |
| for epoch in range(3): | |
| total_reward = 0 | |
| for text, label in data[:100]: # use subset for speed | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64) | |
| with torch.no_grad(): | |
| hidden = encoder(**inputs).last_hidden_state[:, 0, :] | |
| logits = policy_head(hidden) | |
| probs = torch.softmax(logits, dim=-1) | |
| # Sample action (RL style) | |
| action = torch.multinomial(probs, 1).item() | |
| # Reward: +1 if correct, -1 if wrong | |
| reward = 1.0 if action == label else -1.0 | |
| total_reward += reward | |
| # Policy gradient update | |
| log_prob = torch.log(probs[0, action]) | |
| loss = -log_prob * reward | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| return tokenizer, encoder, policy_head | |
| def predict(text, tokenizer, encoder, policy_head): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64) | |
| with torch.no_grad(): | |
| hidden = encoder(**inputs).last_hidden_state[:, 0, :] | |
| logits = policy_head(hidden) | |
| probs = torch.softmax(logits, dim=-1) | |
| action_idx = torch.argmax(probs, dim=-1).item() | |
| score = probs[0, action_idx].item() | |
| return ACTIONS[action_idx], score | |
| async def startup_event(): | |
| import threading | |
| def load_in_background(): | |
| tokenizer, encoder, policy_head = load_model() | |
| model_state["tokenizer"] = tokenizer | |
| model_state["encoder"] = encoder | |
| model_state["policy_head"] = policy_head | |
| model_state["ready"] = True | |
| print("Model loaded and ready!") | |
| # Load model in background thread so server can respond immediately | |
| thread = threading.Thread(target=load_in_background) | |
| thread.start() | |
| def action(request: MessageRequest): | |
| if not model_state["ready"]: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=503, detail="Model is still loading, please wait") | |
| action_name, score = predict( | |
| request.message, | |
| model_state["tokenizer"], | |
| model_state["encoder"], | |
| model_state["policy_head"] | |
| ) | |
| return ActionResponse(action=action_name, score=round(score, 4)) | |