Arijit-07 commited on
Commit
5e31bf4
Β·
verified Β·
1 Parent(s): 5e0397b

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +303 -358
train_model.py CHANGED
@@ -1,88 +1,91 @@
1
- # ── Cell 1: Install dependencies ──────────────────────────────────────────────
2
- import sys, os, json, time, random
3
 
4
- # Set logits flag BEFORE any import
5
  os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
6
 
7
-
8
- # Clear stale module cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  for mod in list(sys.modules.keys()):
10
  if any(x in mod for x in ['trl','unsloth','transformers','peft']):
11
  del sys.modules[mod]
12
 
13
- # Verify β€” unsloth must be imported first
14
  import unsloth
15
  from unsloth import FastLanguageModel
16
  import transformers, peft, torch
17
 
18
- print(f'βœ… unsloth {unsloth.__version__}')
19
- print(f'βœ… transformers {transformers.__version__}')
20
- print(f'βœ… torch {torch.__version__} | CUDA: {torch.cuda.is_available()}')
21
- print(f'βœ… UNSLOTH_RETURN_LOGITS = {os.environ["UNSLOTH_RETURN_LOGITS"]}')
22
- print('βœ… All dependencies installed')
23
-
24
- # ── Cell 2: Config β€” SET YOUR HF TOKEN HERE ───────────────────────────────────
25
- import os
26
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
27
-
28
- HF_TOKEN = os.environ.get('HF_TOKEN') # set as env var or paste here
29
- if HF_TOKEN:
30
- os.environ['HF_TOKEN'] = HF_TOKEN
31
 
32
- from huggingface_hub import login
 
33
  if HF_TOKEN:
 
34
  login(token=HF_TOKEN, add_to_git_credential=False)
35
- print('Logged in to HuggingFace')
36
  else:
37
- print('HF_TOKEN not set; continuing without HuggingFace login')
38
 
 
39
  CONFIG = {
40
- # Model β€” 8B for A100
41
  'model_name': 'unsloth/Meta-Llama-3.1-8B-Instruct',
42
- 'max_seq_length': 3072,
43
- 'load_in_4bit': True,
44
 
45
- # Environment
46
  'env_url': 'https://arijit-07-devops-incident-response.hf.space',
47
  'tasks': ['easy', 'medium', 'hard', 'bonus'],
48
  'episodes_per_task': 40,
49
- 'max_steps_per_episode': 12, # reduced from 20 β€” tighter episodes
50
 
51
- # Training β€” conservative to prevent catastrophic forgetting
52
- 'learning_rate': 5e-6, # FIXED: was 1e-5, caused degradation
53
  'grpo_group_size': 4,
54
  'lora_rank': 32,
55
  'lora_alpha': 64,
56
  'max_grad_norm': 0.5,
57
- 'kl_coeff': 0.05, # NEW: prevents catastrophic forgetting
58
 
59
- # Output
60
  'hf_repo': 'Arijit-07/aria-devops-llama8b',
61
  'output_dir': '/data/outputs',
62
  'save_every_n_episodes': 20,
63
  }
64
 
65
- import torch
66
- print(f'GPU: {torch.cuda.get_device_name(0)}')
67
- print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
68
- print(f'Model: {CONFIG["model_name"]} | Tasks: {CONFIG["tasks"]}')
69
- print(f'LR: {CONFIG["learning_rate"]} | KL: {CONFIG["kl_coeff"]}')
70
 
71
- # ── Cell 3: Environment Client ──────────────────────────��─────────────────────
72
- import requests, json, time, random
73
 
74
  BASE_URL = CONFIG['env_url']
75
 
76
  def env_reset(task_id, seed=None):
77
  payload = {'task_id': task_id}
78
- if seed is not None: payload['seed'] = seed
 
79
  for attempt in range(3):
80
  try:
81
  r = requests.post(f'{BASE_URL}/reset', json=payload, timeout=30)
82
  r.raise_for_status()
83
  return r.json()
84
- except:
85
- if attempt == 2: raise
 
86
  time.sleep(5)
87
 
88
  def env_step(action):
@@ -91,8 +94,9 @@ def env_step(action):
91
  r = requests.post(f'{BASE_URL}/step', json=action, timeout=30)
92
  r.raise_for_status()
93
  return r.json()
94
- except:
95
- if attempt == 2: raise
 
96
  time.sleep(5)
97
 
98
  def env_state():
@@ -108,119 +112,150 @@ VALID_ACTIONS = {
108
  }
109
 
110
  def sanitize_action(action):
111
- DEFAULT_SERVICE = "order-service" # safe fallback
112
-
113
  if not isinstance(action, dict):
114
  return {"action_type": "read_logs", "service": DEFAULT_SERVICE}
115
-
116
  action_type = action.get("action_type", "").lower()
117
-
118
- # Fix common mistakes
119
- if action_type == "read_service_logs":
120
- action_type = "read_logs"
121
-
122
  if action_type not in VALID_ACTIONS:
123
  action_type = "read_logs"
124
-
125
- # ALWAYS ensure service exists
126
  service = action.get("service") or action.get("service_name") or DEFAULT_SERVICE
127
-
128
- clean = {
129
- "action_type": action_type,
130
- "service": service
131
- }
132
-
133
- # Optional fields (safe add)
134
  for key in ["root_cause", "runbook", "version", "reason",
135
  "query", "ip_range", "table", "column", "target_region"]:
136
  if key in action and isinstance(action[key], str):
137
  clean[key] = action[key]
138
-
139
  return clean
140
-
141
 
 
142
  health = requests.get(f'{BASE_URL}/health', timeout=15).json()
143
- print(f'βœ… Environment: {health}')
144
- test_obs = env_reset('easy', seed=0)
145
- print(f'βœ… Reset OK. Services: {len(test_obs.get("services", []))}')
146
-
147
- # ── Cell 4: System Prompt + Observation Formatter ─────────────────────────────
148
-
149
- print('Config loaded:')
150
- for k, v in CONFIG.items():
151
- print(f' {k}: {v}')
152
-
153
- SYSTEM_PROMPT = """
154
- You are an autonomous DevOps agent.
155
 
 
 
156
  You MUST return ONLY valid JSON.
157
 
158
- STRICT RULES:
159
- - action_type MUST be one of:
160
- diagnose, read_logs, read_metrics, read_runbook, search_logs,
161
- restart_service, rollback, scale_up, alert_oncall, acknowledge,
162
- noop, block_ip_range, create_index, failover
163
-
164
- - Use EXACT parameter names:
165
- service (NOT service_name)
166
- root_cause
167
- runbook
168
- version
169
- reason
170
- query
171
- ip_range
172
- table
173
- column
174
- target_region
175
-
176
- - DO NOT invent new fields
177
- - DO NOT change names
178
- - DO NOT use service_name
179
- - Always output valid JSON only
180
-
181
- Example:
182
- {
183
- "action_type": "read_logs",
184
- "service": "order-service"
185
- }
186
- """
187
 
188
  def observation_to_prompt(obs, task_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  return (
190
- f"Task: {task_id}\n"
191
- "Current environment observation:\n"
192
- f"{json.dumps(obs, indent=2, sort_keys=True)}\n"
193
- "Choose the next valid action as JSON."
 
194
  )
195
 
196
- # ── Cell 5: Load Llama-3.1-8B with Unsloth ───────────────────────────────────
197
- from unsloth import FastLanguageModel
198
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- print(f'Loading {CONFIG["model_name"]}...')
 
201
 
 
202
  checkpoint_path = f"{CONFIG['output_dir']}/latest"
203
- resuming_from_checkpoint = os.path.exists(checkpoint_path)
204
 
205
- if resuming_from_checkpoint:
206
- print("πŸ” Resuming from checkpoint...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  model, tokenizer = FastLanguageModel.from_pretrained(
208
  model_name=checkpoint_path,
209
  max_seq_length=CONFIG['max_seq_length'],
210
  dtype=None,
211
- load_in_4bit=False, # πŸ”₯ FORCE DISABLE
 
212
  )
213
  else:
214
- print("πŸ†• Starting fresh model...")
215
  model, tokenizer = FastLanguageModel.from_pretrained(
216
  model_name=CONFIG['model_name'],
217
  max_seq_length=CONFIG['max_seq_length'],
218
  dtype=None,
219
- load_in_4bit=False, # πŸ”₯ FORCE DISABLE
220
- token=HF_TOKEN,
221
  )
222
-
223
- if not resuming_from_checkpoint:
224
  model = FastLanguageModel.get_peft_model(
225
  model,
226
  r=CONFIG['lora_rank'],
@@ -235,146 +270,66 @@ if not resuming_from_checkpoint:
235
 
236
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
237
  total = sum(p.numel() for p in model.parameters())
238
- print(f'βœ… Model loaded')
239
- import copy
240
 
241
- model_device = next(model.parameters()).device
242
-
243
- ref_model = copy.deepcopy(model).to(model_device)
244
  ref_model.eval()
245
-
246
  for p in ref_model.parameters():
247
  p.requires_grad = False
248
-
249
  print("βœ… Reference model frozen for KL penalty")
250
- # πŸ” Load training state if exists
251
- state_path = f"{CONFIG['output_dir']}/state.json"
252
-
253
- if os.path.exists(state_path):
254
- print("πŸ” Restoring training state...")
255
- with open(state_path, "r") as f:
256
- state = json.load(f)
257
-
258
- global_ep = state.get("global_ep", 0)
259
- training_log = state.get("training_log", [])
260
- episode_scores = state.get("episode_scores", {t: [] for t in CONFIG['tasks']})
261
-
262
- print(f"βœ… Resumed from episode {global_ep}")
263
- print(f"πŸš€ Continuing training from episode {global_ep}")
264
- else:
265
- print("πŸ†• Starting fresh training state")
266
- training_log = []
267
- episode_scores = {t: [] for t in CONFIG['tasks']}
268
- global_ep = 0
269
- print(f' Trainable: {trainable:,} ({100*trainable/total:.2f}%)')
270
- print(f' VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB used')
271
-
272
- # ── Cell 6: Action Parser + Episode Runner ────────────────────────────────────
273
- import re
274
-
275
- def parse_action(text):
276
- text = text.strip()
277
- for pattern in [
278
- r'```json\s*({.*?})\s*```',
279
- r'```\s*({.*?})\s*```',
280
- r'({\s*"action_type"[^}]+})',
281
- ]:
282
- match = re.search(pattern, text, re.DOTALL)
283
- if match:
284
- try: return json.loads(match.group(1))
285
- except: continue
286
- try: return json.loads(text)
287
- except: return {'action_type': 'noop'}
288
-
289
- def generate_action(obs, task_id, temperature=0.7):
290
- messages = [
291
- {'role': 'system', 'content': SYSTEM_PROMPT},
292
- {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
293
- ]
294
-
295
- input_ids = tokenizer.apply_chat_template(
296
- messages,
297
- tokenize=True,
298
- add_generation_prompt=True,
299
- return_tensors='pt'
300
- )
301
-
302
- # πŸ”₯ FIX: truncate
303
- input_ids = input_ids[:, -CONFIG['max_seq_length']:].to('cuda')
304
-
305
- FastLanguageModel.for_inference(model)
306
-
307
- with torch.no_grad():
308
- out = model.generate(
309
- input_ids,
310
- max_new_tokens=150,
311
- temperature=temperature,
312
- do_sample=True,
313
- pad_token_id=tokenizer.eos_token_id,
314
- )
315
-
316
- generated = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
317
-
318
- return parse_action(generated), generated
319
 
 
320
  def run_episode(task_id, seed=None, verbose=False):
321
  obs = env_reset(task_id, seed=seed)
322
  total_reward = 0.0
323
  done = False
 
324
 
325
  for step in range(CONFIG['max_steps_per_episode']):
326
  if done:
327
  break
328
-
329
- action, _ = generate_action(obs, task_id)
330
-
331
- action = sanitize_action(action)
332
-
 
 
 
 
 
 
 
 
 
 
 
333
  if verbose:
334
- print(f' Step {step+1}: {action}')
335
-
336
- print("πŸ” Sending action:", action)
337
-
338
  result = env_step(action)
339
-
340
  total_reward += result.get('reward', 0.0)
341
  obs = result.get('observation', obs)
342
  done = result.get('done', False)
 
343
 
344
- state = env_state()
345
- return state.get('current_score', total_reward)
346
-
347
- print('βœ… Episode runner ready')
348
- print('Testing one episode...')
349
- test_score = run_episode('easy', seed=99, verbose=True)
350
- print(f'Test score: {test_score:.3f}')
351
-
352
- # ── Cell 7: Pre-Training Baseline ────────────────────────────────────────────
353
- print('Running pre-training baseline (8 episodes per task)...')
354
  baseline_scores = {}
355
-
356
  for task_id in CONFIG['tasks']:
357
- scores = [run_episode(task_id, seed=i*7+3) for i in range(8)]
358
  avg = sum(scores) / len(scores)
359
  baseline_scores[task_id] = {'scores': scores, 'avg': avg}
360
- print(f' [{task_id}] baseline: {avg:.3f} (min={min(scores):.3f} max={max(scores):.3f})')
361
-
362
- print('\nβœ… Baseline done. Starting training...')
363
-
364
-
365
-
366
- import torch
367
- assert torch.cuda.is_available(), "GPU NOT DETECTED!"
368
- print("Using GPU:", torch.cuda.get_device_name(0))
369
-
370
-
371
- # ── GRPO Training Helpers ─────────────────────────────────────────────
372
 
 
373
  def run_episode_collect(task_id, seed):
 
374
  obs = env_reset(task_id, seed=seed)
375
  trajectory = []
376
  done = False
377
-
378
  FastLanguageModel.for_inference(model)
379
 
380
  for step in range(CONFIG['max_steps_per_episode']):
@@ -385,59 +340,52 @@ def run_episode_collect(task_id, seed):
385
  {'role': 'system', 'content': SYSTEM_PROMPT},
386
  {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
387
  ]
388
-
389
  input_ids = tokenizer.apply_chat_template(
390
- messages,
391
- tokenize=True,
392
- add_generation_prompt=True,
393
  return_tensors='pt'
394
  )
395
-
396
  input_ids = input_ids[:, -CONFIG['max_seq_length']:].to('cuda')
397
 
398
- group_completions, group_rewards = [], []
399
-
400
  for _ in range(CONFIG['grpo_group_size']):
401
  with torch.no_grad():
402
  out = model.generate(
403
- input_ids,
404
- max_new_tokens=128,
405
- temperature=0.9,
406
- do_sample=True,
407
- pad_token_id=tokenizer.eos_token_id,
408
  )
409
-
410
  gen_ids = out[0][input_ids.shape[1]:]
411
- gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True)
 
412
 
 
 
 
413
  action = sanitize_action(parse_action(gen_text))
414
-
415
  try:
416
- env_reset(task_id, seed=seed)
417
  res = env_step(action)
418
- reward = res.get('reward', 0.0)
419
  except:
420
- reward = 0.0
421
-
422
- group_completions.append(gen_ids)
423
- group_rewards.append(reward)
424
 
 
425
  best_idx = group_rewards.index(max(group_rewards))
426
- best_action = sanitize_action(parse_action(
427
- tokenizer.decode(group_completions[best_idx], skip_special_tokens=True)
428
- ))
429
-
430
  try:
431
- res = env_step(best_action)
432
- obs = res.get('observation', obs)
433
- done = res.get('done', False)
434
  except:
435
  done = True
436
 
437
  trajectory.append({
438
  'input_ids': input_ids,
439
  'completions': group_completions,
440
- 'rewards': group_rewards
441
  })
442
 
443
  total_reward = sum(max(s['rewards']) for s in trajectory) if trajectory else 0.0
@@ -445,12 +393,12 @@ def run_episode_collect(task_id, seed):
445
 
446
 
447
  def update_from_trajectory(trajectory):
 
448
  if not trajectory:
449
  return 0.0
450
 
451
  device = next(model.parameters()).device
452
  FastLanguageModel.for_training(model)
453
-
454
  model.train()
455
  optimizer.zero_grad()
456
 
@@ -458,12 +406,10 @@ def update_from_trajectory(trajectory):
458
 
459
  for step_data in trajectory:
460
  input_ids = step_data['input_ids'].to(device)
461
- rewards = step_data['rewards']
462
  completions = step_data['completions']
 
463
 
464
- rewards_t = torch.tensor(rewards, device=device)
465
-
466
- # βœ… Stable advantage normalization
467
  if rewards_t.std() > 1e-8:
468
  advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
469
  else:
@@ -471,74 +417,66 @@ def update_from_trajectory(trajectory):
471
 
472
  best_idx = rewards.index(max(rewards))
473
  best_ids = completions[best_idx].to(device)
 
474
 
475
  full_ids = torch.cat([input_ids[0], best_ids]).unsqueeze(0)
476
  labels = full_ids.clone()
477
  labels[0, :input_ids.shape[1]] = -100
478
 
479
  outputs = model(full_ids, labels=labels)
480
- policy_loss = outputs.loss * (-advantages[best_idx])
481
 
482
- # βœ… KL REGULARIZATION (CRITICAL FIX)
483
  with torch.no_grad():
484
- ref_logits = ref_model(full_ids).logits
485
-
 
486
  kl = torch.nn.functional.kl_div(
487
- torch.log_softmax(outputs.logits, dim=-1),
488
  torch.softmax(ref_logits, dim=-1),
489
  reduction='batchmean'
490
  )
491
-
492
- loss = policy_loss + CONFIG['kl_coeff'] * kl
493
-
494
- total_loss += loss
495
 
496
  total_loss = total_loss / len(trajectory)
497
  total_loss.backward()
498
-
499
  torch.nn.utils.clip_grad_norm_(
500
  [p for p in model.parameters() if p.requires_grad],
501
  CONFIG['max_grad_norm']
502
  )
503
-
504
  optimizer.step()
505
  scheduler.step()
506
-
507
  return total_loss.item()
508
 
509
-
510
- # ── Optimizer & Scheduler ─────────────────────────────────────────
511
-
512
  from torch.optim import AdamW
513
  from transformers import get_cosine_schedule_with_warmup
514
 
515
  optimizer = AdamW(
516
  [p for p in model.parameters() if p.requires_grad],
517
- lr=CONFIG['learning_rate'],
518
- weight_decay=0.01
519
  )
520
-
521
  total_eps = CONFIG['episodes_per_task'] * len(CONFIG['tasks'])
522
-
523
  scheduler = get_cosine_schedule_with_warmup(
524
  optimizer,
525
  num_warmup_steps=max(1, total_eps // 10),
526
  num_training_steps=total_eps
527
  )
528
 
529
-
530
-
531
  def run_training():
532
  global global_ep
533
  start_time = time.time()
534
 
535
- best_score = -1e9
536
- no_improve_count = 0
537
- PATIENCE = 15
 
 
538
 
539
  for task_id in CONFIG['tasks']:
540
- print(f'\nπŸ“‹ Task: {task_id.upper()} | Baseline: {baseline_scores[task_id]["avg"]:.3f}')
541
- print('-' * 40)
542
 
543
  for ep in range(CONFIG['episodes_per_task']):
544
  seed = random.randint(0, 9999)
@@ -547,109 +485,116 @@ def run_training():
547
  loss = update_from_trajectory(trajectory)
548
 
549
  episode_scores[task_id].append(final_score)
550
-
551
- global_ep += 1 # βœ… FIXED
552
-
553
  elapsed = (time.time() - start_time) / 60
554
  recent = episode_scores[task_id][-10:]
555
- rolling = sum(recent) / len(recent) if recent else final_score
556
-
557
- if rolling > best_score:
558
- best_score = rolling
559
- no_improve_count = 0
560
- else:
561
- no_improve_count += 1
562
-
563
- if no_improve_count >= PATIENCE:
564
- print("πŸ›‘ Early stopping triggered β€” no improvement")
565
- break
566
 
567
  training_log.append({
568
- 'episode': global_ep,
569
- 'task_id': task_id,
570
- 'score': final_score,
571
- 'rolling_avg': rolling,
572
- 'loss': loss,
573
- 'elapsed_min': round(elapsed, 1)
574
  })
575
 
 
576
  try:
577
  latest_ckpt = f"{CONFIG['output_dir']}/latest"
578
  model.save_pretrained(latest_ckpt)
579
  tokenizer.save_pretrained(latest_ckpt)
580
-
581
  state = {
582
- "global_ep": global_ep,
583
- "training_log": training_log,
584
- "episode_scores": episode_scores
585
  }
586
-
587
  tmp = f"{CONFIG['output_dir']}/state_tmp.json"
588
- final = f"{CONFIG['output_dir']}/state.json"
589
-
590
- with open(tmp, "w") as f:
591
  json.dump(state, f)
592
-
593
- os.replace(tmp, final)
594
-
595
  except Exception as e:
596
- print("⚠️ Checkpoint save failed:", e)
597
 
598
  if (ep + 1) % 5 == 0:
599
  delta = rolling - baseline_scores[task_id]['avg']
600
  trend = 'πŸ“ˆ' if delta > 0.02 else 'πŸ“‰' if delta < -0.02 else '➑️'
601
  print(
602
- f' {trend} Ep {ep+1:3d}/{CONFIG["episodes_per_task"]} | '
603
- f'Score: {final_score:.3f} | Roll-10: {rolling:.3f} | '
604
- f'vs baseline: {delta:+.3f} | Loss: {loss:.4f} | {elapsed:.0f}m'
605
  )
606
 
607
- print("\nπŸŽ‰ Training complete!")
 
 
 
 
608
 
609
- # =========================
610
- # πŸ”₯ POST EVAL (MOVED INSIDE)
611
- # =========================
612
- FastLanguageModel.for_inference(model)
613
 
614
- print("Post-training evaluation...")
615
- post_scores = {}
616
 
 
 
 
617
  for task_id in CONFIG['tasks']:
618
- scores = [run_episode(task_id, seed=i*13+7) for i in range(8)]
619
  avg = sum(scores) / len(scores)
620
- post_scores[task_id] = avg
621
- print(f"{task_id}: {avg:.3f}")
622
-
623
- print("Zero-shot:")
624
- for task_id in ['security', 'database', 'failover']:
625
- scores = []
626
- for i in range(5):
627
- try:
628
- scores.append(run_episode(task_id, seed=i*17+5))
629
- except:
630
- scores.append(0.0)
631
- print(f"{task_id}: {sum(scores)/len(scores):.3f}")
632
-
633
- print("βœ… Training + evaluation done")
634
-
 
 
 
 
 
 
635
 
636
- # =========================
637
- # πŸš€ ENTRY POINT (CRITICAL)
638
- # =========================
639
 
 
640
  import threading
641
  import gradio as gr
642
 
643
  def alive():
644
- return "Training is running..."
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
  if __name__ == "__main__":
647
  print("πŸš€ Starting training in background thread...")
648
-
649
- thread = threading.Thread(target=run_training)
650
  thread.start()
651
 
652
- print("🌐 Launching keep-alive server...")
653
-
654
- demo = gr.Interface(fn=alive, inputs=[], outputs="text")
 
 
 
 
 
655
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import sys, os, json, time, random, re, copy
 
2
 
3
+ # ── CRITICAL: Set before ANY import ──────────────────────────────────────────
4
  os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
5
 
6
+ # ── Install dependencies ──────────────────────────────────────────────────────
7
+ import subprocess
8
+ subprocess.run([
9
+ 'pip', 'install', '-q',
10
+ 'unsloth==2025.7.7',
11
+ 'transformers==4.51.3',
12
+ 'accelerate==0.34.2',
13
+ 'peft==0.13.2',
14
+ 'trl==0.14.0',
15
+ 'requests',
16
+ 'matplotlib',
17
+ 'scipy',
18
+ 'gradio==4.44.0',
19
+ 'huggingface_hub',
20
+ ], capture_output=True)
21
+
22
+ # ── Clear stale module cache ──────────────────────────────────────────────────
23
  for mod in list(sys.modules.keys()):
24
  if any(x in mod for x in ['trl','unsloth','transformers','peft']):
25
  del sys.modules[mod]
26
 
27
+ # ── Verify imports ────────────────────────────────────────────────────────────
28
  import unsloth
29
  from unsloth import FastLanguageModel
30
  import transformers, peft, torch
31
 
32
+ print(f"βœ… unsloth {unsloth.__version__}")
33
+ print(f"βœ… transformers {transformers.__version__}")
34
+ print(f"βœ… torch {torch.__version__} | CUDA: {torch.cuda.is_available()}")
35
+ print(f"βœ… UNSLOTH_RETURN_LOGITS = {os.environ['UNSLOTH_RETURN_LOGITS']}")
 
 
 
 
 
 
 
 
 
36
 
37
+ # ── Auth ──────────────────────────────────────────────────────────────────────
38
+ HF_TOKEN = os.environ.get('HF_TOKEN', '')
39
  if HF_TOKEN:
40
+ from huggingface_hub import login
41
  login(token=HF_TOKEN, add_to_git_credential=False)
42
+ print("βœ… Logged in to HuggingFace")
43
  else:
44
+ print("⚠️ HF_TOKEN not set β€” will not push to Hub")
45
 
46
+ # ── Config ────────────────────────────────────────────────────────────────────
47
  CONFIG = {
 
48
  'model_name': 'unsloth/Meta-Llama-3.1-8B-Instruct',
49
+ 'max_seq_length': 2048, # reduced from 3072 β€” safer on L4
50
+ 'load_in_4bit': True, # ALWAYS 4bit β€” L4 has 23.7GB
51
 
 
52
  'env_url': 'https://arijit-07-devops-incident-response.hf.space',
53
  'tasks': ['easy', 'medium', 'hard', 'bonus'],
54
  'episodes_per_task': 40,
55
+ 'max_steps_per_episode': 12,
56
 
57
+ 'learning_rate': 5e-6,
 
58
  'grpo_group_size': 4,
59
  'lora_rank': 32,
60
  'lora_alpha': 64,
61
  'max_grad_norm': 0.5,
62
+ 'kl_coeff': 0.05,
63
 
 
64
  'hf_repo': 'Arijit-07/aria-devops-llama8b',
65
  'output_dir': '/data/outputs',
66
  'save_every_n_episodes': 20,
67
  }
68
 
69
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
70
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
 
 
 
71
 
72
+ # ── Environment Client ────────────────────────────────────────────────────────
73
+ import requests
74
 
75
  BASE_URL = CONFIG['env_url']
76
 
77
  def env_reset(task_id, seed=None):
78
  payload = {'task_id': task_id}
79
+ if seed is not None:
80
+ payload['seed'] = seed
81
  for attempt in range(3):
82
  try:
83
  r = requests.post(f'{BASE_URL}/reset', json=payload, timeout=30)
84
  r.raise_for_status()
85
  return r.json()
86
+ except Exception as e:
87
+ if attempt == 2:
88
+ raise
89
  time.sleep(5)
90
 
91
  def env_step(action):
 
94
  r = requests.post(f'{BASE_URL}/step', json=action, timeout=30)
95
  r.raise_for_status()
96
  return r.json()
97
+ except Exception as e:
98
+ if attempt == 2:
99
+ raise
100
  time.sleep(5)
101
 
102
  def env_state():
 
112
  }
113
 
114
  def sanitize_action(action):
115
+ DEFAULT_SERVICE = "order-service"
 
116
  if not isinstance(action, dict):
117
  return {"action_type": "read_logs", "service": DEFAULT_SERVICE}
 
118
  action_type = action.get("action_type", "").lower()
 
 
 
 
 
119
  if action_type not in VALID_ACTIONS:
120
  action_type = "read_logs"
 
 
121
  service = action.get("service") or action.get("service_name") or DEFAULT_SERVICE
122
+ clean = {"action_type": action_type, "service": service}
 
 
 
 
 
 
123
  for key in ["root_cause", "runbook", "version", "reason",
124
  "query", "ip_range", "table", "column", "target_region"]:
125
  if key in action and isinstance(action[key], str):
126
  clean[key] = action[key]
 
127
  return clean
 
128
 
129
+ # Test connection
130
  health = requests.get(f'{BASE_URL}/health', timeout=15).json()
131
+ print(f"βœ… Environment: {health}")
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # ── System Prompt ─────────────────────────────────────────────────────────────
134
+ SYSTEM_PROMPT = """You are an autonomous DevOps agent.
135
  You MUST return ONLY valid JSON.
136
 
137
+ action_type MUST be one of:
138
+ diagnose, read_logs, read_metrics, read_runbook, search_logs,
139
+ restart_service, rollback, scale_up, alert_oncall, acknowledge,
140
+ noop, block_ip_range, create_index, failover
141
+
142
+ Always include "service" field. Use exact parameter names.
143
+ Output valid JSON only. Example:
144
+ {"action_type": "read_logs", "service": "order-service"}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def observation_to_prompt(obs, task_id):
147
+ # Compact representation to save tokens
148
+ services = obs.get('services', [])
149
+ alerts = obs.get('active_alerts', [])
150
+ evidence = obs.get('evidence_log', [])
151
+
152
+ svc_lines = []
153
+ for s in sorted(services, key=lambda x: x.get('error_rate', 0), reverse=True)[:6]:
154
+ svc_lines.append(f" {s.get('name','')}: {s.get('status','')} err={s.get('error_rate',0):.3f} mem={s.get('memory',0):.1f}%")
155
+
156
+ alert_lines = []
157
+ for a in alerts[:4]:
158
+ alert_lines.append(f" [{a.get('severity','').upper()}] {a.get('service','')}: {a.get('message','')}")
159
+
160
+ ev_lines = []
161
+ for e in evidence[-3:]:
162
+ ev_lines.append(f" [{e.get('action_type','').upper()}] {e.get('content','')[:100]}")
163
+
164
  return (
165
+ f"Task: {task_id} | Step {obs.get('step',0)}/{obs.get('max_steps',15)}\n"
166
+ f"Services:\n" + "\n".join(svc_lines) + "\n"
167
+ f"Alerts:\n" + "\n".join(alert_lines) + "\n"
168
+ + (f"Evidence:\n" + "\n".join(ev_lines) if ev_lines else "")
169
+ + "\nChoose next action as JSON:"
170
  )
171
 
172
+ # ── Action Parser ─────────────────────────────────────────────────────────────
173
+ def parse_action(text):
174
+ text = text.strip()
175
+ for pattern in [
176
+ r'''```json\s*({.*?})\s*```''',
177
+ r'''```\s*({.*?})\s*```''',
178
+ r'''({\s*"action_type"[^}]+})''',
179
+ ]:
180
+ match = re.search(pattern, text, re.DOTALL)
181
+ if match:
182
+ try:
183
+ return json.loads(match.group(1))
184
+ except:
185
+ continue
186
+ try:
187
+ return json.loads(text)
188
+ except:
189
+ return {'action_type': 'noop'}
190
 
191
+ # ── Load Model ────────────────────────────────────────────────────────────────
192
+ os.makedirs(CONFIG['output_dir'], exist_ok=True)
193
 
194
+ # FIX: Delete bad checkpoint if it exists but is incompatible
195
  checkpoint_path = f"{CONFIG['output_dir']}/latest"
196
+ state_path = f"{CONFIG['output_dir']}/state.json"
197
 
198
+ def is_valid_checkpoint(path):
199
+ """Check if checkpoint has required model_type in config.json"""
200
+ config_file = os.path.join(path, 'config.json')
201
+ adapter_file = os.path.join(path, 'adapter_config.json')
202
+ if not os.path.exists(config_file) and not os.path.exists(adapter_file):
203
+ return False
204
+ # Check adapter_config for incompatible fields
205
+ if os.path.exists(adapter_file):
206
+ try:
207
+ with open(adapter_file) as f:
208
+ cfg = json.load(f)
209
+ # alora_invocation_tokens is from old peft version β€” incompatible
210
+ if 'alora_invocation_tokens' in cfg:
211
+ print(f"⚠️ Checkpoint has incompatible peft config field 'alora_invocation_tokens'")
212
+ return False
213
+ except:
214
+ return False
215
+ return True
216
+
217
+ resuming = False
218
+ training_log = []
219
+ episode_scores = {t: [] for t in CONFIG['tasks']}
220
+ global_ep = 0
221
+
222
+ if os.path.exists(checkpoint_path):
223
+ if is_valid_checkpoint(checkpoint_path):
224
+ print("πŸ” Valid checkpoint found β€” resuming...")
225
+ resuming = True
226
+ if os.path.exists(state_path):
227
+ with open(state_path) as f:
228
+ state = json.load(f)
229
+ global_ep = state.get('global_ep', 0)
230
+ training_log = state.get('training_log', [])
231
+ episode_scores = state.get('episode_scores', {t: [] for t in CONFIG['tasks']})
232
+ print(f"βœ… Resumed from episode {global_ep}")
233
+ else:
234
+ print("⚠️ Incompatible checkpoint found β€” deleting and starting fresh")
235
+ import shutil
236
+ shutil.rmtree(checkpoint_path, ignore_errors=True)
237
+ if os.path.exists(state_path):
238
+ os.remove(state_path)
239
+ resuming = False
240
+
241
+ print(f"Loading model: {CONFIG['model_name']} ({'resuming' if resuming else 'fresh'})")
242
+
243
+ if resuming:
244
  model, tokenizer = FastLanguageModel.from_pretrained(
245
  model_name=checkpoint_path,
246
  max_seq_length=CONFIG['max_seq_length'],
247
  dtype=None,
248
+ load_in_4bit=CONFIG['load_in_4bit'], # ALWAYS 4bit
249
+ token=HF_TOKEN if HF_TOKEN else None,
250
  )
251
  else:
 
252
  model, tokenizer = FastLanguageModel.from_pretrained(
253
  model_name=CONFIG['model_name'],
254
  max_seq_length=CONFIG['max_seq_length'],
255
  dtype=None,
256
+ load_in_4bit=CONFIG['load_in_4bit'], # ALWAYS 4bit
257
+ token=HF_TOKEN if HF_TOKEN else None,
258
  )
 
 
259
  model = FastLanguageModel.get_peft_model(
260
  model,
261
  r=CONFIG['lora_rank'],
 
270
 
271
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
272
  total = sum(p.numel() for p in model.parameters())
273
+ print(f"βœ… Model loaded | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
274
+ print(f" VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB used")
275
 
276
+ # Frozen reference model for KL penalty
277
+ ref_model = copy.deepcopy(model)
 
278
  ref_model.eval()
 
279
  for p in ref_model.parameters():
280
  p.requires_grad = False
 
281
  print("βœ… Reference model frozen for KL penalty")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ # ── Episode Runner (for baseline) ─────────────────────────────────────────────
284
  def run_episode(task_id, seed=None, verbose=False):
285
  obs = env_reset(task_id, seed=seed)
286
  total_reward = 0.0
287
  done = False
288
+ FastLanguageModel.for_inference(model)
289
 
290
  for step in range(CONFIG['max_steps_per_episode']):
291
  if done:
292
  break
293
+ messages = [
294
+ {'role': 'system', 'content': SYSTEM_PROMPT},
295
+ {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
296
+ ]
297
+ input_ids = tokenizer.apply_chat_template(
298
+ messages, tokenize=True, add_generation_prompt=True,
299
+ return_tensors='pt'
300
+ )
301
+ input_ids = input_ids[:, -CONFIG['max_seq_length']:].to('cuda')
302
+ with torch.no_grad():
303
+ out = model.generate(
304
+ input_ids, max_new_tokens=100, temperature=0.7,
305
+ do_sample=True, pad_token_id=tokenizer.eos_token_id,
306
+ )
307
+ text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
308
+ action = sanitize_action(parse_action(text))
309
  if verbose:
310
+ print(f" Step {step+1}: {action}")
 
 
 
311
  result = env_step(action)
 
312
  total_reward += result.get('reward', 0.0)
313
  obs = result.get('observation', obs)
314
  done = result.get('done', False)
315
+ return total_reward
316
 
317
+ # ── Pre-Training Baseline ─────────────────────────────────────────────────────
318
+ print("\nRunning pre-training baseline (5 episodes per task)...")
 
 
 
 
 
 
 
 
319
  baseline_scores = {}
 
320
  for task_id in CONFIG['tasks']:
321
+ scores = [run_episode(task_id, seed=i*7+3) for i in range(5)]
322
  avg = sum(scores) / len(scores)
323
  baseline_scores[task_id] = {'scores': scores, 'avg': avg}
324
+ print(f" [{task_id}] baseline: {avg:.3f}")
325
+ print("βœ… Baseline done")
 
 
 
 
 
 
 
 
 
 
326
 
327
+ # ── GRPO Training Functions ───────────────────────────────────────────────────
328
  def run_episode_collect(task_id, seed):
329
+ """FIXED: Score completions on fresh env snapshots β€” no reward gate burn."""
330
  obs = env_reset(task_id, seed=seed)
331
  trajectory = []
332
  done = False
 
333
  FastLanguageModel.for_inference(model)
334
 
335
  for step in range(CONFIG['max_steps_per_episode']):
 
340
  {'role': 'system', 'content': SYSTEM_PROMPT},
341
  {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
342
  ]
 
343
  input_ids = tokenizer.apply_chat_template(
344
+ messages, tokenize=True, add_generation_prompt=True,
 
 
345
  return_tensors='pt'
346
  )
 
347
  input_ids = input_ids[:, -CONFIG['max_seq_length']:].to('cuda')
348
 
349
+ # Generate all completions first β€” no env calls yet
350
+ group_completions, group_texts = [], []
351
  for _ in range(CONFIG['grpo_group_size']):
352
  with torch.no_grad():
353
  out = model.generate(
354
+ input_ids, max_new_tokens=100, temperature=0.9,
355
+ do_sample=True, pad_token_id=tokenizer.eos_token_id,
 
 
 
356
  )
 
357
  gen_ids = out[0][input_ids.shape[1]:]
358
+ group_completions.append(gen_ids)
359
+ group_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True))
360
 
361
+ # Score each on a FRESH env snapshot
362
+ group_rewards = []
363
+ for gen_text in group_texts:
364
  action = sanitize_action(parse_action(gen_text))
 
365
  try:
366
+ env_reset(task_id, seed=seed) # fresh snapshot
367
  res = env_step(action)
368
+ r = res.get('reward', 0.0)
369
  except:
370
+ r = 0.0
371
+ if action.get('action_type', 'noop') != 'noop':
372
+ r += 0.02 # exploration bonus
373
+ group_rewards.append(r)
374
 
375
+ # Advance main episode with best action
376
  best_idx = group_rewards.index(max(group_rewards))
377
+ best_action = sanitize_action(parse_action(group_texts[best_idx]))
 
 
 
378
  try:
379
+ adv_res = env_step(best_action)
380
+ obs = adv_res.get('observation', obs)
381
+ done = adv_res.get('done', False)
382
  except:
383
  done = True
384
 
385
  trajectory.append({
386
  'input_ids': input_ids,
387
  'completions': group_completions,
388
+ 'rewards': group_rewards,
389
  })
390
 
391
  total_reward = sum(max(s['rewards']) for s in trajectory) if trajectory else 0.0
 
393
 
394
 
395
  def update_from_trajectory(trajectory):
396
+ """Single model update from full episode + KL penalty."""
397
  if not trajectory:
398
  return 0.0
399
 
400
  device = next(model.parameters()).device
401
  FastLanguageModel.for_training(model)
 
402
  model.train()
403
  optimizer.zero_grad()
404
 
 
406
 
407
  for step_data in trajectory:
408
  input_ids = step_data['input_ids'].to(device)
 
409
  completions = step_data['completions']
410
+ rewards = step_data['rewards']
411
 
412
+ rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
 
 
413
  if rewards_t.std() > 1e-8:
414
  advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
415
  else:
 
417
 
418
  best_idx = rewards.index(max(rewards))
419
  best_ids = completions[best_idx].to(device)
420
+ best_adv = advantages[best_idx]
421
 
422
  full_ids = torch.cat([input_ids[0], best_ids]).unsqueeze(0)
423
  labels = full_ids.clone()
424
  labels[0, :input_ids.shape[1]] = -100
425
 
426
  outputs = model(full_ids, labels=labels)
427
+ policy_loss = outputs.loss * (-best_adv)
428
 
429
+ # KL penalty
430
  with torch.no_grad():
431
+ ref_out = ref_model(full_ids)
432
+ ref_logits = ref_out.logits[:, input_ids.shape[1]-1:-1, :]
433
+ pol_logits = outputs.logits[:, input_ids.shape[1]-1:-1, :]
434
  kl = torch.nn.functional.kl_div(
435
+ torch.log_softmax(pol_logits, dim=-1),
436
  torch.softmax(ref_logits, dim=-1),
437
  reduction='batchmean'
438
  )
439
+ total_loss = total_loss + policy_loss + CONFIG['kl_coeff'] * kl
 
 
 
440
 
441
  total_loss = total_loss / len(trajectory)
442
  total_loss.backward()
 
443
  torch.nn.utils.clip_grad_norm_(
444
  [p for p in model.parameters() if p.requires_grad],
445
  CONFIG['max_grad_norm']
446
  )
 
447
  optimizer.step()
448
  scheduler.step()
 
449
  return total_loss.item()
450
 
451
+ # ── Optimizer ─────────────────────────────────────────────────────────────────
 
 
452
  from torch.optim import AdamW
453
  from transformers import get_cosine_schedule_with_warmup
454
 
455
  optimizer = AdamW(
456
  [p for p in model.parameters() if p.requires_grad],
457
+ lr=CONFIG['learning_rate'], weight_decay=0.01
 
458
  )
 
459
  total_eps = CONFIG['episodes_per_task'] * len(CONFIG['tasks'])
 
460
  scheduler = get_cosine_schedule_with_warmup(
461
  optimizer,
462
  num_warmup_steps=max(1, total_eps // 10),
463
  num_training_steps=total_eps
464
  )
465
 
466
+ # ── Training Loop ─────────────────────────────────────────────────────────────
 
467
  def run_training():
468
  global global_ep
469
  start_time = time.time()
470
 
471
+ print("=" * 65)
472
+ print("ARIA GRPO TRAINING β€” Llama-3.1-8B")
473
+ print(f"LR={CONFIG['learning_rate']} | KL={CONFIG['kl_coeff']} | Groups={CONFIG['grpo_group_size']}")
474
+ print(f"Strategy: fresh env per completion β†’ episode-level update")
475
+ print("=" * 65)
476
 
477
  for task_id in CONFIG['tasks']:
478
+ print(f"\nπŸ“‹ Task: {task_id.upper()} | Baseline: {baseline_scores[task_id]['avg']:.3f}")
479
+ print("-" * 40)
480
 
481
  for ep in range(CONFIG['episodes_per_task']):
482
  seed = random.randint(0, 9999)
 
485
  loss = update_from_trajectory(trajectory)
486
 
487
  episode_scores[task_id].append(final_score)
488
+ global_ep += 1
 
 
489
  elapsed = (time.time() - start_time) / 60
490
  recent = episode_scores[task_id][-10:]
491
+ rolling = sum(recent) / len(recent) if recent else 0.0
 
 
 
 
 
 
 
 
 
 
492
 
493
  training_log.append({
494
+ 'episode': global_ep, 'task_id': task_id,
495
+ 'score': final_score, 'rolling_avg': rolling,
496
+ 'loss': loss, 'elapsed_min': round(elapsed, 1)
 
 
 
497
  })
498
 
499
+ # Save checkpoint every episode (atomic write)
500
  try:
501
  latest_ckpt = f"{CONFIG['output_dir']}/latest"
502
  model.save_pretrained(latest_ckpt)
503
  tokenizer.save_pretrained(latest_ckpt)
 
504
  state = {
505
+ 'global_ep': global_ep,
506
+ 'training_log': training_log,
507
+ 'episode_scores': episode_scores
508
  }
 
509
  tmp = f"{CONFIG['output_dir']}/state_tmp.json"
510
+ final_path = f"{CONFIG['output_dir']}/state.json"
511
+ with open(tmp, 'w') as f:
 
512
  json.dump(state, f)
513
+ os.replace(tmp, final_path)
 
 
514
  except Exception as e:
515
+ print(f"⚠️ Checkpoint save failed: {e}")
516
 
517
  if (ep + 1) % 5 == 0:
518
  delta = rolling - baseline_scores[task_id]['avg']
519
  trend = 'πŸ“ˆ' if delta > 0.02 else 'πŸ“‰' if delta < -0.02 else '➑️'
520
  print(
521
+ f" {trend} Ep {ep+1:3d}/{CONFIG['episodes_per_task']} | "
522
+ f"Score: {final_score:.3f} | Roll-10: {rolling:.3f} | "
523
+ f"vs baseline: {delta:+.3f} | Loss: {loss:.4f} | {elapsed:.0f}m"
524
  )
525
 
526
+ task_avg = sum(episode_scores[task_id]) / len(episode_scores[task_id])
527
+ base_avg = baseline_scores[task_id]['avg']
528
+ delta = task_avg - base_avg
529
+ result = 'βœ… IMPROVED' if delta > 0.02 else '⚠️ FLAT' if delta > -0.02 else '❌ DEGRADED'
530
+ print(f"\n{result} {task_id}: {base_avg:.3f} β†’ {task_avg:.3f} ({delta:+.3f})")
531
 
532
+ # Save training log after each task
533
+ with open(f"{CONFIG['output_dir']}/training_log.json", 'w') as f:
534
+ json.dump(training_log, f, indent=2)
 
535
 
536
+ print(f"\nπŸŽ‰ Training complete! {(time.time()-start_time)/60:.0f} minutes")
 
537
 
538
+ # Post-training eval
539
+ FastLanguageModel.for_inference(model)
540
+ print("\nPost-training evaluation...")
541
  for task_id in CONFIG['tasks']:
542
+ scores = [run_episode(task_id, seed=i*13+7) for i in range(5)]
543
  avg = sum(scores) / len(scores)
544
+ print(f" [{task_id}] {baseline_scores[task_id]['avg']:.3f} β†’ {avg:.3f} ({avg-baseline_scores[task_id]['avg']:+.3f})")
545
+
546
+ # Push to Hub
547
+ if HF_TOKEN:
548
+ print(f"\nPushing to {CONFIG['hf_repo']}...")
549
+ model.push_to_hub_merged(
550
+ CONFIG['hf_repo'], tokenizer,
551
+ save_method='merged_16bit', token=HF_TOKEN,
552
+ )
553
+ from huggingface_hub import HfApi
554
+ api = HfApi()
555
+ for fname in ['training_log.json']:
556
+ fpath = f"{CONFIG['output_dir']}/{fname}"
557
+ if os.path.exists(fpath):
558
+ api.upload_file(
559
+ path_or_fileobj=fpath,
560
+ path_in_repo=fname,
561
+ repo_id=CONFIG['hf_repo'],
562
+ token=HF_TOKEN,
563
+ )
564
+ print(f"βœ… Model live: https://huggingface.co/{CONFIG['hf_repo']}")
565
 
 
 
 
566
 
567
+ # ── Entry Point ───────────────────────────────────────────────────────────────
568
  import threading
569
  import gradio as gr
570
 
571
  def alive():
572
+ if os.path.exists(f"{CONFIG['output_dir']}/state.json"):
573
+ with open(f"{CONFIG['output_dir']}/state.json") as f:
574
+ state = json.load(f)
575
+ ep = state.get('global_ep', 0)
576
+ log = state.get('training_log', [])
577
+ last = log[-1] if log else {}
578
+ return (
579
+ f"Training running... Episode {ep}/{total_eps}\n"
580
+ f"Last: task={last.get('task_id','-')} "
581
+ f"score={last.get('score',0):.3f} "
582
+ f"roll10={last.get('rolling_avg',0):.3f} "
583
+ f"elapsed={last.get('elapsed_min',0):.0f}m"
584
+ )
585
+ return "Starting up..."
586
 
587
  if __name__ == "__main__":
588
  print("πŸš€ Starting training in background thread...")
589
+ thread = threading.Thread(target=run_training, daemon=True)
 
590
  thread.start()
591
 
592
+ print("🌐 Launching keep-alive server on port 7860...")
593
+ demo = gr.Interface(
594
+ fn=alive,
595
+ inputs=[],
596
+ outputs="text",
597
+ title="ARIA Training Status",
598
+ description="Live training progress for ARIA GRPO fine-tuning"
599
+ )
600
  demo.launch(server_name="0.0.0.0", server_port=7860)