Spaces:
Sleeping
Sleeping
File size: 3,738 Bytes
73c62ee 176390f 73c62ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import torch.nn as nn
import json
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
# Simple RL Classifier using Transformer
ACTIONS = ["TRIP", "NONE", "GITHUB", "MAIL"]
DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")
app = FastAPI()
# Global model state - loaded lazily
model_state = {"ready": False, "tokenizer": None, "encoder": None, "policy_head": None}
class MessageRequest(BaseModel):
message: str
class ActionResponse(BaseModel):
action: str
score: float
@app.get("/health")
def health():
return {"status": "ok", "model_ready": model_state["ready"]}
def load_model():
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
encoder = AutoModel.from_pretrained("distilbert-base-uncased")
# Simple policy head
policy_head = nn.Linear(768, len(ACTIONS))
# Load dataset for training
data = []
with open(DATASET_PATH, "r") as f:
for line in f:
item = json.loads(line)
user_msg = item["messages"][1]["content"]
label = item["messages"][2]["content"]
data.append((user_msg, ACTIONS.index(label)))
# Quick RL-style training (policy gradient simplified)
optimizer = torch.optim.Adam(policy_head.parameters(), lr=1e-3)
encoder.eval()
for epoch in range(3):
total_reward = 0
for text, label in data[:100]: # use subset for speed
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
with torch.no_grad():
hidden = encoder(**inputs).last_hidden_state[:, 0, :]
logits = policy_head(hidden)
probs = torch.softmax(logits, dim=-1)
# Sample action (RL style)
action = torch.multinomial(probs, 1).item()
# Reward: +1 if correct, -1 if wrong
reward = 1.0 if action == label else -1.0
total_reward += reward
# Policy gradient update
log_prob = torch.log(probs[0, action])
loss = -log_prob * reward
optimizer.zero_grad()
loss.backward()
optimizer.step()
return tokenizer, encoder, policy_head
def predict(text, tokenizer, encoder, policy_head):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
with torch.no_grad():
hidden = encoder(**inputs).last_hidden_state[:, 0, :]
logits = policy_head(hidden)
probs = torch.softmax(logits, dim=-1)
action_idx = torch.argmax(probs, dim=-1).item()
score = probs[0, action_idx].item()
return ACTIONS[action_idx], score
@app.on_event("startup")
async def startup_event():
import threading
def load_in_background():
tokenizer, encoder, policy_head = load_model()
model_state["tokenizer"] = tokenizer
model_state["encoder"] = encoder
model_state["policy_head"] = policy_head
model_state["ready"] = True
print("Model loaded and ready!")
# Load model in background thread so server can respond immediately
thread = threading.Thread(target=load_in_background)
thread.start()
@app.post("/action", response_model=ActionResponse)
def action(request: MessageRequest):
if not model_state["ready"]:
from fastapi import HTTPException
raise HTTPException(status_code=503, detail="Model is still loading, please wait")
action_name, score = predict(
request.message,
model_state["tokenizer"],
model_state["encoder"],
model_state["policy_head"]
)
return ActionResponse(action=action_name, score=round(score, 4))
|