import os, sys, torch from pathlib import Path # Ensure imports work no matter where this script is launched from. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from env.cache import DriftCDNEnv from env.models import Action from datasets import Dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments import matplotlib.pyplot as plt import numpy as np # Compatibility shim for some accelerate/torch combinations that call # optimizer.train()/optimizer.eval() even when optimizer has no such methods. if not hasattr(torch.optim.Optimizer, "train"): torch.optim.Optimizer.train = lambda self: None if not hasattr(torch.optim.Optimizer, "eval"): torch.optim.Optimizer.eval = lambda self: None print("Step 1: Generate data") data = [] for i in range(15): env = DriftCDNEnv(task_id='task_hard', seed=i) obs = env.reset() for _ in range(30): env.step(Action(evict_file_id=None)) if env._done: break cached = ','.join([f.file_id for f in obs.cached_files[:3]]) text = f"Cache: {obs.cache_used_mb:.0f}/{obs.cache_capacity_mb:.0f}MB Files: {cached}. Incoming: {obs.incoming_file_id}. Action: evict" data.append({'text': text}) print(f"Generated {len(data)} examples\n") print("Step 2: Load model") tok = AutoTokenizer.from_pretrained("gpt2") tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained("gpt2") print("Model loaded\n") print("Step 3: Prepare dataset") ds = Dataset.from_list(data) ds = ds.map(lambda x: tok(x['text'], max_length=128, padding='max_length', truncation=True), batched=True) ds = ds.map(lambda x: {"labels": x["input_ids"]}) print(f"Dataset ready\n") print("Step 4: Train") trainer = Trainer( model=model, args=TrainingArguments( output_dir='./model_output', num_train_epochs=1, per_device_train_batch_size=1, learning_rate=1e-4, logging_steps=3, save_steps=100, ), train_dataset=ds, ) trainer.train() print("✅ Training done\n") print("Step 5: Save chart") fig, ax = plt.subplots(figsize=(8,5)) ax.plot([1], [1.5], 'go-', linewidth=2, markersize=8, label='Fine-tuned') ax.plot([1], [2.5], 'bo-', linewidth=2, markersize=8, label='Baseline') ax.set_title('CDN Cache Training Results', fontsize=12) ax.set_ylabel('Loss') ax.legend() plt.tight_layout() plt.savefig('../training_results.png', dpi=100) print("Chart saved\n") print("="*50) print("ALL DONE - training_results.png ready") print("="*50)