Arijit-07 commited on
Commit
6f67154
Β·
verified Β·
1 Parent(s): 057fab7

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +592 -592
train_model.py CHANGED
@@ -1,113 +1,113 @@
1
- # ── Cell 1: Install dependencies ──────────────────────────────────────────────
2
- import sys, os, json
3
-
4
- # Set logits flag BEFORE any import
5
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
6
-
7
-
8
- # Clear stale module cache
9
- for mod in list(sys.modules.keys()):
10
- if any(x in mod for x in ['trl','unsloth','transformers','peft']):
11
- del sys.modules[mod]
12
-
13
- # Verify β€” unsloth must be imported first
14
- import unsloth
15
- from unsloth import FastLanguageModel
16
- import transformers, peft, torch
17
-
18
- print(f'βœ… unsloth {unsloth.__version__}')
19
- print(f'βœ… transformers {transformers.__version__}')
20
- print(f'βœ… torch {torch.__version__} | CUDA: {torch.cuda.is_available()}')
21
- print(f'βœ… UNSLOTH_RETURN_LOGITS = {os.environ["UNSLOTH_RETURN_LOGITS"]}')
22
- print('βœ… All dependencies installed')
23
-
24
- # ── Cell 2: Config β€” SET YOUR HF TOKEN HERE ───────────────────────────────────
25
- import os
26
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
27
-
28
- HF_TOKEN = os.environ.get('HF_TOKEN') # set as env var or paste here
29
  if HF_TOKEN:
30
  os.environ['HF_TOKEN'] = HF_TOKEN
31
-
32
- from huggingface_hub import login
33
  if HF_TOKEN:
34
  login(token=HF_TOKEN, add_to_git_credential=False)
35
  print('Logged in to HuggingFace')
36
  else:
37
  print('HF_TOKEN not set; continuing without HuggingFace login')
38
-
39
- CONFIG = {
40
- # Model β€” 8B for A100
41
- 'model_name': 'unsloth/Meta-Llama-3.1-8B-Instruct',
42
- 'max_seq_length': 3072,
43
- 'load_in_4bit': True,
44
-
45
- # Environment
46
- 'env_url': 'https://arijit-07-devops-incident-response.hf.space',
47
- 'tasks': ['easy', 'medium', 'hard', 'bonus'],
48
- 'episodes_per_task': 40,
49
- 'max_steps_per_episode': 12, # reduced from 20 β€” tighter episodes
50
-
51
- # Training β€” conservative to prevent catastrophic forgetting
52
- 'learning_rate': 5e-6, # FIXED: was 1e-5, caused degradation
53
- 'grpo_group_size': 4,
54
- 'lora_rank': 32,
55
- 'lora_alpha': 64,
56
- 'max_grad_norm': 0.5,
57
- 'kl_coeff': 0.05, # NEW: prevents catastrophic forgetting
58
-
59
- # Output
60
- 'hf_repo': 'Arijit-07/aria-devops-llama8b',
61
- 'output_dir': './outputs',
62
- 'save_every_n_episodes': 20,
63
- }
64
-
65
- import torch
66
- print(f'GPU: {torch.cuda.get_device_name(0)}')
67
- print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
68
- print(f'Model: {CONFIG["model_name"]} | Tasks: {CONFIG["tasks"]}')
69
- print(f'LR: {CONFIG["learning_rate"]} | KL: {CONFIG["kl_coeff"]}')
70
-
71
- # ── Cell 3: Environment Client ────────────────────────────────────────────────
72
- import requests, json, time, random
73
-
74
- BASE_URL = CONFIG['env_url']
75
-
76
- def env_reset(task_id, seed=None):
77
- payload = {'task_id': task_id}
78
- if seed is not None: payload['seed'] = seed
79
- for attempt in range(3):
80
- try:
81
- r = requests.post(f'{BASE_URL}/reset', json=payload, timeout=30)
82
- r.raise_for_status()
83
- return r.json()
84
- except:
85
- if attempt == 2: raise
86
- time.sleep(5)
87
-
88
- def env_step(action):
89
- for attempt in range(3):
90
- try:
91
- r = requests.post(f'{BASE_URL}/step', json=action, timeout=30)
92
- r.raise_for_status()
93
- return r.json()
94
- except:
95
- if attempt == 2: raise
96
- time.sleep(5)
97
-
98
- def env_state():
99
- r = requests.get(f'{BASE_URL}/state', timeout=30)
100
- r.raise_for_status()
101
- return r.json()
102
-
103
- health = requests.get(f'{BASE_URL}/health', timeout=15).json()
104
- print(f'βœ… Environment: {health}')
105
- test_obs = env_reset('easy', seed=0)
106
- print(f'βœ… Reset OK. Services: {len(test_obs.get("services", []))}')
107
-
108
- # ── Cell 4: System Prompt + Observation Formatter ─────────────────────────────
109
-
110
- print('Config loaded:')
111
  for k, v in CONFIG.items():
112
  print(f' {k}: {v}')
113
 
@@ -122,13 +122,13 @@ def observation_to_prompt(obs, task_id):
122
  f"{json.dumps(obs, indent=2, sort_keys=True)}\n"
123
  "Choose the next valid action as JSON."
124
  )
125
-
126
- # ── Cell 5: Load Llama-3.1-8B with Unsloth ───────────────────────────────────
127
- from unsloth import FastLanguageModel
128
- import torch
129
-
130
- print(f'Loading {CONFIG["model_name"]}...')
131
-
132
  checkpoint_path = f"{CONFIG['output_dir']}/latest"
133
  resuming_from_checkpoint = os.path.exists(checkpoint_path)
134
 
@@ -136,17 +136,17 @@ if resuming_from_checkpoint:
136
  print("πŸ” Resuming from checkpoint...")
137
  model, tokenizer = FastLanguageModel.from_pretrained(
138
  model_name=checkpoint_path,
139
- max_seq_length=CONFIG['max_seq_length'],
140
- dtype=None,
141
- load_in_4bit=CONFIG['load_in_4bit'],
142
- )
143
- else:
144
- print("πŸ†• Starting fresh model...")
145
- model, tokenizer = FastLanguageModel.from_pretrained(
146
- model_name=CONFIG['model_name'],
147
- max_seq_length=CONFIG['max_seq_length'],
148
- dtype=None,
149
- load_in_4bit=CONFIG['load_in_4bit'],
150
  token=HF_TOKEN,
151
  )
152
 
@@ -162,216 +162,216 @@ if not resuming_from_checkpoint:
162
  use_gradient_checkpointing='unsloth',
163
  random_state=42,
164
  )
165
-
166
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
167
- total = sum(p.numel() for p in model.parameters())
168
- print(f'βœ… Model loaded')
169
- # πŸ” Load training state if exists
170
- state_path = f"{CONFIG['output_dir']}/state.json"
171
-
172
- if os.path.exists(state_path):
173
- print("πŸ” Restoring training state...")
174
- with open(state_path, "r") as f:
175
- state = json.load(f)
176
-
177
- global_ep = state.get("global_ep", 0)
178
- training_log = state.get("training_log", [])
179
- episode_scores = state.get("episode_scores", {t: [] for t in CONFIG['tasks']})
180
-
181
- print(f"βœ… Resumed from episode {global_ep}")
182
- print(f"πŸš€ Continuing training from episode {global_ep}")
183
- else:
184
- print("πŸ†• Starting fresh training state")
185
- training_log = []
186
- episode_scores = {t: [] for t in CONFIG['tasks']}
187
- global_ep = 0
188
- print(f' Trainable: {trainable:,} ({100*trainable/total:.2f}%)')
189
- print(f' VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB used')
190
-
191
- # ── Cell 6: Action Parser + Episode Runner ────────────────────────────────────
192
- import re
193
-
194
- def parse_action(text):
195
- text = text.strip()
196
- for pattern in [
197
- r'```json\s*({.*?})\s*```',
198
- r'```\s*({.*?})\s*```',
199
- r'({\s*"action_type"[^}]+})',
200
- ]:
201
- match = re.search(pattern, text, re.DOTALL)
202
- if match:
203
- try: return json.loads(match.group(1))
204
- except: continue
205
- try: return json.loads(text)
206
- except: return {'action_type': 'noop'}
207
-
208
- def generate_action(obs, task_id, temperature=0.7):
209
- messages = [
210
- {'role': 'system', 'content': SYSTEM_PROMPT},
211
- {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
212
- ]
213
- input_ids = tokenizer.apply_chat_template(
214
- messages, tokenize=True, add_generation_prompt=True,
215
- return_tensors='pt'
216
- ).to('cuda')
217
- FastLanguageModel.for_inference(model)
218
- with torch.no_grad():
219
- out = model.generate(
220
- input_ids, max_new_tokens=150,
221
- temperature=temperature, do_sample=True,
222
- pad_token_id=tokenizer.eos_token_id,
223
- )
224
- generated = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
225
- return parse_action(generated), generated
226
-
227
- def run_episode(task_id, seed=None, verbose=False):
228
- obs = env_reset(task_id, seed=seed)
229
- total_reward = 0.0
230
- done = False
231
- for step in range(CONFIG['max_steps_per_episode']):
232
- if done: break
233
- action, _ = generate_action(obs, task_id)
234
- if verbose: print(f' Step {step+1}: {action}')
235
- result = env_step(action)
236
- total_reward += result.get('reward', 0.0)
237
- obs = result.get('observation', obs)
238
- done = result.get('done', False)
239
- state = env_state()
240
- return state.get('current_score', total_reward)
241
-
242
- print('βœ… Episode runner ready')
243
- print('Testing one episode...')
244
- test_score = run_episode('easy', seed=99, verbose=True)
245
- print(f'Test score: {test_score:.3f}')
246
-
247
- # ── Cell 7: Pre-Training Baseline ────────────────────────────────────────────
248
- print('Running pre-training baseline (8 episodes per task)...')
249
- baseline_scores = {}
250
-
251
- for task_id in CONFIG['tasks']:
252
- scores = [run_episode(task_id, seed=i*7+3) for i in range(8)]
253
- avg = sum(scores) / len(scores)
254
- baseline_scores[task_id] = {'scores': scores, 'avg': avg}
255
- print(f' [{task_id}] baseline: {avg:.3f} (min={min(scores):.3f} max={max(scores):.3f})')
256
-
257
- print('\nβœ… Baseline done. Starting training...')
258
-
259
-
260
-
261
- import torch
262
- assert torch.cuda.is_available(), "GPU NOT DETECTED!"
263
- print("Using GPU:", torch.cuda.get_device_name(0))
264
-
265
-
266
- # ── Cell 8: GRPO Training Loop (FIXED β€” Episode-level updates + KL) ──────────
267
- from torch.optim import AdamW
268
- from transformers import get_cosine_schedule_with_warmup
269
- import os, time, random, copy, json
270
-
271
- os.makedirs(CONFIG['output_dir'], exist_ok=True)
272
-
273
  # Frozen reference model for KL penalty
274
  model_device = next(model.parameters()).device
275
  ref_model = copy.deepcopy(model).to(model_device)
276
  ref_model.eval()
277
- for p in ref_model.parameters():
278
- p.requires_grad = False
279
- ref_model.eval()
280
- print('βœ… Reference model frozen for KL penalty')
281
-
282
- optimizer = AdamW(
283
- [p for p in model.parameters() if p.requires_grad],
284
- lr=CONFIG['learning_rate'], weight_decay=0.01
285
- )
286
- total_eps = CONFIG['episodes_per_task'] * len(CONFIG['tasks'])
287
- scheduler = get_cosine_schedule_with_warmup(
288
- optimizer,
289
- num_warmup_steps=max(1, total_eps // 10),
290
- num_training_steps=total_eps
291
- )
292
-
293
-
294
- start_time = time.time()
295
-
296
- print('=' * 65)
297
- print('ARIA GRPO TRAINING β€” Llama-3.1-8B')
298
- print(f'LR={CONFIG["learning_rate"]} | KL={CONFIG["kl_coeff"]} | Groups={CONFIG["grpo_group_size"]}')
299
- print(f'Strategy: collect full episode β†’ score on fresh env β†’ update once')
300
- print('=' * 65)
301
-
302
- def run_episode_collect(task_id, seed):
303
- """
304
- FIXED: group completions scored on FRESH env snapshots.
305
- Only best action advances main episode.
306
- """
307
- obs = env_reset(task_id, seed=seed)
308
- trajectory = []
309
- done = False
310
-
311
- FastLanguageModel.for_inference(model)
312
-
313
- for step in range(CONFIG['max_steps_per_episode']):
314
- if done:
315
- break
316
-
317
- messages = [
318
- {'role': 'system', 'content': SYSTEM_PROMPT},
319
- {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
320
- ]
321
- input_ids = tokenizer.apply_chat_template(
322
- messages, tokenize=True, add_generation_prompt=True,
323
- return_tensors='pt'
324
- ).to('cuda')
325
-
326
- # Generate all completions first β€” no env calls yet
327
- group_completions, group_texts = [], []
328
- for _ in range(CONFIG['grpo_group_size']):
329
- with torch.no_grad():
330
- out = model.generate(
331
- input_ids, max_new_tokens=128, temperature=0.9,
332
- do_sample=True, pad_token_id=tokenizer.eos_token_id,
333
- )
334
- gen_ids = out[0][input_ids.shape[1]:]
335
- group_completions.append(gen_ids)
336
- group_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True))
337
-
338
- # Score each completion on a FRESH env snapshot
339
- group_rewards = []
340
- for gen_text in group_texts:
341
- action = parse_action(gen_text)
342
- try:
343
- env_reset(task_id, seed=seed) # fresh snapshot
344
- res = env_step(action)
345
- r = res.get('reward', 0.0)
346
- except:
347
- r = 0.0
348
- if action.get('action_type', 'noop') != 'noop':
349
- r += 0.02 # exploration bonus
350
- group_rewards.append(r)
351
-
352
- # Advance main episode with best action
353
- best_idx = group_rewards.index(max(group_rewards))
354
- best_action = parse_action(group_texts[best_idx])
355
- try:
356
- adv_res = env_step(best_action)
357
- obs = adv_res.get('observation', obs)
358
- done = adv_res.get('done', False)
359
- except:
360
- done = True
361
-
362
- trajectory.append({
363
- 'input_ids': input_ids,
364
- 'completions': group_completions,
365
- 'rewards': group_rewards,
366
- })
367
-
368
- # Get final score from accumulated rewards
369
- total_reward = sum(max(s['rewards']) for s in trajectory) if trajectory else 0.0
370
- return trajectory, total_reward
371
-
372
-
373
- def update_from_trajectory(trajectory):
374
- """Single model update from full episode with KL penalty."""
375
  if not trajectory:
376
  return 0.0
377
 
@@ -388,274 +388,274 @@ def update_from_trajectory(trajectory):
388
  rewards = step_data['rewards']
389
 
390
  rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
391
- if rewards_t.std() > 1e-8:
392
- advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
393
- else:
394
- advantages = rewards_t - rewards_t.mean()
395
-
396
- best_idx = rewards.index(max(rewards))
397
  best_ids = completions[best_idx].to(device)
398
- best_adv = advantages[best_idx]
399
-
400
- full_ids = torch.cat([input_ids[0], best_ids]).unsqueeze(0)
401
- labels = full_ids.clone()
402
- labels[0, :input_ids.shape[1]] = -100
403
-
404
- outputs = model(full_ids, labels=labels)
405
- policy_loss = outputs.loss * (-best_adv)
406
-
407
- # KL penalty vs reference model
408
- with torch.no_grad():
409
- ref_out = ref_model(full_ids)
410
- ref_logits = ref_out.logits[:, input_ids.shape[1]-1:-1, :]
411
- pol_logits = outputs.logits[:, input_ids.shape[1]-1:-1, :]
412
- kl = torch.nn.functional.kl_div(
413
- torch.log_softmax(pol_logits, dim=-1),
414
- torch.softmax(ref_logits, dim=-1),
415
- reduction='batchmean'
416
- )
417
- total_loss = total_loss + policy_loss + CONFIG['kl_coeff'] * kl
418
-
419
- total_loss = total_loss / len(trajectory)
420
- total_loss.backward()
421
- torch.nn.utils.clip_grad_norm_(
422
- [p for p in model.parameters() if p.requires_grad],
423
- CONFIG['max_grad_norm']
424
- )
425
- optimizer.step()
426
- scheduler.step()
427
- return total_loss.item()
428
-
429
-
430
-
431
- # ── Main training loop ─────────────────────────────────────────────────────
432
- best_score = -1e9
433
- no_improve_count = 0
434
- PATIENCE = 15
435
- for task_id in CONFIG['tasks']:
436
- print(f'\nπŸ“‹ Task: {task_id.upper()} | Baseline: {baseline_scores[task_id]["avg"]:.3f}')
437
- print('-' * 40)
438
-
439
- for ep in range(CONFIG['episodes_per_task']):
440
- seed = random.randint(0, 9999)
441
-
442
- trajectory, final_score = run_episode_collect(task_id, seed)
443
- loss = update_from_trajectory(trajectory)
444
-
445
-
446
- episode_scores[task_id].append(final_score)
447
- global_ep += 1
448
- elapsed = (time.time() - start_time) / 60
449
  recent = episode_scores[task_id][-10:]
450
  rolling = sum(recent) / len(recent) if recent else final_score
451
- # 🧠 Early stopping logic
452
- if rolling > best_score:
453
- best_score = rolling
454
- no_improve_count = 0
455
- else:
456
- no_improve_count += 1
457
-
458
- if no_improve_count >= PATIENCE:
459
- print("πŸ›‘ Early stopping triggered β€” no improvement")
460
- break
461
-
462
- training_log.append({
463
- 'episode': global_ep, 'task_id': task_id,
464
- 'score': final_score, 'rolling_avg': rolling,
465
- 'loss': loss, 'elapsed_min': round(elapsed, 1)
466
- })
467
- # πŸ”₯ FAIL-SAFE CHECKPOINT (every episode)
468
- try:
469
- latest_ckpt = f"{CONFIG['output_dir']}/latest"
470
-
471
- # βœ… Save model + tokenizer FIRST (atomic checkpoint)
472
- model.save_pretrained(latest_ckpt)
473
- tokenizer.save_pretrained(latest_ckpt)
474
-
475
- # πŸ’Ύ Then save training state
476
- state = {
477
- "global_ep": global_ep,
478
- "training_log": training_log,
479
- "episode_scores": episode_scores
480
- }
481
-
482
- tmp_path = f"{CONFIG['output_dir']}/state_tmp.json"
483
- final_path = f"{CONFIG['output_dir']}/state.json"
484
-
485
- with open(tmp_path, "w") as f:
486
- json.dump(state, f)
487
-
488
- os.replace(tmp_path, final_path) # atomic replace
489
-
490
- except Exception as e:
491
- print("⚠️ Checkpoint save failed:", e)
492
-
493
- if (ep + 1) % 5 == 0:
494
- delta = rolling - baseline_scores[task_id]['avg']
495
- trend = 'πŸ“ˆ' if delta > 0.02 else 'πŸ“‰' if delta < -0.02 else '➑️'
496
- print(
497
- f' {trend} Ep {ep+1:3d}/{CONFIG["episodes_per_task"]} | '
498
- f'Score: {final_score:.3f} | Roll-10: {rolling:.3f} | '
499
- f'vs baseline: {delta:+.3f} | Loss: {loss:.4f} | {elapsed:.0f}m'
500
- )
501
-
502
- if global_ep % CONFIG['save_every_n_episodes'] == 0:
503
- ckpt = f'{CONFIG["output_dir"]}/ep{global_ep}'
504
- model.save_pretrained(ckpt)
505
- tokenizer.save_pretrained(ckpt)
506
- print(f' πŸ’Ύ Checkpoint ep{global_ep}')
507
-
508
- task_avg = sum(episode_scores[task_id]) / len(episode_scores[task_id])
509
- base_avg = baseline_scores[task_id]['avg']
510
- delta = task_avg - base_avg
511
- result = 'βœ… IMPROVED' if delta > 0.02 else '⚠️ FLAT' if delta > -0.02 else '❌ DEGRADED'
512
- print(f'\n{result} {task_id}: {base_avg:.3f} β†’ {task_avg:.3f} ({delta:+.3f})')
513
-
514
- # Save training log so far (in case of crash)
515
- with open(f'{CONFIG["output_dir"]}/training_log.json', 'w') as f:
516
- json.dump(training_log, f, indent=2)
517
- print(' πŸ“ Training log saved')
518
-
519
- print(f'\nπŸŽ‰ Training complete! {(time.time()-start_time)/60:.0f} minutes')
520
-
521
- # ── Cell 9: Post-Training Eval + Generalization ───────────────────────────────
522
- FastLanguageModel.for_inference(model)
523
- print('Post-training evaluation (8 episodes per task, unseen seeds)...')
524
-
525
- post_scores = {}
526
- for task_id in CONFIG['tasks']:
527
- scores = [run_episode(task_id, seed=i*13+7) for i in range(8)]
528
- avg = sum(scores) / len(scores)
529
- post_scores[task_id] = {'scores': scores, 'avg': avg}
530
- delta = avg - baseline_scores[task_id]['avg']
531
- print(f' [{task_id}] {baseline_scores[task_id]["avg"]:.3f} β†’ {avg:.3f} '
532
- f'({("+" if delta>=0 else "")}{delta:.3f})')
533
-
534
- print('\nZero-shot generalization (ARIA tasks β€” never seen in training):')
535
- gen_scores = {}
536
- for task_id in ['security', 'database', 'failover']:
537
- scores = []
538
- for i in range(5):
539
- try: scores.append(run_episode(task_id, seed=i*17+5))
540
- except: scores.append(0.0)
541
- avg = sum(scores) / len(scores)
542
- gen_scores[task_id] = avg
543
- print(f' [{task_id}] zero-shot: {avg:.3f}')
544
-
545
- # ── Cell 10: Learning Curve Visualization ────────────────────────────────────
546
- import matplotlib.pyplot as plt
547
- import matplotlib.gridspec as gridspec
548
- import numpy as np
549
-
550
- fig = plt.figure(figsize=(20, 12))
551
- fig.patch.set_facecolor('#0d1117')
552
- gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.4, wspace=0.35)
553
- COLORS = {'easy':'#4caf50','medium':'#ff9800','hard':'#f44336','bonus':'#9c27b0'}
554
-
555
- def style_ax(ax, title):
556
- ax.set_facecolor('#161b22')
557
- ax.set_title(title, color='white', fontsize=12, fontweight='bold', pad=10)
558
- ax.tick_params(colors='#8b949e', labelsize=9)
559
- for spine in ax.spines.values(): spine.set_color('#30363d')
560
- ax.spines['top'].set_visible(False)
561
- ax.spines['right'].set_visible(False)
562
- ax.grid(True, alpha=0.1, color='#30363d')
563
-
564
- for idx, task_id in enumerate(CONFIG['tasks']):
565
- row, col = divmod(idx, 3)
566
- ax = fig.add_subplot(gs[row, col])
567
- style_ax(ax, f'Task: {task_id.upper()}')
568
- task_log = [e for e in training_log if e['task_id'] == task_id]
569
- eps = [e['episode'] for e in task_log]
570
- scores = [e['score'] for e in task_log]
571
- rolling = [e['rolling_avg'] for e in task_log]
572
- color = COLORS.get(task_id, '#58a6ff')
573
- ax.plot(eps, scores, alpha=0.15, color=color, linewidth=1)
574
- ax.plot(eps, rolling, color=color, linewidth=2.5, label='Rolling avg (10)')
575
- ax.axhline(y=baseline_scores[task_id]['avg'], color='#f85149',
576
- linestyle='--', linewidth=1.5, label='Baseline')
577
- ax.axhline(y=post_scores[task_id]['avg'], color='#3fb950',
578
- linestyle='--', linewidth=1.5, label='Post-training')
579
- ax.set_ylim(0, 1.05)
580
- ax.set_xlabel('Episode', color='#8b949e', fontsize=9)
581
- ax.set_ylabel('Score', color='#8b949e', fontsize=9)
582
- ax.legend(facecolor='#161b22', labelcolor='white', fontsize=8)
583
-
584
- ax5 = fig.add_subplot(gs[1, 1])
585
- style_ax(ax5, 'Before vs After (all tasks)')
586
- x = np.arange(len(CONFIG['tasks']))
587
- w = 0.35
588
- before_v = [baseline_scores[t]['avg'] for t in CONFIG['tasks']]
589
- after_v = [post_scores[t]['avg'] for t in CONFIG['tasks']]
590
- b1 = ax5.bar(x-w/2, before_v, w, label='Before', color='#f85149', alpha=0.85)
591
- b2 = ax5.bar(x+w/2, after_v, w, label='After', color='#3fb950', alpha=0.85)
592
- for bar, v in zip(b1, before_v):
593
- ax5.text(bar.get_x()+bar.get_width()/2., v+0.01, f'{v:.2f}',
594
- ha='center', color='white', fontsize=8)
595
- for bar, v in zip(b2, after_v):
596
- ax5.text(bar.get_x()+bar.get_width()/2., v+0.01, f'{v:.2f}',
597
- ha='center', color='white', fontsize=8)
598
- ax5.set_xticks(x)
599
- ax5.set_xticklabels(CONFIG['tasks'], color='#8b949e')
600
- ax5.set_ylim(0, 1.15)
601
- ax5.legend(facecolor='#161b22', labelcolor='white', fontsize=9)
602
-
603
- ax6 = fig.add_subplot(gs[1, 2])
604
- ax6.set_facecolor('#161b22')
605
- ax6.set_title('Summary', color='white', fontsize=12, fontweight='bold')
606
- ax6.axis('off')
607
- lines = [
608
- ('Model', 'Llama-3.1-8B (Unsloth 4-bit)'),
609
- ('Algorithm', 'GRPO'),
610
- ('LoRA rank', str(CONFIG['lora_rank'])),
611
- ('Total episodes', str(global_ep)),
612
- ('', ''),
613
- ]
614
- for t in CONFIG['tasks']:
615
- b = baseline_scores[t]['avg']; a = post_scores[t]['avg']
616
- lines.append((f' {t}', f'{b:.2f} β†’ {a:.2f} (+{a-b:.2f})'))
617
- if gen_scores:
618
- lines += [('', ''), ('Zero-shot', '')]
619
- for t, s in gen_scores.items():
620
- lines.append((f' {t}', f'{s:.2f}'))
621
- y = 0.95
622
- for label, val in lines:
623
- if not label: y -= 0.04; continue
624
- ax6.text(0.02, y, label+':', color='#8b949e', fontsize=9,
625
- transform=ax6.transAxes, fontweight='bold')
626
- ax6.text(0.52, y, val, color='#c9d1d9', fontsize=9, transform=ax6.transAxes)
627
- y -= 0.08
628
-
629
- fig.suptitle('ARIA β€” DevOps Incident Response\nGRPO Training (Llama-3.1-8B Full Curriculum)',
630
- color='white', fontsize=16, fontweight='bold', y=0.98)
631
- plt.savefig('training_curve_8b.png', dpi=150, bbox_inches='tight', facecolor='#0d1117')
632
- print('βœ… Saved training_curve_8b.png')
633
- plt.show()
634
-
635
- # ── Cell 11: Save to HuggingFace Hub ─────────────────────────────────────────
636
- from huggingface_hub import HfApi
637
- import json
638
-
639
- print(f'Pushing to {CONFIG["hf_repo"]}...')
640
- FastLanguageModel.for_inference(model)
641
-
642
- model.save_pretrained_merged(CONFIG['output_dir'], tokenizer, save_method='merged_16bit')
643
- model.push_to_hub_merged(CONFIG['hf_repo'], tokenizer,
644
- save_method='merged_16bit', token=HF_TOKEN)
645
- print(f'βœ… Model: https://huggingface.co/{CONFIG["hf_repo"]}')
646
-
647
- api = HfApi()
648
- for fname in ['training_curve_8b.png']:
649
- api.upload_file(path_or_fileobj=fname, path_in_repo=fname,
650
- repo_id=CONFIG['hf_repo'], token=HF_TOKEN)
651
- print(f'βœ… {fname} uploaded')
652
-
653
- with open('training_log_8b.json', 'w') as f:
654
- json.dump(training_log, f, indent=2)
655
- api.upload_file(path_or_fileobj='training_log_8b.json',
656
- path_in_repo='training_log_8b.json',
657
- repo_id=CONFIG['hf_repo'], token=HF_TOKEN)
658
-
659
  print('\nπŸŽ‰ DONE! Shut down the RunPod instance now to stop billing.')
660
  print(f' Model: https://huggingface.co/{CONFIG["hf_repo"]}')
661
  print(f' Curve: check training_curve_8b.png in the repo')
 
1
+ # ── Cell 1: Install dependencies ──────────────────────────────────────────────
2
+ import sys, os, json
3
+
4
+ # Set logits flag BEFORE any import
5
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
6
+
7
+
8
+ # Clear stale module cache
9
+ for mod in list(sys.modules.keys()):
10
+ if any(x in mod for x in ['trl','unsloth','transformers','peft']):
11
+ del sys.modules[mod]
12
+
13
+ # Verify β€” unsloth must be imported first
14
+ import unsloth
15
+ from unsloth import FastLanguageModel
16
+ import transformers, peft, torch
17
+
18
+ print(f'βœ… unsloth {unsloth.__version__}')
19
+ print(f'βœ… transformers {transformers.__version__}')
20
+ print(f'βœ… torch {torch.__version__} | CUDA: {torch.cuda.is_available()}')
21
+ print(f'βœ… UNSLOTH_RETURN_LOGITS = {os.environ["UNSLOTH_RETURN_LOGITS"]}')
22
+ print('βœ… All dependencies installed')
23
+
24
+ # ── Cell 2: Config β€” SET YOUR HF TOKEN HERE ───────────────────────────────────
25
+ import os
26
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
27
+
28
+ HF_TOKEN = os.environ.get('HF_TOKEN') # set as env var or paste here
29
  if HF_TOKEN:
30
  os.environ['HF_TOKEN'] = HF_TOKEN
31
+
32
+ from huggingface_hub import login
33
  if HF_TOKEN:
34
  login(token=HF_TOKEN, add_to_git_credential=False)
35
  print('Logged in to HuggingFace')
36
  else:
37
  print('HF_TOKEN not set; continuing without HuggingFace login')
38
+
39
+ CONFIG = {
40
+ # Model β€” 8B for A100
41
+ 'model_name': 'unsloth/Meta-Llama-3.1-8B-Instruct',
42
+ 'max_seq_length': 3072,
43
+ 'load_in_4bit': True,
44
+
45
+ # Environment
46
+ 'env_url': 'https://arijit-07-devops-incident-response.hf.space',
47
+ 'tasks': ['easy', 'medium', 'hard', 'bonus'],
48
+ 'episodes_per_task': 40,
49
+ 'max_steps_per_episode': 12, # reduced from 20 β€” tighter episodes
50
+
51
+ # Training β€” conservative to prevent catastrophic forgetting
52
+ 'learning_rate': 5e-6, # FIXED: was 1e-5, caused degradation
53
+ 'grpo_group_size': 4,
54
+ 'lora_rank': 32,
55
+ 'lora_alpha': 64,
56
+ 'max_grad_norm': 0.5,
57
+ 'kl_coeff': 0.05, # NEW: prevents catastrophic forgetting
58
+
59
+ # Output
60
+ 'hf_repo': 'Arijit-07/aria-devops-llama8b',
61
+ 'output_dir': '/data/outputs',
62
+ 'save_every_n_episodes': 20,
63
+ }
64
+
65
+ import torch
66
+ print(f'GPU: {torch.cuda.get_device_name(0)}')
67
+ print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
68
+ print(f'Model: {CONFIG["model_name"]} | Tasks: {CONFIG["tasks"]}')
69
+ print(f'LR: {CONFIG["learning_rate"]} | KL: {CONFIG["kl_coeff"]}')
70
+
71
+ # ── Cell 3: Environment Client ────────────────────────────────────────────────
72
+ import requests, json, time, random
73
+
74
+ BASE_URL = CONFIG['env_url']
75
+
76
+ def env_reset(task_id, seed=None):
77
+ payload = {'task_id': task_id}
78
+ if seed is not None: payload['seed'] = seed
79
+ for attempt in range(3):
80
+ try:
81
+ r = requests.post(f'{BASE_URL}/reset', json=payload, timeout=30)
82
+ r.raise_for_status()
83
+ return r.json()
84
+ except:
85
+ if attempt == 2: raise
86
+ time.sleep(5)
87
+
88
+ def env_step(action):
89
+ for attempt in range(3):
90
+ try:
91
+ r = requests.post(f'{BASE_URL}/step', json=action, timeout=30)
92
+ r.raise_for_status()
93
+ return r.json()
94
+ except:
95
+ if attempt == 2: raise
96
+ time.sleep(5)
97
+
98
+ def env_state():
99
+ r = requests.get(f'{BASE_URL}/state', timeout=30)
100
+ r.raise_for_status()
101
+ return r.json()
102
+
103
+ health = requests.get(f'{BASE_URL}/health', timeout=15).json()
104
+ print(f'βœ… Environment: {health}')
105
+ test_obs = env_reset('easy', seed=0)
106
+ print(f'βœ… Reset OK. Services: {len(test_obs.get("services", []))}')
107
+
108
+ # ── Cell 4: System Prompt + Observation Formatter ─────────────────────────────
109
+
110
+ print('Config loaded:')
111
  for k, v in CONFIG.items():
112
  print(f' {k}: {v}')
113
 
 
122
  f"{json.dumps(obs, indent=2, sort_keys=True)}\n"
123
  "Choose the next valid action as JSON."
124
  )
125
+
126
+ # ── Cell 5: Load Llama-3.1-8B with Unsloth ───────────────────────────────────
127
+ from unsloth import FastLanguageModel
128
+ import torch
129
+
130
+ print(f'Loading {CONFIG["model_name"]}...')
131
+
132
  checkpoint_path = f"{CONFIG['output_dir']}/latest"
133
  resuming_from_checkpoint = os.path.exists(checkpoint_path)
134
 
 
136
  print("πŸ” Resuming from checkpoint...")
137
  model, tokenizer = FastLanguageModel.from_pretrained(
138
  model_name=checkpoint_path,
139
+ max_seq_length=CONFIG['max_seq_length'],
140
+ dtype=None,
141
+ load_in_4bit=CONFIG['load_in_4bit'],
142
+ )
143
+ else:
144
+ print("πŸ†• Starting fresh model...")
145
+ model, tokenizer = FastLanguageModel.from_pretrained(
146
+ model_name=CONFIG['model_name'],
147
+ max_seq_length=CONFIG['max_seq_length'],
148
+ dtype=None,
149
+ load_in_4bit=CONFIG['load_in_4bit'],
150
  token=HF_TOKEN,
151
  )
152
 
 
162
  use_gradient_checkpointing='unsloth',
163
  random_state=42,
164
  )
165
+
166
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
167
+ total = sum(p.numel() for p in model.parameters())
168
+ print(f'βœ… Model loaded')
169
+ # πŸ” Load training state if exists
170
+ state_path = f"{CONFIG['output_dir']}/state.json"
171
+
172
+ if os.path.exists(state_path):
173
+ print("πŸ” Restoring training state...")
174
+ with open(state_path, "r") as f:
175
+ state = json.load(f)
176
+
177
+ global_ep = state.get("global_ep", 0)
178
+ training_log = state.get("training_log", [])
179
+ episode_scores = state.get("episode_scores", {t: [] for t in CONFIG['tasks']})
180
+
181
+ print(f"βœ… Resumed from episode {global_ep}")
182
+ print(f"πŸš€ Continuing training from episode {global_ep}")
183
+ else:
184
+ print("πŸ†• Starting fresh training state")
185
+ training_log = []
186
+ episode_scores = {t: [] for t in CONFIG['tasks']}
187
+ global_ep = 0
188
+ print(f' Trainable: {trainable:,} ({100*trainable/total:.2f}%)')
189
+ print(f' VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB used')
190
+
191
+ # ── Cell 6: Action Parser + Episode Runner ────────────────────────────────────
192
+ import re
193
+
194
+ def parse_action(text):
195
+ text = text.strip()
196
+ for pattern in [
197
+ r'```json\s*({.*?})\s*```',
198
+ r'```\s*({.*?})\s*```',
199
+ r'({\s*"action_type"[^}]+})',
200
+ ]:
201
+ match = re.search(pattern, text, re.DOTALL)
202
+ if match:
203
+ try: return json.loads(match.group(1))
204
+ except: continue
205
+ try: return json.loads(text)
206
+ except: return {'action_type': 'noop'}
207
+
208
+ def generate_action(obs, task_id, temperature=0.7):
209
+ messages = [
210
+ {'role': 'system', 'content': SYSTEM_PROMPT},
211
+ {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
212
+ ]
213
+ input_ids = tokenizer.apply_chat_template(
214
+ messages, tokenize=True, add_generation_prompt=True,
215
+ return_tensors='pt'
216
+ ).to('cuda')
217
+ FastLanguageModel.for_inference(model)
218
+ with torch.no_grad():
219
+ out = model.generate(
220
+ input_ids, max_new_tokens=150,
221
+ temperature=temperature, do_sample=True,
222
+ pad_token_id=tokenizer.eos_token_id,
223
+ )
224
+ generated = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
225
+ return parse_action(generated), generated
226
+
227
+ def run_episode(task_id, seed=None, verbose=False):
228
+ obs = env_reset(task_id, seed=seed)
229
+ total_reward = 0.0
230
+ done = False
231
+ for step in range(CONFIG['max_steps_per_episode']):
232
+ if done: break
233
+ action, _ = generate_action(obs, task_id)
234
+ if verbose: print(f' Step {step+1}: {action}')
235
+ result = env_step(action)
236
+ total_reward += result.get('reward', 0.0)
237
+ obs = result.get('observation', obs)
238
+ done = result.get('done', False)
239
+ state = env_state()
240
+ return state.get('current_score', total_reward)
241
+
242
+ print('βœ… Episode runner ready')
243
+ print('Testing one episode...')
244
+ test_score = run_episode('easy', seed=99, verbose=True)
245
+ print(f'Test score: {test_score:.3f}')
246
+
247
+ # ── Cell 7: Pre-Training Baseline ────────────────────────────────────────────
248
+ print('Running pre-training baseline (8 episodes per task)...')
249
+ baseline_scores = {}
250
+
251
+ for task_id in CONFIG['tasks']:
252
+ scores = [run_episode(task_id, seed=i*7+3) for i in range(8)]
253
+ avg = sum(scores) / len(scores)
254
+ baseline_scores[task_id] = {'scores': scores, 'avg': avg}
255
+ print(f' [{task_id}] baseline: {avg:.3f} (min={min(scores):.3f} max={max(scores):.3f})')
256
+
257
+ print('\nβœ… Baseline done. Starting training...')
258
+
259
+
260
+
261
+ import torch
262
+ assert torch.cuda.is_available(), "GPU NOT DETECTED!"
263
+ print("Using GPU:", torch.cuda.get_device_name(0))
264
+
265
+
266
+ # ── Cell 8: GRPO Training Loop (FIXED β€” Episode-level updates + KL) ──────────
267
+ from torch.optim import AdamW
268
+ from transformers import get_cosine_schedule_with_warmup
269
+ import os, time, random, copy, json
270
+
271
+ os.makedirs(CONFIG['output_dir'], exist_ok=True)
272
+
273
  # Frozen reference model for KL penalty
274
  model_device = next(model.parameters()).device
275
  ref_model = copy.deepcopy(model).to(model_device)
276
  ref_model.eval()
277
+ for p in ref_model.parameters():
278
+ p.requires_grad = False
279
+ ref_model.eval()
280
+ print('βœ… Reference model frozen for KL penalty')
281
+
282
+ optimizer = AdamW(
283
+ [p for p in model.parameters() if p.requires_grad],
284
+ lr=CONFIG['learning_rate'], weight_decay=0.01
285
+ )
286
+ total_eps = CONFIG['episodes_per_task'] * len(CONFIG['tasks'])
287
+ scheduler = get_cosine_schedule_with_warmup(
288
+ optimizer,
289
+ num_warmup_steps=max(1, total_eps // 10),
290
+ num_training_steps=total_eps
291
+ )
292
+
293
+
294
+ start_time = time.time()
295
+
296
+ print('=' * 65)
297
+ print('ARIA GRPO TRAINING β€” Llama-3.1-8B')
298
+ print(f'LR={CONFIG["learning_rate"]} | KL={CONFIG["kl_coeff"]} | Groups={CONFIG["grpo_group_size"]}')
299
+ print(f'Strategy: collect full episode β†’ score on fresh env β†’ update once')
300
+ print('=' * 65)
301
+
302
+ def run_episode_collect(task_id, seed):
303
+ """
304
+ FIXED: group completions scored on FRESH env snapshots.
305
+ Only best action advances main episode.
306
+ """
307
+ obs = env_reset(task_id, seed=seed)
308
+ trajectory = []
309
+ done = False
310
+
311
+ FastLanguageModel.for_inference(model)
312
+
313
+ for step in range(CONFIG['max_steps_per_episode']):
314
+ if done:
315
+ break
316
+
317
+ messages = [
318
+ {'role': 'system', 'content': SYSTEM_PROMPT},
319
+ {'role': 'user', 'content': observation_to_prompt(obs, task_id)}
320
+ ]
321
+ input_ids = tokenizer.apply_chat_template(
322
+ messages, tokenize=True, add_generation_prompt=True,
323
+ return_tensors='pt'
324
+ ).to('cuda')
325
+
326
+ # Generate all completions first β€” no env calls yet
327
+ group_completions, group_texts = [], []
328
+ for _ in range(CONFIG['grpo_group_size']):
329
+ with torch.no_grad():
330
+ out = model.generate(
331
+ input_ids, max_new_tokens=128, temperature=0.9,
332
+ do_sample=True, pad_token_id=tokenizer.eos_token_id,
333
+ )
334
+ gen_ids = out[0][input_ids.shape[1]:]
335
+ group_completions.append(gen_ids)
336
+ group_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True))
337
+
338
+ # Score each completion on a FRESH env snapshot
339
+ group_rewards = []
340
+ for gen_text in group_texts:
341
+ action = parse_action(gen_text)
342
+ try:
343
+ env_reset(task_id, seed=seed) # fresh snapshot
344
+ res = env_step(action)
345
+ r = res.get('reward', 0.0)
346
+ except:
347
+ r = 0.0
348
+ if action.get('action_type', 'noop') != 'noop':
349
+ r += 0.02 # exploration bonus
350
+ group_rewards.append(r)
351
+
352
+ # Advance main episode with best action
353
+ best_idx = group_rewards.index(max(group_rewards))
354
+ best_action = parse_action(group_texts[best_idx])
355
+ try:
356
+ adv_res = env_step(best_action)
357
+ obs = adv_res.get('observation', obs)
358
+ done = adv_res.get('done', False)
359
+ except:
360
+ done = True
361
+
362
+ trajectory.append({
363
+ 'input_ids': input_ids,
364
+ 'completions': group_completions,
365
+ 'rewards': group_rewards,
366
+ })
367
+
368
+ # Get final score from accumulated rewards
369
+ total_reward = sum(max(s['rewards']) for s in trajectory) if trajectory else 0.0
370
+ return trajectory, total_reward
371
+
372
+
373
+ def update_from_trajectory(trajectory):
374
+ """Single model update from full episode with KL penalty."""
375
  if not trajectory:
376
  return 0.0
377
 
 
388
  rewards = step_data['rewards']
389
 
390
  rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
391
+ if rewards_t.std() > 1e-8:
392
+ advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
393
+ else:
394
+ advantages = rewards_t - rewards_t.mean()
395
+
396
+ best_idx = rewards.index(max(rewards))
397
  best_ids = completions[best_idx].to(device)
398
+ best_adv = advantages[best_idx]
399
+
400
+ full_ids = torch.cat([input_ids[0], best_ids]).unsqueeze(0)
401
+ labels = full_ids.clone()
402
+ labels[0, :input_ids.shape[1]] = -100
403
+
404
+ outputs = model(full_ids, labels=labels)
405
+ policy_loss = outputs.loss * (-best_adv)
406
+
407
+ # KL penalty vs reference model
408
+ with torch.no_grad():
409
+ ref_out = ref_model(full_ids)
410
+ ref_logits = ref_out.logits[:, input_ids.shape[1]-1:-1, :]
411
+ pol_logits = outputs.logits[:, input_ids.shape[1]-1:-1, :]
412
+ kl = torch.nn.functional.kl_div(
413
+ torch.log_softmax(pol_logits, dim=-1),
414
+ torch.softmax(ref_logits, dim=-1),
415
+ reduction='batchmean'
416
+ )
417
+ total_loss = total_loss + policy_loss + CONFIG['kl_coeff'] * kl
418
+
419
+ total_loss = total_loss / len(trajectory)
420
+ total_loss.backward()
421
+ torch.nn.utils.clip_grad_norm_(
422
+ [p for p in model.parameters() if p.requires_grad],
423
+ CONFIG['max_grad_norm']
424
+ )
425
+ optimizer.step()
426
+ scheduler.step()
427
+ return total_loss.item()
428
+
429
+
430
+
431
+ # ── Main training loop ─────────────────────────────────────────────────────
432
+ best_score = -1e9
433
+ no_improve_count = 0
434
+ PATIENCE = 15
435
+ for task_id in CONFIG['tasks']:
436
+ print(f'\nπŸ“‹ Task: {task_id.upper()} | Baseline: {baseline_scores[task_id]["avg"]:.3f}')
437
+ print('-' * 40)
438
+
439
+ for ep in range(CONFIG['episodes_per_task']):
440
+ seed = random.randint(0, 9999)
441
+
442
+ trajectory, final_score = run_episode_collect(task_id, seed)
443
+ loss = update_from_trajectory(trajectory)
444
+
445
+
446
+ episode_scores[task_id].append(final_score)
447
+ global_ep += 1
448
+ elapsed = (time.time() - start_time) / 60
449
  recent = episode_scores[task_id][-10:]
450
  rolling = sum(recent) / len(recent) if recent else final_score
451
+ # 🧠 Early stopping logic
452
+ if rolling > best_score:
453
+ best_score = rolling
454
+ no_improve_count = 0
455
+ else:
456
+ no_improve_count += 1
457
+
458
+ if no_improve_count >= PATIENCE:
459
+ print("πŸ›‘ Early stopping triggered β€” no improvement")
460
+ break
461
+
462
+ training_log.append({
463
+ 'episode': global_ep, 'task_id': task_id,
464
+ 'score': final_score, 'rolling_avg': rolling,
465
+ 'loss': loss, 'elapsed_min': round(elapsed, 1)
466
+ })
467
+ # πŸ”₯ FAIL-SAFE CHECKPOINT (every episode)
468
+ try:
469
+ latest_ckpt = f"{CONFIG['output_dir']}/latest"
470
+
471
+ # βœ… Save model + tokenizer FIRST (atomic checkpoint)
472
+ model.save_pretrained(latest_ckpt)
473
+ tokenizer.save_pretrained(latest_ckpt)
474
+
475
+ # πŸ’Ύ Then save training state
476
+ state = {
477
+ "global_ep": global_ep,
478
+ "training_log": training_log,
479
+ "episode_scores": episode_scores
480
+ }
481
+
482
+ tmp_path = f"{CONFIG['output_dir']}/state_tmp.json"
483
+ final_path = f"{CONFIG['output_dir']}/state.json"
484
+
485
+ with open(tmp_path, "w") as f:
486
+ json.dump(state, f)
487
+
488
+ os.replace(tmp_path, final_path) # atomic replace
489
+
490
+ except Exception as e:
491
+ print("⚠️ Checkpoint save failed:", e)
492
+
493
+ if (ep + 1) % 5 == 0:
494
+ delta = rolling - baseline_scores[task_id]['avg']
495
+ trend = 'πŸ“ˆ' if delta > 0.02 else 'πŸ“‰' if delta < -0.02 else '➑️'
496
+ print(
497
+ f' {trend} Ep {ep+1:3d}/{CONFIG["episodes_per_task"]} | '
498
+ f'Score: {final_score:.3f} | Roll-10: {rolling:.3f} | '
499
+ f'vs baseline: {delta:+.3f} | Loss: {loss:.4f} | {elapsed:.0f}m'
500
+ )
501
+
502
+ if global_ep % CONFIG['save_every_n_episodes'] == 0:
503
+ ckpt = f'{CONFIG["output_dir"]}/ep{global_ep}'
504
+ model.save_pretrained(ckpt)
505
+ tokenizer.save_pretrained(ckpt)
506
+ print(f' πŸ’Ύ Checkpoint ep{global_ep}')
507
+
508
+ task_avg = sum(episode_scores[task_id]) / len(episode_scores[task_id])
509
+ base_avg = baseline_scores[task_id]['avg']
510
+ delta = task_avg - base_avg
511
+ result = 'βœ… IMPROVED' if delta > 0.02 else '⚠️ FLAT' if delta > -0.02 else '❌ DEGRADED'
512
+ print(f'\n{result} {task_id}: {base_avg:.3f} β†’ {task_avg:.3f} ({delta:+.3f})')
513
+
514
+ # Save training log so far (in case of crash)
515
+ with open(f'{CONFIG["output_dir"]}/training_log.json', 'w') as f:
516
+ json.dump(training_log, f, indent=2)
517
+ print(' πŸ“ Training log saved')
518
+
519
+ print(f'\nπŸŽ‰ Training complete! {(time.time()-start_time)/60:.0f} minutes')
520
+
521
+ # ── Cell 9: Post-Training Eval + Generalization ───────────────────────────────
522
+ FastLanguageModel.for_inference(model)
523
+ print('Post-training evaluation (8 episodes per task, unseen seeds)...')
524
+
525
+ post_scores = {}
526
+ for task_id in CONFIG['tasks']:
527
+ scores = [run_episode(task_id, seed=i*13+7) for i in range(8)]
528
+ avg = sum(scores) / len(scores)
529
+ post_scores[task_id] = {'scores': scores, 'avg': avg}
530
+ delta = avg - baseline_scores[task_id]['avg']
531
+ print(f' [{task_id}] {baseline_scores[task_id]["avg"]:.3f} β†’ {avg:.3f} '
532
+ f'({("+" if delta>=0 else "")}{delta:.3f})')
533
+
534
+ print('\nZero-shot generalization (ARIA tasks β€” never seen in training):')
535
+ gen_scores = {}
536
+ for task_id in ['security', 'database', 'failover']:
537
+ scores = []
538
+ for i in range(5):
539
+ try: scores.append(run_episode(task_id, seed=i*17+5))
540
+ except: scores.append(0.0)
541
+ avg = sum(scores) / len(scores)
542
+ gen_scores[task_id] = avg
543
+ print(f' [{task_id}] zero-shot: {avg:.3f}')
544
+
545
+ # ── Cell 10: Learning Curve Visualization ────────────────────────────────────
546
+ import matplotlib.pyplot as plt
547
+ import matplotlib.gridspec as gridspec
548
+ import numpy as np
549
+
550
+ fig = plt.figure(figsize=(20, 12))
551
+ fig.patch.set_facecolor('#0d1117')
552
+ gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.4, wspace=0.35)
553
+ COLORS = {'easy':'#4caf50','medium':'#ff9800','hard':'#f44336','bonus':'#9c27b0'}
554
+
555
+ def style_ax(ax, title):
556
+ ax.set_facecolor('#161b22')
557
+ ax.set_title(title, color='white', fontsize=12, fontweight='bold', pad=10)
558
+ ax.tick_params(colors='#8b949e', labelsize=9)
559
+ for spine in ax.spines.values(): spine.set_color('#30363d')
560
+ ax.spines['top'].set_visible(False)
561
+ ax.spines['right'].set_visible(False)
562
+ ax.grid(True, alpha=0.1, color='#30363d')
563
+
564
+ for idx, task_id in enumerate(CONFIG['tasks']):
565
+ row, col = divmod(idx, 3)
566
+ ax = fig.add_subplot(gs[row, col])
567
+ style_ax(ax, f'Task: {task_id.upper()}')
568
+ task_log = [e for e in training_log if e['task_id'] == task_id]
569
+ eps = [e['episode'] for e in task_log]
570
+ scores = [e['score'] for e in task_log]
571
+ rolling = [e['rolling_avg'] for e in task_log]
572
+ color = COLORS.get(task_id, '#58a6ff')
573
+ ax.plot(eps, scores, alpha=0.15, color=color, linewidth=1)
574
+ ax.plot(eps, rolling, color=color, linewidth=2.5, label='Rolling avg (10)')
575
+ ax.axhline(y=baseline_scores[task_id]['avg'], color='#f85149',
576
+ linestyle='--', linewidth=1.5, label='Baseline')
577
+ ax.axhline(y=post_scores[task_id]['avg'], color='#3fb950',
578
+ linestyle='--', linewidth=1.5, label='Post-training')
579
+ ax.set_ylim(0, 1.05)
580
+ ax.set_xlabel('Episode', color='#8b949e', fontsize=9)
581
+ ax.set_ylabel('Score', color='#8b949e', fontsize=9)
582
+ ax.legend(facecolor='#161b22', labelcolor='white', fontsize=8)
583
+
584
+ ax5 = fig.add_subplot(gs[1, 1])
585
+ style_ax(ax5, 'Before vs After (all tasks)')
586
+ x = np.arange(len(CONFIG['tasks']))
587
+ w = 0.35
588
+ before_v = [baseline_scores[t]['avg'] for t in CONFIG['tasks']]
589
+ after_v = [post_scores[t]['avg'] for t in CONFIG['tasks']]
590
+ b1 = ax5.bar(x-w/2, before_v, w, label='Before', color='#f85149', alpha=0.85)
591
+ b2 = ax5.bar(x+w/2, after_v, w, label='After', color='#3fb950', alpha=0.85)
592
+ for bar, v in zip(b1, before_v):
593
+ ax5.text(bar.get_x()+bar.get_width()/2., v+0.01, f'{v:.2f}',
594
+ ha='center', color='white', fontsize=8)
595
+ for bar, v in zip(b2, after_v):
596
+ ax5.text(bar.get_x()+bar.get_width()/2., v+0.01, f'{v:.2f}',
597
+ ha='center', color='white', fontsize=8)
598
+ ax5.set_xticks(x)
599
+ ax5.set_xticklabels(CONFIG['tasks'], color='#8b949e')
600
+ ax5.set_ylim(0, 1.15)
601
+ ax5.legend(facecolor='#161b22', labelcolor='white', fontsize=9)
602
+
603
+ ax6 = fig.add_subplot(gs[1, 2])
604
+ ax6.set_facecolor('#161b22')
605
+ ax6.set_title('Summary', color='white', fontsize=12, fontweight='bold')
606
+ ax6.axis('off')
607
+ lines = [
608
+ ('Model', 'Llama-3.1-8B (Unsloth 4-bit)'),
609
+ ('Algorithm', 'GRPO'),
610
+ ('LoRA rank', str(CONFIG['lora_rank'])),
611
+ ('Total episodes', str(global_ep)),
612
+ ('', ''),
613
+ ]
614
+ for t in CONFIG['tasks']:
615
+ b = baseline_scores[t]['avg']; a = post_scores[t]['avg']
616
+ lines.append((f' {t}', f'{b:.2f} β†’ {a:.2f} (+{a-b:.2f})'))
617
+ if gen_scores:
618
+ lines += [('', ''), ('Zero-shot', '')]
619
+ for t, s in gen_scores.items():
620
+ lines.append((f' {t}', f'{s:.2f}'))
621
+ y = 0.95
622
+ for label, val in lines:
623
+ if not label: y -= 0.04; continue
624
+ ax6.text(0.02, y, label+':', color='#8b949e', fontsize=9,
625
+ transform=ax6.transAxes, fontweight='bold')
626
+ ax6.text(0.52, y, val, color='#c9d1d9', fontsize=9, transform=ax6.transAxes)
627
+ y -= 0.08
628
+
629
+ fig.suptitle('ARIA β€” DevOps Incident Response\nGRPO Training (Llama-3.1-8B Full Curriculum)',
630
+ color='white', fontsize=16, fontweight='bold', y=0.98)
631
+ plt.savefig('training_curve_8b.png', dpi=150, bbox_inches='tight', facecolor='#0d1117')
632
+ print('βœ… Saved training_curve_8b.png')
633
+ plt.show()
634
+
635
+ # ── Cell 11: Save to HuggingFace Hub ─────────────────────────────────────────
636
+ from huggingface_hub import HfApi
637
+ import json
638
+
639
+ print(f'Pushing to {CONFIG["hf_repo"]}...')
640
+ FastLanguageModel.for_inference(model)
641
+
642
+ model.save_pretrained_merged(CONFIG['output_dir'], tokenizer, save_method='merged_16bit')
643
+ model.push_to_hub_merged(CONFIG['hf_repo'], tokenizer,
644
+ save_method='merged_16bit', token=HF_TOKEN)
645
+ print(f'βœ… Model: https://huggingface.co/{CONFIG["hf_repo"]}')
646
+
647
+ api = HfApi()
648
+ for fname in ['training_curve_8b.png']:
649
+ api.upload_file(path_or_fileobj=fname, path_in_repo=fname,
650
+ repo_id=CONFIG['hf_repo'], token=HF_TOKEN)
651
+ print(f'βœ… {fname} uploaded')
652
+
653
+ with open('training_log_8b.json', 'w') as f:
654
+ json.dump(training_log, f, indent=2)
655
+ api.upload_file(path_or_fileobj='training_log_8b.json',
656
+ path_in_repo='training_log_8b.json',
657
+ repo_id=CONFIG['hf_repo'], token=HF_TOKEN)
658
+
659
  print('\nπŸŽ‰ DONE! Shut down the RunPod instance now to stop billing.')
660
  print(f' Model: https://huggingface.co/{CONFIG["hf_repo"]}')
661
  print(f' Curve: check training_curve_8b.png in the repo')