File size: 3,738 Bytes
73c62ee
 
 
 
 
 
 
 
 
176390f
73c62ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
import json
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel

# Simple RL Classifier using Transformer
ACTIONS = ["TRIP", "NONE", "GITHUB", "MAIL"]
DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")

app = FastAPI()

# Global model state - loaded lazily
model_state = {"ready": False, "tokenizer": None, "encoder": None, "policy_head": None}


class MessageRequest(BaseModel):
    message: str


class ActionResponse(BaseModel):
    action: str
    score: float


@app.get("/health")
def health():
    return {"status": "ok", "model_ready": model_state["ready"]}


def load_model():
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    encoder = AutoModel.from_pretrained("distilbert-base-uncased")

    # Simple policy head
    policy_head = nn.Linear(768, len(ACTIONS))

    # Load dataset for training
    data = []
    with open(DATASET_PATH, "r") as f:
        for line in f:
            item = json.loads(line)
            user_msg = item["messages"][1]["content"]
            label = item["messages"][2]["content"]
            data.append((user_msg, ACTIONS.index(label)))

    # Quick RL-style training (policy gradient simplified)
    optimizer = torch.optim.Adam(policy_head.parameters(), lr=1e-3)
    encoder.eval()

    for epoch in range(3):
        total_reward = 0
        for text, label in data[:100]:  # use subset for speed
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
            with torch.no_grad():
                hidden = encoder(**inputs).last_hidden_state[:, 0, :]

            logits = policy_head(hidden)
            probs = torch.softmax(logits, dim=-1)

            # Sample action (RL style)
            action = torch.multinomial(probs, 1).item()

            # Reward: +1 if correct, -1 if wrong
            reward = 1.0 if action == label else -1.0
            total_reward += reward

            # Policy gradient update
            log_prob = torch.log(probs[0, action])
            loss = -log_prob * reward

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return tokenizer, encoder, policy_head


def predict(text, tokenizer, encoder, policy_head):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
    with torch.no_grad():
        hidden = encoder(**inputs).last_hidden_state[:, 0, :]
        logits = policy_head(hidden)
        probs = torch.softmax(logits, dim=-1)
        action_idx = torch.argmax(probs, dim=-1).item()
        score = probs[0, action_idx].item()

    return ACTIONS[action_idx], score


@app.on_event("startup")
async def startup_event():
    import threading

    def load_in_background():
        tokenizer, encoder, policy_head = load_model()
        model_state["tokenizer"] = tokenizer
        model_state["encoder"] = encoder
        model_state["policy_head"] = policy_head
        model_state["ready"] = True
        print("Model loaded and ready!")

    # Load model in background thread so server can respond immediately
    thread = threading.Thread(target=load_in_background)
    thread.start()


@app.post("/action", response_model=ActionResponse)
def action(request: MessageRequest):
    if not model_state["ready"]:
        from fastapi import HTTPException
        raise HTTPException(status_code=503, detail="Model is still loading, please wait")

    action_name, score = predict(
        request.message,
        model_state["tokenizer"],
        model_state["encoder"],
        model_state["policy_head"]
    )
    return ActionResponse(action=action_name, score=round(score, 4))