PiotrPasztor commited on
Commit
a40c8da
·
1 Parent(s): bdb9784
Files changed (3) hide show
  1. __pycache__/app.cpython-314.pyc +0 -0
  2. app.py +396 -47
  3. dataset.jsonl +0 -0
__pycache__/app.cpython-314.pyc ADDED
Binary file (25.5 kB). View file
 
app.py CHANGED
@@ -1,19 +1,34 @@
1
  import torch
2
  import torch.nn as nn
 
3
  import json
4
  import os
 
 
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModel
8
 
9
- # Simple RL Classifier using Transformer
10
- ACTIONS = ["TRIP", "NONE", "GITHUB", "MAIL"]
11
  DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")
12
 
 
 
 
 
 
 
13
  app = FastAPI()
14
 
15
- # Global model state - loaded lazily
16
- model_state = {"ready": False, "tokenizer": None, "encoder": None, "policy_head": None}
 
 
 
 
 
17
 
18
 
19
  class MessageRequest(BaseModel):
@@ -25,69 +40,402 @@ class ActionResponse(BaseModel):
25
  score: float
26
 
27
 
28
- @app.get("/health")
29
- def health():
30
- return {"status": "ok", "model_ready": model_state["ready"]}
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def load_model():
34
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
35
- encoder = AutoModel.from_pretrained("distilbert-base-uncased")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Simple policy head
38
- policy_head = nn.Linear(768, len(ACTIONS))
 
 
 
39
 
40
- # Load dataset for training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  data = []
 
42
  with open(DATASET_PATH, "r") as f:
43
  for line in f:
44
  item = json.loads(line)
45
  user_msg = item["messages"][1]["content"]
46
  label = item["messages"][2]["content"]
47
- data.append((user_msg, ACTIONS.index(label)))
 
48
 
49
- # Quick RL-style training (policy gradient simplified)
50
- optimizer = torch.optim.Adam(policy_head.parameters(), lr=1e-3)
51
- encoder.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- for epoch in range(3):
54
- total_reward = 0
55
- for text, label in data[:100]: # use subset for speed
56
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
 
 
 
 
 
 
 
 
 
 
57
  with torch.no_grad():
58
- hidden = encoder(**inputs).last_hidden_state[:, 0, :]
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- logits = policy_head(hidden)
61
- probs = torch.softmax(logits, dim=-1)
 
62
 
63
- # Sample action (RL style)
64
- action = torch.multinomial(probs, 1).item()
 
 
 
 
 
 
 
 
 
65
 
66
- # Reward: +1 if correct, -1 if wrong
67
- reward = 1.0 if action == label else -1.0
68
- total_reward += reward
69
 
70
- # Policy gradient update
71
- log_prob = torch.log(probs[0, action])
72
- loss = -log_prob * reward
 
 
 
 
 
 
73
 
74
- optimizer.zero_grad()
75
- loss.backward()
76
- optimizer.step()
77
 
78
- return tokenizer, encoder, policy_head
79
 
 
 
 
 
 
 
80
 
81
- def predict(text, tokenizer, encoder, policy_head):
 
 
 
 
 
 
 
 
 
 
 
82
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
83
  with torch.no_grad():
84
  hidden = encoder(**inputs).last_hidden_state[:, 0, :]
85
- logits = policy_head(hidden)
86
- probs = torch.softmax(logits, dim=-1)
87
- action_idx = torch.argmax(probs, dim=-1).item()
88
- score = probs[0, action_idx].item()
 
 
 
 
 
 
 
89
 
90
- return ACTIONS[action_idx], score
 
 
 
 
 
91
 
92
 
93
  @app.on_event("startup")
@@ -95,14 +443,14 @@ async def startup_event():
95
  import threading
96
 
97
  def load_in_background():
98
- tokenizer, encoder, policy_head = load_model()
99
  model_state["tokenizer"] = tokenizer
100
  model_state["encoder"] = encoder
101
- model_state["policy_head"] = policy_head
 
102
  model_state["ready"] = True
103
- print("Model loaded and ready!")
104
 
105
- # Load model in background thread so server can respond immediately
106
  thread = threading.Thread(target=load_in_background)
107
  thread.start()
108
 
@@ -117,6 +465,7 @@ def action(request: MessageRequest):
117
  request.message,
118
  model_state["tokenizer"],
119
  model_state["encoder"],
120
- model_state["policy_head"]
 
121
  )
122
  return ActionResponse(action=action_name, score=round(score, 4))
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  import json
5
  import os
6
+ import random
7
+ import numpy as np
8
+ from collections import deque
9
  from fastapi import FastAPI
10
  from pydantic import BaseModel
11
  from transformers import AutoTokenizer, AutoModel
12
 
13
+ ACTIONS = ["TRIP", "GITHUB", "MAIL"]
14
+ NUM_ACTIONS = len(ACTIONS)
15
  DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")
16
 
17
+ # Confidence threshold - below this returns NONE
18
+ CONFIDENCE_THRESHOLD = 0.6
19
+
20
+ # Distance threshold for outlier detection (cosine similarity)
21
+ DISTANCE_THRESHOLD = 0.93
22
+
23
  app = FastAPI()
24
 
25
+ model_state = {
26
+ "ready": False,
27
+ "agent": None,
28
+ "tokenizer": None,
29
+ "encoder": None,
30
+ "class_centroids": None, # Mean embeddings per class
31
+ }
32
 
33
 
34
  class MessageRequest(BaseModel):
 
40
  score: float
41
 
42
 
43
+ class PolicyNetwork(nn.Module):
44
+ """Policy network that outputs action probabilities."""
 
45
 
46
+ def __init__(self, state_dim, num_actions, hidden_dim=128):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Linear(state_dim, hidden_dim),
50
+ nn.LayerNorm(hidden_dim),
51
+ nn.ReLU(),
52
+ nn.Dropout(0.1),
53
+ nn.Linear(hidden_dim, hidden_dim),
54
+ nn.LayerNorm(hidden_dim),
55
+ nn.ReLU(),
56
+ nn.Dropout(0.1),
57
+ nn.Linear(hidden_dim, num_actions)
58
+ )
59
 
60
+ # Initialize last layer with small weights for balanced initial policy
61
+ nn.init.xavier_uniform_(self.net[-1].weight, gain=0.01)
62
+ nn.init.zeros_(self.net[-1].bias)
63
+
64
+ def forward(self, state):
65
+ return self.net(state)
66
+
67
+ def get_action_probs(self, state):
68
+ logits = self.forward(state)
69
+ return F.softmax(logits, dim=-1)
70
+
71
+ def get_action(self, state, deterministic=False, temperature=1.0):
72
+ logits = self.forward(state)
73
+
74
+ # Apply temperature for exploration control
75
+ scaled_logits = logits / temperature
76
+ probs = F.softmax(scaled_logits, dim=-1)
77
+
78
+ if deterministic:
79
+ action = torch.argmax(probs, dim=-1)
80
+ else:
81
+ dist = torch.distributions.Categorical(probs)
82
+ action = dist.sample()
83
+
84
+ return action, probs
85
+
86
+
87
+ class QNetwork(nn.Module):
88
+ """Q-Network for action-value estimation."""
89
+
90
+ def __init__(self, state_dim, num_actions, hidden_dim=128):
91
+ super().__init__()
92
+ self.net = nn.Sequential(
93
+ nn.Linear(state_dim, hidden_dim),
94
+ nn.LayerNorm(hidden_dim),
95
+ nn.ReLU(),
96
+ nn.Linear(hidden_dim, hidden_dim),
97
+ nn.LayerNorm(hidden_dim),
98
+ nn.ReLU(),
99
+ nn.Linear(hidden_dim, num_actions)
100
+ )
101
+
102
+ def forward(self, state):
103
+ return self.net(state)
104
+
105
+
106
+ class RLAgent:
107
+ """
108
+ RL Agent using Double DQN with proper exploration.
109
+ """
110
+
111
+ def __init__(self, state_dim, num_actions, lr=1e-3, gamma=0.95):
112
+ self.state_dim = state_dim
113
+ self.num_actions = num_actions
114
+ self.gamma = gamma
115
+
116
+ # Q-Networks (Double DQN)
117
+ self.q_net = QNetwork(state_dim, num_actions)
118
+ self.target_q_net = QNetwork(state_dim, num_actions)
119
+ self.target_q_net.load_state_dict(self.q_net.state_dict())
120
+
121
+ # Policy network
122
+ self.policy_net = PolicyNetwork(state_dim, num_actions)
123
+
124
+ self.q_optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=lr, weight_decay=1e-4)
125
+ self.policy_optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr=lr, weight_decay=1e-4)
126
 
127
+ # Exploration parameters
128
+ self.epsilon = 1.0
129
+ self.epsilon_min = 0.05
130
+ self.epsilon_decay = 0.995
131
+ self.temperature = 1.0
132
 
133
+ def select_action(self, state, deterministic=True):
134
+ """Select action given state."""
135
+ with torch.no_grad():
136
+ if deterministic:
137
+ # Use policy network for inference
138
+ action, probs = self.policy_net.get_action(state, deterministic=True)
139
+ action_idx = action.item()
140
+
141
+ # Use entropy-based confidence: high entropy = low confidence
142
+ entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).item()
143
+ max_entropy = np.log(self.num_actions) # Maximum possible entropy
144
+
145
+ # Confidence based on how certain the distribution is
146
+ # Low entropy = high confidence, high entropy = low confidence
147
+ confidence = 1.0 - (entropy / max_entropy)
148
+
149
+ # Also factor in the raw probability
150
+ raw_prob = probs[0, action_idx].item()
151
+ confidence = confidence * raw_prob
152
+ else:
153
+ # Epsilon-greedy for training
154
+ if random.random() < self.epsilon:
155
+ action_idx = random.randint(0, self.num_actions - 1)
156
+ confidence = 1.0 / self.num_actions
157
+ else:
158
+ action, probs = self.policy_net.get_action(state, deterministic=False, temperature=self.temperature)
159
+ action_idx = action.item()
160
+ confidence = probs[0, action_idx].item()
161
+
162
+ return action_idx, confidence
163
+
164
+ def update_q(self, states, actions, rewards, next_states, dones):
165
+ """Update Q-network using TD learning."""
166
+ # Current Q values
167
+ q_values = self.q_net(states)
168
+ q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
169
+
170
+ # Target Q values (Double DQN)
171
+ with torch.no_grad():
172
+ # Select best action using online network
173
+ next_q_online = self.q_net(next_states)
174
+ best_actions = next_q_online.argmax(dim=1)
175
+
176
+ # Evaluate using target network
177
+ next_q_target = self.target_q_net(next_states)
178
+ next_q_values = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1)
179
+
180
+ target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
181
+
182
+ # Q-network loss
183
+ q_loss = F.smooth_l1_loss(q_values, target_q_values)
184
+
185
+ self.q_optimizer.zero_grad()
186
+ q_loss.backward()
187
+ torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
188
+ self.q_optimizer.step()
189
+
190
+ return q_loss.item()
191
+
192
+ def update_policy(self, states, actions):
193
+ """Update policy network to match Q-values (actor-critic style)."""
194
+ # Get Q-values for actions
195
+ with torch.no_grad():
196
+ q_values = self.q_net(states)
197
+ # Advantage = Q(s,a) - V(s), where V(s) = E[Q(s,a)]
198
+ v_values = q_values.mean(dim=1, keepdim=True)
199
+ advantages = q_values - v_values
200
+
201
+ # Policy logits
202
+ logits = self.policy_net(states)
203
+ log_probs = F.log_softmax(logits, dim=-1)
204
+
205
+ # Policy loss: maximize advantage-weighted log probability
206
+ action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
207
+ action_advantages = advantages.gather(1, actions.unsqueeze(1)).squeeze(1)
208
+
209
+ # Add entropy bonus for exploration
210
+ probs = F.softmax(logits, dim=-1)
211
+ entropy = -(probs * log_probs).sum(dim=-1).mean()
212
+
213
+ policy_loss = -(action_log_probs * action_advantages.detach()).mean() - 0.05 * entropy
214
+
215
+ self.policy_optimizer.zero_grad()
216
+ policy_loss.backward()
217
+ torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
218
+ self.policy_optimizer.step()
219
+
220
+ return policy_loss.item()
221
+
222
+ def update_target_network(self, tau=0.005):
223
+ """Soft update target network."""
224
+ for target_param, param in zip(self.target_q_net.parameters(), self.q_net.parameters()):
225
+ target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
226
+
227
+ def decay_exploration(self):
228
+ """Decay exploration parameters."""
229
+ self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
230
+
231
+
232
+ def load_dataset():
233
+ """Load and parse the dataset."""
234
  data = []
235
+
236
  with open(DATASET_PATH, "r") as f:
237
  for line in f:
238
  item = json.loads(line)
239
  user_msg = item["messages"][1]["content"]
240
  label = item["messages"][2]["content"]
241
+ if label in ACTIONS:
242
+ data.append((user_msg, ACTIONS.index(label)))
243
 
244
+ random.shuffle(data)
245
+ return data
246
+
247
+
248
+ def encode_texts(texts, tokenizer, encoder):
249
+ """Batch encode texts to state representations."""
250
+ inputs = tokenizer(texts, return_tensors="pt", truncation=True, max_length=64, padding=True)
251
+ with torch.no_grad():
252
+ hidden = encoder(**inputs).last_hidden_state[:, 0, :]
253
+ return hidden
254
+
255
+
256
+ def train_rl_agent(tokenizer, encoder, data, num_epochs=50, batch_size=64):
257
+ """
258
+ Train RL agent using offline RL on dataset.
259
+
260
+ Uses the dataset as demonstration data:
261
+ - States: encoded text messages
262
+ - Actions: correct labels from dataset (expert demonstrations)
263
+ - Rewards: +1 for correct, -1 for incorrect
264
+ """
265
+ state_dim = 768 # DistilBERT hidden size
266
+ agent = RLAgent(state_dim, NUM_ACTIONS, lr=3e-4)
267
+
268
+ print("Encoding all dataset examples...")
269
+
270
+ # Pre-encode all texts for efficiency
271
+ all_texts = [text for text, _ in data]
272
+ all_labels = [label for _, label in data]
273
+
274
+ # Encode in batches
275
+ all_states = []
276
+ for i in range(0, len(all_texts), batch_size):
277
+ batch_texts = all_texts[i:i+batch_size]
278
+ batch_states = encode_texts(batch_texts, tokenizer, encoder)
279
+ all_states.append(batch_states)
280
+
281
+ all_states = torch.cat(all_states, dim=0)
282
+ all_labels = torch.tensor(all_labels, dtype=torch.long)
283
+
284
+ print(f"Encoded {len(all_states)} examples")
285
+
286
+ # Print class distribution
287
+ for i, action_name in enumerate(ACTIONS):
288
+ count = (all_labels == i).sum().item()
289
+ print(f" {action_name}: {count} examples")
290
+
291
+ # Create next states (shifted by 1, with wraparound)
292
+ indices = torch.randperm(len(all_states))
293
+ next_states = all_states[indices]
294
+
295
+ print("Starting RL training...")
296
+
297
+ for epoch in range(num_epochs):
298
+ # Shuffle data each epoch
299
+ perm = torch.randperm(len(all_states))
300
+ states_shuffled = all_states[perm]
301
+ labels_shuffled = all_labels[perm]
302
+ next_states_shuffled = next_states[perm]
303
+
304
+ epoch_q_loss = 0
305
+ epoch_policy_loss = 0
306
+ num_batches = 0
307
+
308
+ for i in range(0, len(states_shuffled), batch_size):
309
+ batch_states = states_shuffled[i:i+batch_size]
310
+ batch_labels = labels_shuffled[i:i+batch_size]
311
+ batch_next_states = next_states_shuffled[i:i+batch_size]
312
+
313
+ # Simple rewards: +1 for correct, -1 for wrong
314
+ batch_rewards = torch.ones(len(batch_labels), dtype=torch.float32)
315
+ batch_dones = torch.zeros(len(batch_labels), dtype=torch.float32)
316
+
317
+ # Add negative examples (wrong actions with negative reward)
318
+ wrong_actions_list = []
319
+ for label in batch_labels:
320
+ wrong = (label.item() + random.randint(1, NUM_ACTIONS - 1)) % NUM_ACTIONS
321
+ wrong_actions_list.append(wrong)
322
+ wrong_actions = torch.tensor(wrong_actions_list, dtype=torch.long)
323
+ wrong_rewards = -torch.ones(len(batch_labels), dtype=torch.float32)
324
+
325
+ # Combine correct and incorrect transitions
326
+ combined_states = torch.cat([batch_states, batch_states], dim=0)
327
+ combined_actions = torch.cat([batch_labels, wrong_actions], dim=0)
328
+ combined_rewards = torch.cat([batch_rewards, wrong_rewards], dim=0)
329
+ combined_next_states = torch.cat([batch_next_states, batch_next_states], dim=0)
330
+ combined_dones = torch.cat([batch_dones, batch_dones], dim=0)
331
+
332
+ # Update Q-network
333
+ q_loss = agent.update_q(
334
+ combined_states, combined_actions, combined_rewards,
335
+ combined_next_states, combined_dones
336
+ )
337
 
338
+ # Update policy (only on correct examples)
339
+ policy_loss = agent.update_policy(batch_states, batch_labels)
340
+
341
+ # Soft update target
342
+ agent.update_target_network(tau=0.005)
343
+
344
+ epoch_q_loss += q_loss
345
+ epoch_policy_loss += policy_loss
346
+ num_batches += 1
347
+
348
+ agent.decay_exploration()
349
+
350
+ if (epoch + 1) % 10 == 0:
351
+ # Evaluate
352
  with torch.no_grad():
353
+ _, probs = agent.policy_net.get_action(all_states, deterministic=True)
354
+ predictions = probs.argmax(dim=-1)
355
+ accuracy = (predictions == all_labels).float().mean().item() * 100
356
+
357
+ # Check policy entropy (diversity)
358
+ avg_entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=-1).mean().item()
359
+
360
+ print(f"Epoch {epoch + 1}/{num_epochs} | "
361
+ f"Q-Loss: {epoch_q_loss/num_batches:.4f} | "
362
+ f"Policy-Loss: {epoch_policy_loss/num_batches:.4f} | "
363
+ f"Accuracy: {accuracy:.1f}% | "
364
+ f"Entropy: {avg_entropy:.3f} | "
365
+ f"Epsilon: {agent.epsilon:.3f}")
366
 
367
+ # Set networks to eval mode (disables dropout for deterministic inference)
368
+ agent.policy_net.eval()
369
+ agent.q_net.eval()
370
 
371
+ # Final evaluation
372
+ print("\nFinal Evaluation:")
373
+ with torch.no_grad():
374
+ _, probs = agent.policy_net.get_action(all_states, deterministic=True)
375
+ predictions = probs.argmax(dim=-1)
376
+
377
+ for i, action_name in enumerate(ACTIONS):
378
+ mask = all_labels == i
379
+ if mask.sum() > 0:
380
+ action_acc = (predictions[mask] == i).float().mean().item() * 100
381
+ print(f" {action_name}: {action_acc:.1f}% ({mask.sum().item()} samples)")
382
 
383
+ overall_acc = (predictions == all_labels).float().mean().item() * 100
384
+ print(f" Overall: {overall_acc:.1f}%")
 
385
 
386
+ # Compute class centroids for outlier detection
387
+ print("\nComputing class centroids...")
388
+ centroids = []
389
+ for i in range(NUM_ACTIONS):
390
+ mask = all_labels == i
391
+ class_states = all_states[mask]
392
+ centroid = class_states.mean(dim=0)
393
+ centroids.append(centroid)
394
+ class_centroids = torch.stack(centroids)
395
 
396
+ return agent, class_centroids
 
 
397
 
 
398
 
399
+ def load_model():
400
+ """Load encoder and train RL agent."""
401
+ print("Loading tokenizer and encoder...")
402
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
403
+ encoder = AutoModel.from_pretrained("distilbert-base-uncased")
404
+ encoder.eval()
405
 
406
+ print("Loading dataset...")
407
+ data = load_dataset()
408
+ print(f"Dataset size: {len(data)} examples")
409
+
410
+ print("Training RL agent...")
411
+ agent, class_centroids = train_rl_agent(tokenizer, encoder, data)
412
+
413
+ return tokenizer, encoder, agent, class_centroids
414
+
415
+
416
+ def predict(text, tokenizer, encoder, agent, class_centroids):
417
+ """Use trained RL agent to predict action for given text."""
418
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64)
419
  with torch.no_grad():
420
  hidden = encoder(**inputs).last_hidden_state[:, 0, :]
421
+ action_idx, confidence = agent.select_action(hidden, deterministic=True)
422
+
423
+ # Compute cosine similarity to closest class centroid
424
+ hidden_norm = hidden / hidden.norm(dim=-1, keepdim=True)
425
+ centroids_norm = class_centroids / class_centroids.norm(dim=-1, keepdim=True)
426
+ similarities = torch.mm(hidden_norm, centroids_norm.t()).squeeze(0)
427
+ max_similarity = similarities.max().item()
428
+
429
+ # Return NONE if similarity is too low OR confidence is too low
430
+ if max_similarity < DISTANCE_THRESHOLD or confidence < CONFIDENCE_THRESHOLD:
431
+ return "NONE", confidence
432
 
433
+ return ACTIONS[action_idx], confidence
434
+
435
+
436
+ @app.get("/health")
437
+ def health():
438
+ return {"status": "ok", "model_ready": model_state["ready"]}
439
 
440
 
441
  @app.on_event("startup")
 
443
  import threading
444
 
445
  def load_in_background():
446
+ tokenizer, encoder, agent, class_centroids = load_model()
447
  model_state["tokenizer"] = tokenizer
448
  model_state["encoder"] = encoder
449
+ model_state["agent"] = agent
450
+ model_state["class_centroids"] = class_centroids
451
  model_state["ready"] = True
452
+ print("RL Agent loaded and ready!")
453
 
 
454
  thread = threading.Thread(target=load_in_background)
455
  thread.start()
456
 
 
465
  request.message,
466
  model_state["tokenizer"],
467
  model_state["encoder"],
468
+ model_state["agent"],
469
+ model_state["class_centroids"]
470
  )
471
  return ActionResponse(action=action_name, score=round(score, 4))
dataset.jsonl CHANGED
The diff for this file is too large to render. See raw diff