Naseer-010 commited on
Commit
09fcbfb
Β·
1 Parent(s): b54ab02

fixing the reward tirage

Browse files
Files changed (2) hide show
  1. inference.py +50 -18
  2. train_grpo_unsloth.py +219 -160
inference.py CHANGED
@@ -90,33 +90,37 @@ Available commands:
90
  CRITICAL INCIDENT TRIAGE TREE (Follow strictly in order):
91
  1. OOM IMMINENT (Memory Leak): IF ANY 'mem_utilizations' > 0.92:
92
  IMMEDIATELY output: kubectl delete pod node-5 (or whichever node is leaking. Scaling does NOT fix memory leaks!)
93
-
94
- 2. SPLIT-BRAIN (Disk I/O Bottleneck): IF node_0 'io_wait' > 0.80:
 
 
 
 
95
  Output: kubectl throttle ingress --rate=0.5 (Do NOT scale up; more workers will lock the DB disk further).
96
 
97
- 3. HOT SHARD (Load Balancer Skew): IF one worker's CPU > 0.90 but the cluster average is low:
98
  Output: kubectl exec -it istio-proxy -- traffic shift --from=<high_cpu_node> --to=<low_cpu_node>
99
 
100
- 4. RETRY STORM / THUNDERING HERD: IF 'p99_latency' > 100.0 AND traffic is spiking:
101
  Output: kubectl throttle ingress --rate=0.4 (Break the exponential retry loop).
102
 
103
- 5. CONNECTION DEADLOCK (Zombie Node): IF a worker's CPU is incredibly low (< 0.10) BUT 'p99_latency' is huge:
104
  Output: kubectl exec -it istio-proxy -- traffic shift --from=<zombie_node> --to=<healthy_node>
105
 
106
- 6. BLACK SWAN (Multi-Node Death): IF multiple nodes are in 'failed_nodes':
107
  Output: kubectl throttle ingress --rate=0.3 (Shed load to protect survivors while you recover).
108
 
109
- 7. DATABASE SURVIVAL: IF node-0 (DB) cpu_load > 0.80:
110
  Output: kubectl throttle ingress --rate=0.7
111
 
112
- 8. SAFE SCALING: IF avg worker CPU > 0.75 AND 'error_budget' > 20:
113
  Output: kubectl scale deployment frontend --replicas=10
114
 
115
- 9. HEALTHY / FLAPPING TRAP: If metrics are stable or oscillating slightly:
116
  Output: no_op
117
 
118
  Respond using the following STRICT format. You must include the XML reasoning tags:
119
- <reasoning>Diagnose the telemetry. Identify which of the 9 Triage rules applies.</reasoning>
120
  <action>
121
  {"command": "your_kubectl_command_or_no_op_here"}
122
  </action>"""
@@ -649,16 +653,30 @@ def _get_direct_env():
649
  global _direct_env
650
  if _direct_env is None:
651
  from server.environment import DistributedInfraEnvironment
 
652
  _direct_env = DistributedInfraEnvironment()
653
  return _direct_env
654
 
655
 
656
  def _infraobs_to_dict(obs) -> dict:
657
  keys = [
658
- "cpu_loads", "mem_utilizations", "queue_lengths", "failed_nodes",
659
- "latency_ms", "request_rate", "io_wait", "p99_latency", "error_budget",
660
- "step", "task_hint", "action_errors", "cloud_budget",
661
- "task_score", "done", "reward", "uptime_pct",
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  ]
663
  return {k: getattr(obs, k) for k in keys if hasattr(obs, k)}
664
 
@@ -669,6 +687,7 @@ def env_reset_direct(task_id: str) -> dict:
669
 
670
  def env_step_direct(action_dict: dict) -> dict:
671
  from server.models import InfraAction
 
672
  env = _get_direct_env()
673
  raw_cmd = action_dict.get("raw_command")
674
  if raw_cmd and raw_cmd != "no_op":
@@ -676,7 +695,10 @@ def env_step_direct(action_dict: dict) -> dict:
676
  else:
677
  act_type = action_dict.get("action_type", "no_op")
678
  kwargs: dict = {"action_type": act_type}
679
- if act_type in ("restart_node", "query_logs") and action_dict.get("target") is not None:
 
 
 
680
  kwargs["target"] = int(action_dict["target"])
681
  elif act_type == "reroute_traffic":
682
  kwargs["from_node"] = int(action_dict.get("from_node", 0))
@@ -689,7 +711,13 @@ def env_step_direct(action_dict: dict) -> dict:
689
  action = InfraAction(action_type="no_op")
690
  obs = env.step(action)
691
  obs_dict = _infraobs_to_dict(obs)
692
- return {"data": {"observation": obs_dict, "reward": getattr(obs, "reward", 0.0), "done": getattr(obs, "done", False)}}
 
 
 
 
 
 
693
 
694
 
695
  # ---------------------------------------------------------------------------
@@ -764,7 +792,11 @@ def run_task(
764
  done = False
765
 
766
  try:
767
- result = env_step_direct(backend_action) if use_direct else env_step(env_url, backend_action)
 
 
 
 
768
  data_block = result.get("data", result)
769
 
770
  if "observation" in data_block and isinstance(
@@ -923,7 +955,7 @@ def main():
923
  print(f" STRUCTURED LOGS: {log_dir}")
924
  print("==================================================")
925
 
926
- use_direct = (mode == "local")
927
  for task_id in tasks:
928
  stats = run_task(
929
  task_id,
 
90
  CRITICAL INCIDENT TRIAGE TREE (Follow strictly in order):
91
  1. OOM IMMINENT (Memory Leak): IF ANY 'mem_utilizations' > 0.92:
92
  IMMEDIATELY output: kubectl delete pod node-5 (or whichever node is leaking. Scaling does NOT fix memory leaks!)
93
+
94
+ 2. DB RECOVERY: IF node-0 is in 'failed_nodes':
95
+ Output: kubectl delete pod node-0
96
+ (The DB is a SPOF. If it's dead, ALL other actions are futile until it restarts.)
97
+
98
+ 3. SPLIT-BRAIN (Disk I/O Bottleneck): IF node_0 'io_wait' > 0.80:
99
  Output: kubectl throttle ingress --rate=0.5 (Do NOT scale up; more workers will lock the DB disk further).
100
 
101
+ 4. HOT SHARD (Load Balancer Skew): IF one worker's CPU > 0.90 but the cluster average is low:
102
  Output: kubectl exec -it istio-proxy -- traffic shift --from=<high_cpu_node> --to=<low_cpu_node>
103
 
104
+ 5. RETRY STORM / THUNDERING HERD: IF 'p99_latency' > 100.0 AND traffic is spiking:
105
  Output: kubectl throttle ingress --rate=0.4 (Break the exponential retry loop).
106
 
107
+ 6. CONNECTION DEADLOCK (Zombie Node): IF a worker's CPU is incredibly low (< 0.10) BUT 'p99_latency' is huge:
108
  Output: kubectl exec -it istio-proxy -- traffic shift --from=<zombie_node> --to=<healthy_node>
109
 
110
+ 7. BLACK SWAN (Multi-Node Death): IF multiple nodes are in 'failed_nodes' (but DB is alive):
111
  Output: kubectl throttle ingress --rate=0.3 (Shed load to protect survivors while you recover).
112
 
113
+ 8. DATABASE SURVIVAL: IF node-0 (DB) cpu_load > 0.80:
114
  Output: kubectl throttle ingress --rate=0.7
115
 
116
+ 9. SAFE SCALING: IF avg worker CPU > 0.75 AND 'error_budget' > 20:
117
  Output: kubectl scale deployment frontend --replicas=10
118
 
119
+ 10. HEALTHY / FLAPPING TRAP: If metrics are stable or oscillating slightly:
120
  Output: no_op
121
 
122
  Respond using the following STRICT format. You must include the XML reasoning tags:
123
+ <reasoning>Diagnose the telemetry. Identify which of the 10 Triage rules applies.</reasoning>
124
  <action>
125
  {"command": "your_kubectl_command_or_no_op_here"}
126
  </action>"""
 
653
  global _direct_env
654
  if _direct_env is None:
655
  from server.environment import DistributedInfraEnvironment
656
+
657
  _direct_env = DistributedInfraEnvironment()
658
  return _direct_env
659
 
660
 
661
  def _infraobs_to_dict(obs) -> dict:
662
  keys = [
663
+ "cpu_loads",
664
+ "mem_utilizations",
665
+ "queue_lengths",
666
+ "failed_nodes",
667
+ "latency_ms",
668
+ "request_rate",
669
+ "io_wait",
670
+ "p99_latency",
671
+ "error_budget",
672
+ "step",
673
+ "task_hint",
674
+ "action_errors",
675
+ "cloud_budget",
676
+ "task_score",
677
+ "done",
678
+ "reward",
679
+ "uptime_pct",
680
  ]
681
  return {k: getattr(obs, k) for k in keys if hasattr(obs, k)}
682
 
 
687
 
688
  def env_step_direct(action_dict: dict) -> dict:
689
  from server.models import InfraAction
690
+
691
  env = _get_direct_env()
692
  raw_cmd = action_dict.get("raw_command")
693
  if raw_cmd and raw_cmd != "no_op":
 
695
  else:
696
  act_type = action_dict.get("action_type", "no_op")
697
  kwargs: dict = {"action_type": act_type}
698
+ if (
699
+ act_type in ("restart_node", "query_logs")
700
+ and action_dict.get("target") is not None
701
+ ):
702
  kwargs["target"] = int(action_dict["target"])
703
  elif act_type == "reroute_traffic":
704
  kwargs["from_node"] = int(action_dict.get("from_node", 0))
 
711
  action = InfraAction(action_type="no_op")
712
  obs = env.step(action)
713
  obs_dict = _infraobs_to_dict(obs)
714
+ return {
715
+ "data": {
716
+ "observation": obs_dict,
717
+ "reward": getattr(obs, "reward", 0.0),
718
+ "done": getattr(obs, "done", False),
719
+ }
720
+ }
721
 
722
 
723
  # ---------------------------------------------------------------------------
 
792
  done = False
793
 
794
  try:
795
+ result = (
796
+ env_step_direct(backend_action)
797
+ if use_direct
798
+ else env_step(env_url, backend_action)
799
+ )
800
  data_block = result.get("data", result)
801
 
802
  if "observation" in data_block and isinstance(
 
955
  print(f" STRUCTURED LOGS: {log_dir}")
956
  print("==================================================")
957
 
958
+ use_direct = mode == "local"
959
  for task_id in tasks:
960
  stats = run_task(
961
  task_id,
train_grpo_unsloth.py CHANGED
@@ -26,6 +26,7 @@ Post-training benchmark:
26
 
27
  # UNSLOTH_VLLM_STANDBY must be set before unsloth is imported.
28
  import os
 
29
  os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
30
 
31
  # ---------------------------------------------------------------------------
@@ -64,6 +65,7 @@ from server.models import InfraAction, InfraObservation
64
  from server.command_parser import parse_command, CommandParseError
65
  from server.rubrics import calculate_step_reward as _calculate_step_reward
66
 
 
67
  def _probe_rubrics() -> bool:
68
  """Return True if rubrics returns bounded rewards (main branch), False if -1000 (nithish)."""
69
  try:
@@ -75,33 +77,47 @@ def _probe_rubrics() -> bool:
75
  except Exception:
76
  return False
77
 
 
78
  _RUBRICS_BOUNDED = _probe_rubrics()
79
- print(f"[GRPO] rubrics version: {'main (bounded [-5,+5])' if _RUBRICS_BOUNDED else 'nithish (WARNING: -1000 cliff β€” using fallback)'}")
 
 
80
 
81
  # ---------------------------------------------------------------------------
82
  # Config
83
  # ---------------------------------------------------------------------------
84
 
85
- MODEL_NAME = "unsloth/Qwen3-8B"
86
- MAX_SEQ_LENGTH = 2048
87
- LORA_RANK = 32
88
- OUTPUT_DIR = "checkpoints/qwen3_grpo_unsloth"
89
 
90
- DATASET_EPISODES = 500 # env rollouts to build the training dataset
91
- MAX_STEPS = 300 # GRPOTrainer update steps
92
- NUM_GENERATIONS = 4 # G β€” completions per prompt; reward_env is CPU-bound, keep small
93
- MAX_COMPLETION_LENGTH = 512 # Qwen3 no-think response is ~60 tokens; 512 is a safe ceiling
94
- SAVE_STEPS = 100
 
 
95
 
96
  ALL_TASKS = [
97
- "traffic_spike", "node_failure", "cascading_failure", "flash_crowd",
98
- "thundering_herd", "zombie_node", "hot_shard_skew", "memory_leak_slow_burn",
99
- "split_brain_io_bottleneck", "black_swan_az_failure",
100
- "retry_storm", "connection_pool_deadlock", "autoscaler_flapping_trap",
 
 
 
 
 
 
 
 
 
101
  ]
102
 
103
  # ---------------------------------------------------------------------------
104
- # System prompt β€” identical to inference.py + rule 8 (DB recovery)
105
  # ---------------------------------------------------------------------------
106
 
107
  SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE) managing a highly volatile Kubernetes cluster.
@@ -124,15 +140,16 @@ Available commands:
124
 
125
  CRITICAL INCIDENT TRIAGE TREE (Follow strictly in order):
126
  1. OOM IMMINENT: IF ANY 'mem_utilizations' > 0.92 β†’ kubectl delete pod node-<leaking_node>
127
- 2. SPLIT-BRAIN: IF node_0 'io_wait' > 0.80 β†’ kubectl throttle ingress --rate=0.5
128
- 3. HOT SHARD: IF one worker CPU > 0.90 but cluster average is low
 
 
129
  β†’ kubectl exec -it istio-proxy -- traffic shift --from=<hot> --to=<cold>
130
- 4. RETRY STORM: IF 'p99_latency' > 100.0 AND traffic spiking β†’ kubectl throttle ingress --rate=0.4
131
- 5. ZOMBIE NODE: IF worker CPU < 0.10 BUT 'p99_latency' is huge
132
  β†’ kubectl exec -it istio-proxy -- traffic shift --from=<zombie> --to=<healthy>
133
- 6. BLACK SWAN: IF multiple nodes in 'failed_nodes' β†’ kubectl throttle ingress --rate=0.3
134
- 7. DATABASE SURVIVAL: IF node-0 cpu_load > 0.80 β†’ kubectl throttle ingress --rate=0.7
135
- 8. DB RECOVERY: IF node-0 is in 'failed_nodes' β†’ kubectl delete pod node-0
136
  9. SAFE SCALING: IF avg worker CPU > 0.75 AND 'error_budget' > 20
137
  β†’ kubectl scale deployment frontend --replicas=10
138
  10. HEALTHY: If metrics are stable β†’ no_op
@@ -147,26 +164,40 @@ Respond in this exact format:
147
  # Triage oracle β€” deterministic expected action for any observation
148
  # ---------------------------------------------------------------------------
149
 
 
150
  def _get_expected_action(obs: dict) -> str:
151
- """Return the kubectl command the triage tree mandates, or 'no_op'."""
152
- cpu = obs.get("cpu_loads", [0.3] * 8)
153
- mem = obs.get("mem_utilizations", [0.2] * 8)
154
- fail = set(obs.get("failed_nodes", []))
155
- io = float(obs.get("io_wait", 0.0))
156
- p99 = float(obs.get("p99_latency", 0.0))
157
- rr = float(obs.get("request_rate", 100.0))
158
- bud = float(obs.get("error_budget", 100.0))
159
-
160
- # Rule 1: OOM
 
 
 
 
 
 
 
 
 
161
  for i, m in enumerate(mem):
162
  if float(m) > 0.92:
163
  return f"kubectl delete pod node-{i}"
164
 
165
- # Rule 2: Split-brain
 
 
 
 
166
  if io > 0.80:
167
  return "kubectl throttle ingress --rate=0.5"
168
 
169
- # Rule 3: Hot shard
170
  workers = [(i, float(c)) for i, c in enumerate(cpu[1:], 1) if float(c) >= 0]
171
  if workers:
172
  avg = sum(c for _, c in workers) / len(workers)
@@ -180,11 +211,11 @@ def _get_expected_action(obs: dict) -> str:
180
  if dst is not None:
181
  return f"kubectl exec -it istio-proxy -- traffic shift --from={i} --to={dst}"
182
 
183
- # Rule 4: Retry storm
184
  if p99 > 100.0 and rr > 150:
185
  return "kubectl throttle ingress --rate=0.4"
186
 
187
- # Rule 5: Zombie node
188
  for i, c in workers:
189
  if 0 <= c < 0.10 and p99 > 100.0:
190
  dst = next(
@@ -194,34 +225,42 @@ def _get_expected_action(obs: dict) -> str:
194
  if dst is not None:
195
  return f"kubectl exec -it istio-proxy -- traffic shift --from={i} --to={dst}"
196
 
197
- # Rule 6: Black swan
198
  if len(fail) >= 2:
199
  return "kubectl throttle ingress --rate=0.3"
200
 
201
- # Rule 7: DB survival
202
  db_cpu = float(cpu[0]) if cpu and float(cpu[0]) >= 0 else 0.0
203
  if db_cpu > 0.80:
204
  return "kubectl throttle ingress --rate=0.7"
205
 
206
- # Rule 8: DB recovery ← this rule was MISSING from inference.py
207
- if 0 in fail:
208
- return "kubectl delete pod node-0"
209
-
210
  # Rule 9: Safe scaling
211
  if workers and sum(c for _, c in workers) / len(workers) > 0.75 and bud > 20:
212
  return "kubectl scale deployment frontend --replicas=10"
213
 
214
  return "no_op"
215
 
 
216
  # ---------------------------------------------------------------------------
217
  # Dataset collection
218
  # ---------------------------------------------------------------------------
219
 
 
220
  def _obs_to_dict(obs: InfraObservation) -> dict:
221
  keys = [
222
- "cpu_loads", "mem_utilizations", "queue_lengths", "failed_nodes",
223
- "latency_ms", "request_rate", "io_wait", "p99_latency", "error_budget",
224
- "step", "task_hint", "action_errors", "cloud_budget",
 
 
 
 
 
 
 
 
 
 
225
  ]
226
  return {k: getattr(obs, k) for k in keys if hasattr(obs, k)}
227
 
@@ -231,11 +270,14 @@ def _heuristic_action(obs: InfraObservation) -> InfraAction:
231
  if random.random() < 0.30:
232
  atype = random.choice(["no_op", "restart_node", "throttle", "scale_up"])
233
  if atype == "restart_node":
234
- return InfraAction(action_type="restart_node",
235
- target=random.randint(0, min(7, len(obs.cpu_loads) - 1)))
 
 
236
  if atype == "throttle":
237
- return InfraAction(action_type="throttle",
238
- rate=random.choice([0.3, 0.5, 0.7]))
 
239
  if atype == "scale_up":
240
  return InfraAction(action_type="scale_up")
241
  return InfraAction(action_type="no_op")
@@ -270,21 +312,23 @@ def collect_dataset(n_episodes: int, tasks: List[str]) -> Dataset:
270
 
271
  for _ in range(20):
272
  d = _obs_to_dict(obs)
273
- rows.append({
274
- "prompt": [
275
- {"role": "system", "content": SYSTEM_PROMPT},
276
- {
277
- "role": "user",
278
- "content": (
279
- "/no_think\n" # suppress Qwen3 <think> block β€” response is ~60 tokens, not ~600
280
- f"Current system state:\n{json.dumps(d)}\n"
281
- "Respond with the required XML and JSON format."
282
- ),
283
- },
284
- ],
285
- "obs_json": json.dumps(d),
286
- "task": task,
287
- })
 
 
288
 
289
  action = _heuristic_action(obs)
290
  try:
@@ -295,15 +339,17 @@ def collect_dataset(n_episodes: int, tasks: List[str]) -> Dataset:
295
  break
296
 
297
  if (ep + 1) % 50 == 0:
298
- print(f" [dataset] episode {ep+1}/{n_episodes} β†’ {len(rows)} rows")
299
 
300
  print(f" [dataset] collected {len(rows)} total rows")
301
  return Dataset.from_list(rows)
302
 
 
303
  # ---------------------------------------------------------------------------
304
  # Helpers shared by reward functions
305
  # ---------------------------------------------------------------------------
306
 
 
307
  def _get_completion_text(comp) -> str:
308
  """
309
  Extract completion text from TRL GRPOTrainer's format.
@@ -341,7 +387,7 @@ def _restore_env_state(env: DistributedInfraEnvironment, obs: dict) -> None:
341
  sim = env.sim
342
  cpu_l = obs.get("cpu_loads", [])
343
  mem_l = obs.get("mem_utilizations", [])
344
- q_l = obs.get("queue_lengths", [])
345
 
346
  for i in range(min(len(cpu_l), len(sim.nodes))):
347
  if float(cpu_l[i]) >= 0:
@@ -355,10 +401,11 @@ def _restore_env_state(env: DistributedInfraEnvironment, obs: dict) -> None:
355
  if 0 <= idx < len(sim.nodes):
356
  sim.nodes[idx].is_failed = True
357
 
358
- sim.latency_ms = float(obs.get("latency_ms", 20.0))
359
- sim.error_budget = float(obs.get("error_budget", 100.0))
360
- sim.last_trace_p99_latency = float(obs.get("p99_latency", 0.0))
361
- sim.last_trace_node_0_io = float(obs.get("io_wait", 0.0))
 
362
 
363
  # ---------------------------------------------------------------------------
364
  # Reward functions (TRL GRPOTrainer signature)
@@ -371,6 +418,7 @@ def _restore_env_state(env: DistributedInfraEnvironment, obs: dict) -> None:
371
  # len(obs_json) == G (same value repeated G times by TRL)
372
  # ---------------------------------------------------------------------------
373
 
 
374
  def reward_format(completions: List, **kwargs) -> List[float]:
375
  """
376
  Reward XML structure compliance (mirrors Unsloth's match_format_exactly /
@@ -383,18 +431,18 @@ def reward_format(completions: List, **kwargs) -> List[float]:
383
  scores = []
384
  for comp in completions:
385
  text = _get_completion_text(comp)
386
- n_re = text.count("<reasoning>")
387
  n_re_ = text.count("</reasoning>")
388
- n_ac = text.count("<action>")
389
  n_ac_ = text.count("</action>")
390
 
391
  if n_re == 1 and n_re_ == 1 and n_ac == 1 and n_ac_ == 1:
392
  scores.append(3.0)
393
  else:
394
  s = 0.0
395
- s += 0.5 if n_re_ == 1 else -1.0
396
- s += 0.5 if n_ac == 1 else -1.0
397
- s += 0.5 if n_ac_ == 1 else -1.0
398
  scores.append(s)
399
  return scores
400
 
@@ -431,26 +479,25 @@ def reward_env(
431
  **kwargs,
432
  ) -> List[float]:
433
  """
434
- Environment simulation reward using the production-grade SRE reward function.
435
 
436
- Uses calculate_step_reward() from server/rubrics.py (friend's improved version):
437
  - 7 components: uptime, DB CPU, memory cliff, p99 latency, load shedding,
438
  action efficiency, temporal friction
439
  - Bounded to [-5.0, +5.0] β€” no -1000 cliff, gradients always flow
440
- - Action-aware: penalises unnecessary throttling and no-ops under load
441
 
442
- Requires the updated rubrics.py (main branch) where calculate_step_reward
443
- returns -5.0 for DB failure instead of -1000.
444
 
445
- Range: [βˆ’5.0, +5.0]
446
  """
447
  scores = []
448
  for i, comp in enumerate(completions):
449
  try:
450
- obs_data = json.loads(obs_json[i]) if obs_json else {}
451
  task_name = task[i] if task else "traffic_spike"
452
  except (TypeError, IndexError, json.JSONDecodeError):
453
- scores.append(-5.0)
454
  continue
455
 
456
  env = DistributedInfraEnvironment()
@@ -472,17 +519,17 @@ def reward_env(
472
  pass
473
 
474
  if _RUBRICS_BOUNDED:
475
- # Main branch: full 7-component bounded reward [-5.0, +5.0]
476
- scores.append(_calculate_step_reward(env.sim))
477
  else:
478
- # Nithish branch fallback: simple 3-component formula [-2.5, +0.5]
479
- sim = env.sim
480
  nodes = sim.nodes
481
  alive = sum(1 for n in nodes if not n.is_failed)
482
- r_up = 0.5 * (alive / max(len(nodes), 1))
483
  r_lat = -0.5 * min((max(0.0, sim.latency_ms - 50.0) / 100.0) ** 2, 1.0)
484
- r_db = -2.0 if (nodes and nodes[0].is_failed) else 0.0
485
- scores.append(r_up + r_lat + r_db)
486
 
487
  return scores
488
 
@@ -493,61 +540,62 @@ def reward_triage(
493
  **kwargs,
494
  ) -> List[float]:
495
  """
496
- Triage oracle reward (mirrors Unsloth's check_answer / check_numbers pattern).
497
 
498
  Compares the model's action against the deterministic triage tree output.
 
499
 
500
- +5.0 β€” exact command match with expected action
501
- +2.0 β€” same action_type but different parameters (e.g., throttle at wrong rate)
502
- 0.0 β€” no_op when a specific action is expected
503
- -2.0 β€” completely wrong action type
504
- -1.0 β€” unnecessary action when system is healthy (expected no_op)
505
-
506
- This reward is intentionally strong to bootstrap correct rule-following.
507
  """
508
  scores = []
509
  for i, comp in enumerate(completions):
510
  try:
511
  obs_data = json.loads(obs_json[i]) if obs_json else {}
512
  except (TypeError, IndexError, json.JSONDecodeError):
513
- scores.append(-2.0)
514
  continue
515
 
516
- expected = _get_expected_action(obs_data)
517
  predicted = _extract_command(_get_completion_text(comp))
518
 
519
  if predicted is None:
520
- scores.append(-2.0)
521
  continue
522
 
523
  if predicted.strip() == expected.strip():
524
- scores.append(5.0)
525
  continue
526
 
527
  if expected == "no_op":
528
- # Healthy system β€” penalise unnecessary intervention
529
- scores.append(-1.0 if predicted != "no_op" else 0.0)
530
  continue
531
 
532
  if predicted == "no_op":
533
- # Missed a required action
534
- scores.append(-2.0)
535
  continue
536
 
537
  # Same action type, wrong parameters?
538
  try:
539
  act_p = parse_command(predicted)
540
  act_e = parse_command(expected)
541
- scores.append(2.0 if act_p.action_type == act_e.action_type else -2.0)
542
  except CommandParseError:
543
- scores.append(-2.0)
544
 
545
  return scores
546
 
 
547
  # ---------------------------------------------------------------------------
548
  # Main
549
  # ---------------------------------------------------------------------------
550
 
 
551
  def main() -> None:
552
  # ---- Load model (Unsloth FastLanguageModel + LoRA + FP8) ----
553
  print(f"[GRPO] Loading {MODEL_NAME} ...")
@@ -555,45 +603,56 @@ def main() -> None:
555
  # compilation_config=0 β†’ basic CUDA graphs only; skips piecewise graph-split that
556
  # crashes on A100 SM 8.0 (vLLM bug in _decompose_size_nodes)
557
  model, tokenizer = FastLanguageModel.from_pretrained(
558
- model_name = MODEL_NAME,
559
- max_seq_length = MAX_SEQ_LENGTH,
560
- load_in_4bit = False,
561
- fast_inference = True,
562
- max_lora_rank = LORA_RANK,
563
- load_in_fp8 = False, # FP8 requires compute capability 8.9+; A100 is 8.0
564
- compilation_config = 0, # avoid piecewise graph-split crash; still uses CUDA graphs
565
  )
566
 
567
  model = FastLanguageModel.get_peft_model(
568
  model,
569
- r = LORA_RANK,
570
- lora_alpha = LORA_RANK * 2, # 2Γ— alpha speeds up training
571
- target_modules = [
572
- "q_proj", "k_proj", "v_proj", "o_proj",
573
- "gate_proj", "up_proj", "down_proj", # include MLP for DIME reasoning
 
 
 
 
 
574
  ],
575
- use_gradient_checkpointing = "unsloth", # 30% memory reduction
576
- random_state = 3407,
577
  )
578
 
579
  # ---- Collect dataset ----
580
- print(f"\n[GRPO] Collecting dataset ({DATASET_EPISODES} episodes, {len(ALL_TASKS)} tasks)...")
 
 
581
  dataset = collect_dataset(DATASET_EPISODES, ALL_TASKS)
582
 
583
  # Filter prompts that exceed 90th-percentile token length (avoids outlier OOM)
584
  print("[GRPO] Filtering dataset by prompt length...")
585
  prompt_lens = [
586
- len(tokenizer.apply_chat_template(
587
- row["prompt"], add_generation_prompt=True, tokenize=True
588
- ))
 
 
589
  for row in dataset
590
  ]
591
- max_prompt_len = int(np.quantile(prompt_lens, 0.90)) + 1
592
- max_comp_len = MAX_SEQ_LENGTH - max_prompt_len
593
- keep_idx = [i for i, L in enumerate(prompt_lens) if L <= max_prompt_len]
594
- dataset = dataset.select(keep_idx)
595
- print(f"[GRPO] Final dataset: {len(dataset)} rows | "
596
- f"max_prompt={max_prompt_len} max_completion={max_comp_len}")
 
 
597
 
598
  # ---- Sleep vLLM engine if available (frees VRAM during training) ----
599
  if hasattr(model, "vllm_engine") and model.vllm_engine is not None:
@@ -607,40 +666,40 @@ def main() -> None:
607
  # PatchFastRL above also re-adds vllm_sampling_params as an alias, but the
608
  # native params are clearer and forward-compatible.
609
  training_args = TRLGRPOConfig(
610
- temperature = 1.0,
611
- top_k = 50,
612
- top_p = 0.95,
613
- min_p = 0.1,
614
- learning_rate = 5e-6,
615
- weight_decay = 0.01,
616
- warmup_ratio = 0.1,
617
- lr_scheduler_type = "cosine",
618
- optim = "adamw_8bit",
619
- logging_steps = 5,
620
- per_device_train_batch_size = 1, # keep small: reward_env is CPU-bound, not GPU-bound
621
- gradient_accumulation_steps = 4, # effective batch = 4
622
- num_generations = NUM_GENERATIONS,
623
- vllm_gpu_memory_utilization = 0.7, # 70% of remaining VRAM for KV cache β†’ faster generation
624
- max_prompt_length = max_prompt_len,
625
- max_completion_length = MAX_COMPLETION_LENGTH, # hard cap: prevents Qwen3 think-block bloat
626
- max_steps = MAX_STEPS,
627
- save_steps = SAVE_STEPS,
628
- output_dir = OUTPUT_DIR,
629
- report_to = "none",
630
  )
631
 
632
  # ---- Trainer ----
633
  trainer = GRPOTrainer(
634
- model = model,
635
- processing_class = tokenizer,
636
- reward_funcs = [
637
- reward_format, # structural: <reasoning><action> tags ← early signal
638
  reward_validity, # syntactic: command parses without error ← anti-hallucination
639
- reward_env, # semantic: env simulation, uptime+latency ← main SRE signal
640
- reward_triage, # oracle: matches triage tree expected action ← strong supervision
641
  ],
642
- args = training_args,
643
- train_dataset = dataset,
644
  )
645
 
646
  print("\n[GRPO] Training starts...")
 
26
 
27
  # UNSLOTH_VLLM_STANDBY must be set before unsloth is imported.
28
  import os
29
+
30
  os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
31
 
32
  # ---------------------------------------------------------------------------
 
65
  from server.command_parser import parse_command, CommandParseError
66
  from server.rubrics import calculate_step_reward as _calculate_step_reward
67
 
68
+
69
  def _probe_rubrics() -> bool:
70
  """Return True if rubrics returns bounded rewards (main branch), False if -1000 (nithish)."""
71
  try:
 
77
  except Exception:
78
  return False
79
 
80
+
81
  _RUBRICS_BOUNDED = _probe_rubrics()
82
+ print(
83
+ f"[GRPO] rubrics version: {'main (bounded [-5,+5])' if _RUBRICS_BOUNDED else 'nithish (WARNING: -1000 cliff β€” using fallback)'}"
84
+ )
85
 
86
  # ---------------------------------------------------------------------------
87
  # Config
88
  # ---------------------------------------------------------------------------
89
 
90
+ MODEL_NAME = "unsloth/Qwen3-8B"
91
+ MAX_SEQ_LENGTH = 2048
92
+ LORA_RANK = 32
93
+ OUTPUT_DIR = "checkpoints/qwen3_grpo_unsloth"
94
 
95
+ DATASET_EPISODES = 500 # env rollouts to build the training dataset
96
+ MAX_STEPS = 300 # GRPOTrainer update steps
97
+ NUM_GENERATIONS = 4 # G β€” completions per prompt; reward_env is CPU-bound, keep small
98
+ MAX_COMPLETION_LENGTH = (
99
+ 512 # Qwen3 no-think response is ~60 tokens; 512 is a safe ceiling
100
+ )
101
+ SAVE_STEPS = 100
102
 
103
  ALL_TASKS = [
104
+ "traffic_spike",
105
+ "node_failure",
106
+ "cascading_failure",
107
+ "flash_crowd",
108
+ "thundering_herd",
109
+ "zombie_node",
110
+ "hot_shard_skew",
111
+ "memory_leak_slow_burn",
112
+ "split_brain_io_bottleneck",
113
+ "black_swan_az_failure",
114
+ "retry_storm",
115
+ "connection_pool_deadlock",
116
+ "autoscaler_flapping_trap",
117
  ]
118
 
119
  # ---------------------------------------------------------------------------
120
+ # System prompt β€” shared with inference.py
121
  # ---------------------------------------------------------------------------
122
 
123
  SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE) managing a highly volatile Kubernetes cluster.
 
140
 
141
  CRITICAL INCIDENT TRIAGE TREE (Follow strictly in order):
142
  1. OOM IMMINENT: IF ANY 'mem_utilizations' > 0.92 β†’ kubectl delete pod node-<leaking_node>
143
+ 2. DB RECOVERY: IF node-0 is in 'failed_nodes' β†’ kubectl delete pod node-0
144
+ (The DB is a SPOF. If it's dead, ALL other actions are futile until it restarts.)
145
+ 3. SPLIT-BRAIN: IF node_0 'io_wait' > 0.80 β†’ kubectl throttle ingress --rate=0.5
146
+ 4. HOT SHARD: IF one worker CPU > 0.90 but cluster average is low
147
  β†’ kubectl exec -it istio-proxy -- traffic shift --from=<hot> --to=<cold>
148
+ 5. RETRY STORM: IF 'p99_latency' > 100.0 AND traffic spiking β†’ kubectl throttle ingress --rate=0.4
149
+ 6. ZOMBIE NODE: IF worker CPU < 0.10 BUT 'p99_latency' is huge
150
  β†’ kubectl exec -it istio-proxy -- traffic shift --from=<zombie> --to=<healthy>
151
+ 7. BLACK SWAN: IF multiple nodes in 'failed_nodes' (but DB is alive) β†’ kubectl throttle ingress --rate=0.3
152
+ 8. DATABASE SURVIVAL: IF node-0 cpu_load > 0.80 β†’ kubectl throttle ingress --rate=0.7
 
153
  9. SAFE SCALING: IF avg worker CPU > 0.75 AND 'error_budget' > 20
154
  β†’ kubectl scale deployment frontend --replicas=10
155
  10. HEALTHY: If metrics are stable β†’ no_op
 
164
  # Triage oracle β€” deterministic expected action for any observation
165
  # ---------------------------------------------------------------------------
166
 
167
+
168
  def _get_expected_action(obs: dict) -> str:
169
+ """Return the kubectl command the triage tree mandates, or 'no_op'.
170
+
171
+ Rule ordering is critical for RL convergence:
172
+ 1. OOM β€” immediate life-or-death
173
+ 2. DB Recovery β€” SPOF must be restored before anything else
174
+ 3-6. Network/traffic rules
175
+ 7. Black Swan β€” only fires if DB is alive
176
+ 8-9. Proactive scaling
177
+ 10. Healthy
178
+ """
179
+ cpu = obs.get("cpu_loads", [0.3] * 8)
180
+ mem = obs.get("mem_utilizations", [0.2] * 8)
181
+ fail = set(obs.get("failed_nodes", []))
182
+ io = float(obs.get("io_wait", 0.0))
183
+ p99 = float(obs.get("p99_latency", 0.0))
184
+ rr = float(obs.get("request_rate", 100.0))
185
+ bud = float(obs.get("error_budget", 100.0))
186
+
187
+ # Rule 1: OOM β€” instant kill prevention
188
  for i, m in enumerate(mem):
189
  if float(m) > 0.92:
190
  return f"kubectl delete pod node-{i}"
191
 
192
+ # Rule 2: DB RECOVERY β€” the DB is a SPOF; if it's dead, nothing else matters
193
+ if 0 in fail:
194
+ return "kubectl delete pod node-0"
195
+
196
+ # Rule 3: Split-brain
197
  if io > 0.80:
198
  return "kubectl throttle ingress --rate=0.5"
199
 
200
+ # Rule 4: Hot shard
201
  workers = [(i, float(c)) for i, c in enumerate(cpu[1:], 1) if float(c) >= 0]
202
  if workers:
203
  avg = sum(c for _, c in workers) / len(workers)
 
211
  if dst is not None:
212
  return f"kubectl exec -it istio-proxy -- traffic shift --from={i} --to={dst}"
213
 
214
+ # Rule 5: Retry storm
215
  if p99 > 100.0 and rr > 150:
216
  return "kubectl throttle ingress --rate=0.4"
217
 
218
+ # Rule 6: Zombie node
219
  for i, c in workers:
220
  if 0 <= c < 0.10 and p99 > 100.0:
221
  dst = next(
 
225
  if dst is not None:
226
  return f"kubectl exec -it istio-proxy -- traffic shift --from={i} --to={dst}"
227
 
228
+ # Rule 7: Black swan (only fires when DB is alive β€” DB recovery is above)
229
  if len(fail) >= 2:
230
  return "kubectl throttle ingress --rate=0.3"
231
 
232
+ # Rule 8: DB survival (protect a living DB under load)
233
  db_cpu = float(cpu[0]) if cpu and float(cpu[0]) >= 0 else 0.0
234
  if db_cpu > 0.80:
235
  return "kubectl throttle ingress --rate=0.7"
236
 
 
 
 
 
237
  # Rule 9: Safe scaling
238
  if workers and sum(c for _, c in workers) / len(workers) > 0.75 and bud > 20:
239
  return "kubectl scale deployment frontend --replicas=10"
240
 
241
  return "no_op"
242
 
243
+
244
  # ---------------------------------------------------------------------------
245
  # Dataset collection
246
  # ---------------------------------------------------------------------------
247
 
248
+
249
  def _obs_to_dict(obs: InfraObservation) -> dict:
250
  keys = [
251
+ "cpu_loads",
252
+ "mem_utilizations",
253
+ "queue_lengths",
254
+ "failed_nodes",
255
+ "latency_ms",
256
+ "request_rate",
257
+ "io_wait",
258
+ "p99_latency",
259
+ "error_budget",
260
+ "step",
261
+ "task_hint",
262
+ "action_errors",
263
+ "cloud_budget",
264
  ]
265
  return {k: getattr(obs, k) for k in keys if hasattr(obs, k)}
266
 
 
270
  if random.random() < 0.30:
271
  atype = random.choice(["no_op", "restart_node", "throttle", "scale_up"])
272
  if atype == "restart_node":
273
+ return InfraAction(
274
+ action_type="restart_node",
275
+ target=random.randint(0, min(7, len(obs.cpu_loads) - 1)),
276
+ )
277
  if atype == "throttle":
278
+ return InfraAction(
279
+ action_type="throttle", rate=random.choice([0.3, 0.5, 0.7])
280
+ )
281
  if atype == "scale_up":
282
  return InfraAction(action_type="scale_up")
283
  return InfraAction(action_type="no_op")
 
312
 
313
  for _ in range(20):
314
  d = _obs_to_dict(obs)
315
+ rows.append(
316
+ {
317
+ "prompt": [
318
+ {"role": "system", "content": SYSTEM_PROMPT},
319
+ {
320
+ "role": "user",
321
+ "content": (
322
+ "/no_think\n" # suppress Qwen3 <think> block β€” response is ~60 tokens, not ~600
323
+ f"Current system state:\n{json.dumps(d)}\n"
324
+ "Respond with the required XML and JSON format."
325
+ ),
326
+ },
327
+ ],
328
+ "obs_json": json.dumps(d),
329
+ "task": task,
330
+ }
331
+ )
332
 
333
  action = _heuristic_action(obs)
334
  try:
 
339
  break
340
 
341
  if (ep + 1) % 50 == 0:
342
+ print(f" [dataset] episode {ep + 1}/{n_episodes} β†’ {len(rows)} rows")
343
 
344
  print(f" [dataset] collected {len(rows)} total rows")
345
  return Dataset.from_list(rows)
346
 
347
+
348
  # ---------------------------------------------------------------------------
349
  # Helpers shared by reward functions
350
  # ---------------------------------------------------------------------------
351
 
352
+
353
  def _get_completion_text(comp) -> str:
354
  """
355
  Extract completion text from TRL GRPOTrainer's format.
 
387
  sim = env.sim
388
  cpu_l = obs.get("cpu_loads", [])
389
  mem_l = obs.get("mem_utilizations", [])
390
+ q_l = obs.get("queue_lengths", [])
391
 
392
  for i in range(min(len(cpu_l), len(sim.nodes))):
393
  if float(cpu_l[i]) >= 0:
 
401
  if 0 <= idx < len(sim.nodes):
402
  sim.nodes[idx].is_failed = True
403
 
404
+ sim.latency_ms = float(obs.get("latency_ms", 20.0))
405
+ sim.error_budget = float(obs.get("error_budget", 100.0))
406
+ sim.last_trace_p99_latency = float(obs.get("p99_latency", 0.0))
407
+ sim.last_trace_node_0_io = float(obs.get("io_wait", 0.0))
408
+
409
 
410
  # ---------------------------------------------------------------------------
411
  # Reward functions (TRL GRPOTrainer signature)
 
418
  # len(obs_json) == G (same value repeated G times by TRL)
419
  # ---------------------------------------------------------------------------
420
 
421
+
422
  def reward_format(completions: List, **kwargs) -> List[float]:
423
  """
424
  Reward XML structure compliance (mirrors Unsloth's match_format_exactly /
 
431
  scores = []
432
  for comp in completions:
433
  text = _get_completion_text(comp)
434
+ n_re = text.count("<reasoning>")
435
  n_re_ = text.count("</reasoning>")
436
+ n_ac = text.count("<action>")
437
  n_ac_ = text.count("</action>")
438
 
439
  if n_re == 1 and n_re_ == 1 and n_ac == 1 and n_ac_ == 1:
440
  scores.append(3.0)
441
  else:
442
  s = 0.0
443
+ s += 0.5 if n_re_ == 1 else -1.0
444
+ s += 0.5 if n_ac == 1 else -1.0
445
+ s += 0.5 if n_ac_ == 1 else -1.0
446
  scores.append(s)
447
  return scores
448
 
 
479
  **kwargs,
480
  ) -> List[float]:
481
  """
482
+ Environment simulation reward β€” the PRIMARY training signal.
483
 
484
+ Uses calculate_step_reward() from server/rubrics.py:
485
  - 7 components: uptime, DB CPU, memory cliff, p99 latency, load shedding,
486
  action efficiency, temporal friction
487
  - Bounded to [-5.0, +5.0] β€” no -1000 cliff, gradients always flow
 
488
 
489
+ Output is scaled by 2Γ— so the environment physics dominates over the
490
+ oracle (reward_triage) in the total reward signal.
491
 
492
+ Range: [βˆ’10.0, +10.0] (2Γ— the raw [-5, +5])
493
  """
494
  scores = []
495
  for i, comp in enumerate(completions):
496
  try:
497
+ obs_data = json.loads(obs_json[i]) if obs_json else {}
498
  task_name = task[i] if task else "traffic_spike"
499
  except (TypeError, IndexError, json.JSONDecodeError):
500
+ scores.append(-10.0)
501
  continue
502
 
503
  env = DistributedInfraEnvironment()
 
519
  pass
520
 
521
  if _RUBRICS_BOUNDED:
522
+ # Main branch: 2Γ— scaled to dominate over oracle reward
523
+ scores.append(2.0 * _calculate_step_reward(env.sim))
524
  else:
525
+ # Nithish branch fallback: simple 3-component formula [-5.0, +1.0]
526
+ sim = env.sim
527
  nodes = sim.nodes
528
  alive = sum(1 for n in nodes if not n.is_failed)
529
+ r_up = 0.5 * (alive / max(len(nodes), 1))
530
  r_lat = -0.5 * min((max(0.0, sim.latency_ms - 50.0) / 100.0) ** 2, 1.0)
531
+ r_db = -2.0 if (nodes and nodes[0].is_failed) else 0.0
532
+ scores.append(2.0 * (r_up + r_lat + r_db))
533
 
534
  return scores
535
 
 
540
  **kwargs,
541
  ) -> List[float]:
542
  """
543
+ Triage oracle reward β€” gentle guidance, NOT the primary teacher.
544
 
545
  Compares the model's action against the deterministic triage tree output.
546
+ Kept intentionally weak so reward_env (physics) dominates learning.
547
 
548
+ +1.0 β€” exact command match with expected action
549
+ +0.5 β€” same action_type but different parameters
550
+ 0.0 β€” no_op when a specific action is expected, or healthy system
551
+ -0.5 β€” completely wrong action type
552
+ -0.5 β€” unnecessary action when system is healthy (expected no_op)
 
 
553
  """
554
  scores = []
555
  for i, comp in enumerate(completions):
556
  try:
557
  obs_data = json.loads(obs_json[i]) if obs_json else {}
558
  except (TypeError, IndexError, json.JSONDecodeError):
559
+ scores.append(-0.5)
560
  continue
561
 
562
+ expected = _get_expected_action(obs_data)
563
  predicted = _extract_command(_get_completion_text(comp))
564
 
565
  if predicted is None:
566
+ scores.append(-0.5)
567
  continue
568
 
569
  if predicted.strip() == expected.strip():
570
+ scores.append(1.0)
571
  continue
572
 
573
  if expected == "no_op":
574
+ # Healthy system β€” mild penalty for unnecessary intervention
575
+ scores.append(-0.5 if predicted != "no_op" else 0.0)
576
  continue
577
 
578
  if predicted == "no_op":
579
+ # Missed a required action β€” mild penalty
580
+ scores.append(0.0)
581
  continue
582
 
583
  # Same action type, wrong parameters?
584
  try:
585
  act_p = parse_command(predicted)
586
  act_e = parse_command(expected)
587
+ scores.append(0.5 if act_p.action_type == act_e.action_type else -0.5)
588
  except CommandParseError:
589
+ scores.append(-0.5)
590
 
591
  return scores
592
 
593
+
594
  # ---------------------------------------------------------------------------
595
  # Main
596
  # ---------------------------------------------------------------------------
597
 
598
+
599
  def main() -> None:
600
  # ---- Load model (Unsloth FastLanguageModel + LoRA + FP8) ----
601
  print(f"[GRPO] Loading {MODEL_NAME} ...")
 
603
  # compilation_config=0 β†’ basic CUDA graphs only; skips piecewise graph-split that
604
  # crashes on A100 SM 8.0 (vLLM bug in _decompose_size_nodes)
605
  model, tokenizer = FastLanguageModel.from_pretrained(
606
+ model_name=MODEL_NAME,
607
+ max_seq_length=MAX_SEQ_LENGTH,
608
+ load_in_4bit=False,
609
+ fast_inference=True,
610
+ max_lora_rank=LORA_RANK,
611
+ load_in_fp8=False, # FP8 requires compute capability 8.9+; A100 is 8.0
612
+ compilation_config=0, # avoid piecewise graph-split crash; still uses CUDA graphs
613
  )
614
 
615
  model = FastLanguageModel.get_peft_model(
616
  model,
617
+ r=LORA_RANK,
618
+ lora_alpha=LORA_RANK * 2, # 2Γ— alpha speeds up training
619
+ target_modules=[
620
+ "q_proj",
621
+ "k_proj",
622
+ "v_proj",
623
+ "o_proj",
624
+ "gate_proj",
625
+ "up_proj",
626
+ "down_proj", # include MLP for DIME reasoning
627
  ],
628
+ use_gradient_checkpointing="unsloth", # 30% memory reduction
629
+ random_state=3407,
630
  )
631
 
632
  # ---- Collect dataset ----
633
+ print(
634
+ f"\n[GRPO] Collecting dataset ({DATASET_EPISODES} episodes, {len(ALL_TASKS)} tasks)..."
635
+ )
636
  dataset = collect_dataset(DATASET_EPISODES, ALL_TASKS)
637
 
638
  # Filter prompts that exceed 90th-percentile token length (avoids outlier OOM)
639
  print("[GRPO] Filtering dataset by prompt length...")
640
  prompt_lens = [
641
+ len(
642
+ tokenizer.apply_chat_template(
643
+ row["prompt"], add_generation_prompt=True, tokenize=True
644
+ )
645
+ )
646
  for row in dataset
647
  ]
648
+ max_prompt_len = int(np.quantile(prompt_lens, 0.90)) + 1
649
+ max_comp_len = MAX_SEQ_LENGTH - max_prompt_len
650
+ keep_idx = [i for i, L in enumerate(prompt_lens) if L <= max_prompt_len]
651
+ dataset = dataset.select(keep_idx)
652
+ print(
653
+ f"[GRPO] Final dataset: {len(dataset)} rows | "
654
+ f"max_prompt={max_prompt_len} max_completion={max_comp_len}"
655
+ )
656
 
657
  # ---- Sleep vLLM engine if available (frees VRAM during training) ----
658
  if hasattr(model, "vllm_engine") and model.vllm_engine is not None:
 
666
  # PatchFastRL above also re-adds vllm_sampling_params as an alias, but the
667
  # native params are clearer and forward-compatible.
668
  training_args = TRLGRPOConfig(
669
+ temperature=1.0,
670
+ top_k=50,
671
+ top_p=0.95,
672
+ min_p=0.1,
673
+ learning_rate=5e-6,
674
+ weight_decay=0.01,
675
+ warmup_ratio=0.1,
676
+ lr_scheduler_type="cosine",
677
+ optim="adamw_8bit",
678
+ logging_steps=5,
679
+ per_device_train_batch_size=1, # keep small: reward_env is CPU-bound, not GPU-bound
680
+ gradient_accumulation_steps=4, # effective batch = 4
681
+ num_generations=NUM_GENERATIONS,
682
+ vllm_gpu_memory_utilization=0.7, # 70% of remaining VRAM for KV cache β†’ faster generation
683
+ max_prompt_length=max_prompt_len,
684
+ max_completion_length=MAX_COMPLETION_LENGTH, # hard cap: prevents Qwen3 think-block bloat
685
+ max_steps=MAX_STEPS,
686
+ save_steps=SAVE_STEPS,
687
+ output_dir=OUTPUT_DIR,
688
+ report_to="none",
689
  )
690
 
691
  # ---- Trainer ----
692
  trainer = GRPOTrainer(
693
+ model=model,
694
+ processing_class=tokenizer,
695
+ reward_funcs=[
696
+ reward_format, # structural: <reasoning><action> tags ← early signal
697
  reward_validity, # syntactic: command parses without error ← anti-hallucination
698
+ reward_env, # semantic: env simulation, uptime+latency ← main SRE signal
699
+ reward_triage, # oracle: matches triage tree expected action ← strong supervision
700
  ],
701
+ args=training_args,
702
+ train_dataset=dataset,
703
  )
704
 
705
  print("\n[GRPO] Training starts...")