aria-training / train_model.py
Arijit-07's picture
Update train_model.py
4f2fad3 verified
import sys, os, json, time, random, re, copy
# ── CRITICAL: Set before ANY import ──────────────────────────────────────────
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
# ── Install dependencies ──────────────────────────────────────────────────────
import subprocess
subprocess.run([
'pip', 'install', '-q',
'unsloth==2025.7.7',
'transformers==4.51.3',
'accelerate==0.34.2',
'peft==0.13.2',
'trl==0.14.0',
'requests',
'matplotlib',
'scipy',
'huggingface_hub',
], capture_output=True)
# ── Clear stale module cache ──────────────────────────────────────────────────
for mod in list(sys.modules.keys()):
if any(x in mod for x in ['trl','unsloth','transformers','peft']):
del sys.modules[mod]
# ── Verify imports ────────────────────────────────────────────────────────────
import unsloth
from unsloth import FastLanguageModel
import transformers, peft, torch
print(f"βœ… unsloth {unsloth.__version__}")
print(f"βœ… transformers {transformers.__version__}")
print(f"βœ… torch {torch.__version__} | CUDA: {torch.cuda.is_available()}")
print(f"βœ… UNSLOTH_RETURN_LOGITS = {os.environ['UNSLOTH_RETURN_LOGITS']}")
# ── Auth ──────────────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get('HF_TOKEN', '')
if HF_TOKEN:
from huggingface_hub import login
login(token=HF_TOKEN, add_to_git_credential=False)
print("βœ… Logged in to HuggingFace")
else:
print("⚠️ HF_TOKEN not set β€” will not push to Hub")
# ── Config ────────────────────────────────────────────────────────────────────
CONFIG = {
'model_name': 'unsloth/Meta-Llama-3.1-8B-Instruct',
'max_seq_length': 2048, # reduced from 3072 β€” safer on L4
'load_in_4bit': True, # ALWAYS 4bit β€” L4 has 23.7GB
'env_url': 'https://arijit-07-devops-incident-response.hf.space',
'tasks': ['easy', 'medium', 'hard', 'bonus'],
'episodes_per_task': 40,
'max_steps_per_episode': 12,
'learning_rate': 5e-6,
'grpo_group_size': 4,
'lora_rank': 32,
'lora_alpha': 64,
'max_grad_norm': 0.5,
'kl_coeff': 0.05,
'hf_repo': 'Arijit-07/aria-devops-llama8b',
'output_dir': '/data/outputs',
'save_every_n_episodes': 20,
}
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ── Environment Client ────────────────────────────────────────────────────────
import requests
BASE_URL = CONFIG['env_url']
def env_reset(task_id, seed=None):
payload = {'task_id': task_id}
if seed is not None:
payload['seed'] = seed
for attempt in range(3):
try:
r = requests.post(f'{BASE_URL}/reset', json=payload, timeout=30)
r.raise_for_status()
return r.json()
except Exception as e:
if attempt == 2:
raise
time.sleep(5)
def env_step(action):
for attempt in range(3):
try:
r = requests.post(f'{BASE_URL}/step', json=action, timeout=30)
r.raise_for_status()
return r.json()
except Exception as e:
if attempt == 2:
raise
time.sleep(5)
def env_state():
r = requests.get(f'{BASE_URL}/state', timeout=30)
r.raise_for_status()
return r.json()
VALID_ACTIONS = {
"diagnose", "read_logs", "read_metrics", "read_runbook",
"search_logs", "restart_service", "rollback", "scale_up",
"alert_oncall", "acknowledge", "noop", "block_ip_range",
"create_index", "failover"
}
def sanitize_action(action):
DEFAULT_SERVICE = "order-service"
if not isinstance(action, dict):
return {"action_type": "read_logs", "service": DEFAULT_SERVICE}
action_type = action.get("action_type", "").lower()
if action_type not in VALID_ACTIONS:
action_type = "read_logs"
service = action.get("service") or action.get("service_name") or DEFAULT_SERVICE
clean = {"action_type": action_type, "service": service}
for key in ["root_cause", "runbook", "version", "reason",
"query", "ip_range", "table", "column", "target_region"]:
if key in action and isinstance(action[key], str):
clean[key] = action[key]
return clean
# Test connection
health = requests.get(f'{BASE_URL}/health', timeout=15).json()
print(f"βœ… Environment: {health}")
# ── System Prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an autonomous DevOps agent.
You MUST return ONLY valid JSON.
action_type MUST be one of:
diagnose, read_logs, read_metrics, read_runbook, search_logs,
restart_service, rollback, scale_up, alert_oncall, acknowledge,
noop, block_ip_range, create_index, failover
Always include "service" field. Use exact parameter names.
Output valid JSON only. Example:
{"action_type": "read_logs", "service": "order-service"}"""
def observation_to_prompt(obs, task_id):
# Compact representation to save tokens
services = obs.get('services', [])
alerts = obs.get('active_alerts', [])
evidence = obs.get('evidence_log', [])
svc_lines = []
for s in sorted(services, key=lambda x: x.get('error_rate', 0), reverse=True)[:6]:
svc_lines.append(f" {s.get('name','')}: {s.get('status','')} err={s.get('error_rate',0):.3f} mem={s.get('memory',0):.1f}%")
alert_lines = []
for a in alerts[:4]:
alert_lines.append(f" [{a.get('severity','').upper()}] {a.get('service','')}: {a.get('message','')}")
ev_lines = []
for e in evidence[-3:]:
ev_lines.append(f" [{e.get('action_type','').upper()}] {e.get('content','')[:100]}")
return (
f"Task: {task_id} | Step {obs.get('step',0)}/{obs.get('max_steps',15)}\n"
f"Services:\n" + "\n".join(svc_lines) + "\n"
f"Alerts:\n" + "\n".join(alert_lines) + "\n"
+ (f"Evidence:\n" + "\n".join(ev_lines) if ev_lines else "")
+ "\nChoose next action as JSON:"
)
# ── Action Parser ─────────────────────────────────────────────────────────────
def parse_action(text):
text = text.strip()
for pattern in [
r'''```json\s*({.*?})\s*```''',
r'''```\s*({.*?})\s*```''',
r'''({\s*"action_type"[^}]+})''',
]:
match = re.search(pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except:
continue
try:
return json.loads(text)
except:
return {'action_type': 'noop'}
# ── Load Model ────────────────────────────────────────────────────────────────
os.makedirs(CONFIG['output_dir'], exist_ok=True)
# FIX: Delete bad checkpoint if it exists but is incompatible
checkpoint_path = f"{CONFIG['output_dir']}/latest"
state_path = f"{CONFIG['output_dir']}/state.json"
def is_valid_checkpoint(path):
"""Check if checkpoint has required model_type in config.json"""
config_file = os.path.join(path, 'config.json')
adapter_file = os.path.join(path, 'adapter_config.json')
if not os.path.exists(config_file) and not os.path.exists(adapter_file):
return False
# Check adapter_config for incompatible fields
if os.path.exists(adapter_file):
try:
with open(adapter_file) as f:
cfg = json.load(f)
# alora_invocation_tokens is from old peft version β€” incompatible
if 'alora_invocation_tokens' in cfg:
print(f"⚠️ Checkpoint has incompatible peft config field 'alora_invocation_tokens'")
return False
except:
return False
return True
resuming = False
training_log = []
episode_scores = {t: [] for t in CONFIG['tasks']}
global_ep = 0
if os.path.exists(checkpoint_path):
if is_valid_checkpoint(checkpoint_path):
print("πŸ” Valid checkpoint found β€” resuming...")
resuming = True
if os.path.exists(state_path):
with open(state_path) as f:
state = json.load(f)
global_ep = state.get('global_ep', 0)
training_log = state.get('training_log', [])
episode_scores = state.get('episode_scores', {t: [] for t in CONFIG['tasks']})
print(f"βœ… Resumed from episode {global_ep}")
else:
print("⚠️ Incompatible checkpoint found β€” deleting and starting fresh")
import shutil
shutil.rmtree(checkpoint_path, ignore_errors=True)
if os.path.exists(state_path):
os.remove(state_path)
resuming = False
print(f"Loading model: {CONFIG['model_name']} ({'resuming' if resuming else 'fresh'})")
if resuming:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=checkpoint_path,
max_seq_length=CONFIG['max_seq_length'],
dtype=None,
load_in_4bit=CONFIG['load_in_4bit'], # ALWAYS 4bit
token=HF_TOKEN if HF_TOKEN else None,
)
else:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=CONFIG['model_name'],
max_seq_length=CONFIG['max_seq_length'],
dtype=None,
load_in_4bit=CONFIG['load_in_4bit'], # ALWAYS 4bit
token=HF_TOKEN if HF_TOKEN else None,
)
model = FastLanguageModel.get_peft_model(
model,
r=CONFIG['lora_rank'],
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj'],
lora_alpha=CONFIG['lora_alpha'],
lora_dropout=0.05,
bias='none',
use_gradient_checkpointing='unsloth',
random_state=42,
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"βœ… Model loaded | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
print(f" VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB used")
# Frozen reference model for KL penalty
ref_model = copy.deepcopy(model)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad = False
print("βœ… Reference model frozen for KL penalty")
# ── Episode Runner (for baseline) ─────────────────────────────────────────────
def run_episode(task_id, seed=None, verbose=False):
obs = env_reset(task_id, seed=seed)
total_reward = 0.0
done = False
FastLanguageModel.for_inference(model)
for step in range(CONFIG['max_steps_per_episode']):
if done:
break
messages = [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': observation_to_prompt(obs, task_id)}
]
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True,
return_tensors='pt'
)
input_ids = input_ids[:, -CONFIG['max_seq_length']:].to('cuda')
with torch.no_grad():
out = model.generate(
input_ids, max_new_tokens=100, temperature=0.7,
do_sample=True, pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
action = sanitize_action(parse_action(text))
if verbose:
print(f" Step {step+1}: {action}")
result = env_step(action)
total_reward += result.get('reward', 0.0)
obs = result.get('observation', obs)
done = result.get('done', False)
return total_reward
# ── Pre-Training Baseline ─────────────────────────────────────────────────────
print("\nRunning pre-training baseline (5 episodes per task)...")
baseline_scores = {}
for task_id in CONFIG['tasks']:
scores = [run_episode(task_id, seed=i*7+3) for i in range(5)]
avg = sum(scores) / len(scores)
baseline_scores[task_id] = {'scores': scores, 'avg': avg}
print(f" [{task_id}] baseline: {avg:.3f}")
print("βœ… Baseline done")
# ── GRPO Training Functions ───────────────────────────────────────────────────
def run_episode_collect(task_id, seed):
"""FIXED: Score completions on fresh env snapshots β€” no reward gate burn."""
obs = env_reset(task_id, seed=seed)
trajectory = []
done = False
FastLanguageModel.for_inference(model)
for step in range(CONFIG['max_steps_per_episode']):
if done:
break
messages = [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': observation_to_prompt(obs, task_id)}
]
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True,
return_tensors='pt'
)
input_ids = input_ids[:, -CONFIG['max_seq_length']:].to('cuda')
# Generate all completions first β€” no env calls yet
group_completions, group_texts = [], []
for _ in range(CONFIG['grpo_group_size']):
with torch.no_grad():
out = model.generate(
input_ids, max_new_tokens=100, temperature=0.9,
do_sample=True, pad_token_id=tokenizer.eos_token_id,
)
gen_ids = out[0][input_ids.shape[1]:]
group_completions.append(gen_ids)
group_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True))
# Score each on a FRESH env snapshot
group_rewards = []
for gen_text in group_texts:
action = sanitize_action(parse_action(gen_text))
try:
env_reset(task_id, seed=seed) # fresh snapshot
res = env_step(action)
r = res.get('reward', 0.0)
except:
r = 0.0
if action.get('action_type', 'noop') != 'noop':
r += 0.02 # exploration bonus
group_rewards.append(r)
# Advance main episode with best action
best_idx = group_rewards.index(max(group_rewards))
best_action = sanitize_action(parse_action(group_texts[best_idx]))
try:
adv_res = env_step(best_action)
obs = adv_res.get('observation', obs)
done = adv_res.get('done', False)
except:
done = True
trajectory.append({
'input_ids': input_ids,
'completions': group_completions,
'rewards': group_rewards,
})
total_reward = sum(max(s['rewards']) for s in trajectory) if trajectory else 0.0
return trajectory, total_reward
def update_from_trajectory(trajectory):
"""Single model update from full episode + KL penalty."""
if not trajectory:
return 0.0
device = next(model.parameters()).device
FastLanguageModel.for_training(model)
model.train()
optimizer.zero_grad()
total_loss = torch.tensor(0.0, device=device)
for step_data in trajectory:
input_ids = step_data['input_ids'].to(device)
completions = step_data['completions']
rewards = step_data['rewards']
rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
if rewards_t.std() > 1e-8:
advantages = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
else:
advantages = rewards_t - rewards_t.mean()
best_idx = rewards.index(max(rewards))
best_ids = completions[best_idx].to(device)
best_adv = advantages[best_idx]
full_ids = torch.cat([input_ids[0], best_ids]).unsqueeze(0)
labels = full_ids.clone()
labels[0, :input_ids.shape[1]] = -100
outputs = model(full_ids, labels=labels)
policy_loss = outputs.loss * (-best_adv)
# KL penalty
with torch.no_grad():
ref_out = ref_model(full_ids)
ref_logits = ref_out.logits[:, input_ids.shape[1]-1:-1, :]
pol_logits = outputs.logits[:, input_ids.shape[1]-1:-1, :]
kl = torch.nn.functional.kl_div(
torch.log_softmax(pol_logits, dim=-1),
torch.softmax(ref_logits, dim=-1),
reduction='batchmean'
)
total_loss = total_loss + policy_loss + CONFIG['kl_coeff'] * kl
total_loss = total_loss / len(trajectory)
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad],
CONFIG['max_grad_norm']
)
optimizer.step()
scheduler.step()
return total_loss.item()
# ── Optimizer ─────────────────────────────────────────────────────────────────
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
optimizer = AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=CONFIG['learning_rate'], weight_decay=0.01
)
total_eps = CONFIG['episodes_per_task'] * len(CONFIG['tasks'])
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=max(1, total_eps // 10),
num_training_steps=total_eps
)
# ── Training Loop ─────────────────────────────────────────────────────────────
def run_training():
global global_ep
start_time = time.time()
print("=" * 65)
print("ARIA GRPO TRAINING β€” Llama-3.1-8B")
print(f"LR={CONFIG['learning_rate']} | KL={CONFIG['kl_coeff']} | Groups={CONFIG['grpo_group_size']}")
print(f"Strategy: fresh env per completion β†’ episode-level update")
print("=" * 65)
for task_id in CONFIG['tasks']:
print(f"\nπŸ“‹ Task: {task_id.upper()} | Baseline: {baseline_scores[task_id]['avg']:.3f}")
print("-" * 40)
for ep in range(CONFIG['episodes_per_task']):
seed = random.randint(0, 9999)
trajectory, final_score = run_episode_collect(task_id, seed)
loss = update_from_trajectory(trajectory)
episode_scores[task_id].append(final_score)
global_ep += 1
elapsed = (time.time() - start_time) / 60
recent = episode_scores[task_id][-10:]
rolling = sum(recent) / len(recent) if recent else 0.0
training_log.append({
'episode': global_ep, 'task_id': task_id,
'score': final_score, 'rolling_avg': rolling,
'loss': loss, 'elapsed_min': round(elapsed, 1)
})
# Save checkpoint every episode (atomic write)
try:
latest_ckpt = f"{CONFIG['output_dir']}/latest"
model.save_pretrained(latest_ckpt)
tokenizer.save_pretrained(latest_ckpt)
state = {
'global_ep': global_ep,
'training_log': training_log,
'episode_scores': episode_scores
}
tmp = f"{CONFIG['output_dir']}/state_tmp.json"
final_path = f"{CONFIG['output_dir']}/state.json"
with open(tmp, 'w') as f:
json.dump(state, f)
os.replace(tmp, final_path)
except Exception as e:
print(f"⚠️ Checkpoint save failed: {e}")
if (ep + 1) % 5 == 0:
delta = rolling - baseline_scores[task_id]['avg']
trend = 'πŸ“ˆ' if delta > 0.02 else 'πŸ“‰' if delta < -0.02 else '➑️'
print(
f" {trend} Ep {ep+1:3d}/{CONFIG['episodes_per_task']} | "
f"Score: {final_score:.3f} | Roll-10: {rolling:.3f} | "
f"vs baseline: {delta:+.3f} | Loss: {loss:.4f} | {elapsed:.0f}m"
)
task_avg = sum(episode_scores[task_id]) / len(episode_scores[task_id])
base_avg = baseline_scores[task_id]['avg']
delta = task_avg - base_avg
result = 'βœ… IMPROVED' if delta > 0.02 else '⚠️ FLAT' if delta > -0.02 else '❌ DEGRADED'
print(f"\n{result} {task_id}: {base_avg:.3f} β†’ {task_avg:.3f} ({delta:+.3f})")
# Save training log after each task
with open(f"{CONFIG['output_dir']}/training_log.json", 'w') as f:
json.dump(training_log, f, indent=2)
print(f"\nπŸŽ‰ Training complete! {(time.time()-start_time)/60:.0f} minutes")
# Post-training eval
FastLanguageModel.for_inference(model)
print("\nPost-training evaluation...")
for task_id in CONFIG['tasks']:
scores = [run_episode(task_id, seed=i*13+7) for i in range(5)]
avg = sum(scores) / len(scores)
print(f" [{task_id}] {baseline_scores[task_id]['avg']:.3f} β†’ {avg:.3f} ({avg-baseline_scores[task_id]['avg']:+.3f})")
# Push to Hub
if HF_TOKEN:
print(f"\nPushing to {CONFIG['hf_repo']}...")
model.push_to_hub_merged(
CONFIG['hf_repo'], tokenizer,
save_method='merged_16bit', token=HF_TOKEN,
)
from huggingface_hub import HfApi
api = HfApi()
for fname in ['training_log.json']:
fpath = f"{CONFIG['output_dir']}/{fname}"
if os.path.exists(fpath):
api.upload_file(
path_or_fileobj=fpath,
path_in_repo=fname,
repo_id=CONFIG['hf_repo'],
token=HF_TOKEN,
)
print(f"βœ… Model live: https://huggingface.co/{CONFIG['hf_repo']}")
# ── Entry Point ───────────────────────────────────────────────────────────────
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
def make_handler():
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
self._respond()
def do_HEAD(self):
self._respond()
def _respond(self):
state_file = f"{CONFIG['output_dir']}/state.json"
try:
if os.path.exists(state_file):
with open(state_file) as f:
state = json.load(f)
ep = state.get('global_ep', 0)
log = state.get('training_log', [])
last = log[-1] if log else {}
msg = (
f"Episode {ep}/{total_eps} | "
f"task={last.get('task_id','-')} | "
f"score={last.get('score',0):.3f} | "
f"roll10={last.get('rolling_avg',0):.3f} | "
f"elapsed={last.get('elapsed_min',0):.0f}m"
)
else:
msg = "Starting up β€” model loading"
except Exception as e:
msg = f"Running (state: {e})"
body = (
b"<!DOCTYPE html><html><head>"
b"<meta http-equiv='refresh' content='20'>"
b"<style>body{background:#0d1117;color:#10b981;"
b"font-family:monospace;padding:40px;font-size:18px}"
b"h1{color:#3b82f6}</style></head><body>"
b"<h1>ARIA Training</h1><pre>" +
msg.encode() +
b"</pre><p style='color:#6b7280'>"
b"Auto-refreshes every 20s</p></body></html>"
)
self.send_response(200)
self.send_header('Content-Type', 'text/html')
self.send_header('Content-Length', str(len(body)))
self.end_headers()
if self.command != 'HEAD':
self.wfile.write(body)
def log_message(self, *args):
pass
return Handler
if __name__ == "__main__":
print("πŸš€ Starting training in background thread...")
thread = threading.Thread(target=run_training, daemon=True)
thread.start()
print("🌐 Status server on port 7860...")
server = HTTPServer(('0.0.0.0', 7860), make_handler())
print("βœ… Server ready")
server.serve_forever()