| |
| |
|
|
| import torch |
| import os |
|
|
| |
| os.environ['TORCHDYNAMO_DISABLE'] = '1' |
| os.environ['TORCH_COMPILE_DISABLE'] = '1' |
|
|
| from unsloth import FastLanguageModel |
| from datasets import load_dataset |
| from trl import SFTTrainer |
| from transformers import TrainingArguments, TrainerCallback |
| import time |
| import sys |
| from datetime import datetime, timedelta |
| import json |
|
|
| |
| class ProgressCallback(TrainerCallback): |
| def __init__(self, total_steps): |
| self.total_steps = total_steps |
| self.start_time = None |
| self.step_times = [] |
| self.losses = [] |
| self.last_update = 0 |
| self.crashed = False |
| |
| def on_train_begin(self, args, state, control, **kwargs): |
| self.start_time = time.time() |
| print("\n" + "="*80) |
| print("Training Started: {}".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) |
| print("="*80 + "\n") |
| |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if logs is None: |
| return |
| |
| current_step = state.global_step |
| if current_step == 0 or current_step == self.last_update: |
| return |
| |
| self.last_update = current_step |
| |
| |
| if 'loss' in logs: |
| self.losses.append(logs['loss']) |
| |
| |
| progress = current_step / self.total_steps |
| elapsed = time.time() - self.start_time |
| |
| |
| if current_step > 0: |
| avg_time_per_step = elapsed / current_step |
| remaining_steps = self.total_steps - current_step |
| eta_seconds = avg_time_per_step * remaining_steps |
| eta = str(timedelta(seconds=int(eta_seconds))) |
| else: |
| eta = "calculating..." |
| |
| |
| bar_length = 50 |
| filled = int(bar_length * progress) |
| bar = '█' * filled + '░' * (bar_length - filled) |
| |
| |
| sys.stdout.write('\r\033[K') |
| |
| |
| progress_line = f"Progress: [{bar}] {progress*100:.1f}% | Step {current_step}/{self.total_steps}" |
| print(progress_line) |
| |
| |
| loss_str = f"{logs.get('loss', 0):.4f}" if 'loss' in logs else "N/A" |
| lr_str = f"{logs.get('learning_rate', 0):.2e}" if 'learning_rate' in logs else "N/A" |
| |
| metrics_line = f"Loss: {loss_str} | LR: {lr_str} | Elapsed: {str(timedelta(seconds=int(elapsed)))} | ETA: {eta}" |
| print(metrics_line) |
| |
| |
| if len(self.losses) > 1: |
| self._print_mini_graph() |
| |
| print() |
| |
| def _print_mini_graph(self): |
| """Print a simple ASCII graph of recent losses""" |
| recent_losses = self.losses[-20:] |
| if len(recent_losses) < 2: |
| return |
| |
| |
| min_loss = min(recent_losses) |
| max_loss = max(recent_losses) |
| range_loss = max_loss - min_loss if max_loss > min_loss else 1 |
| |
| graph_height = 5 |
| graph = [[] for _ in range(graph_height)] |
| |
| for loss in recent_losses: |
| normalized = (loss - min_loss) / range_loss |
| level = int(normalized * (graph_height - 1)) |
| |
| for i in range(graph_height): |
| if i == (graph_height - 1 - level): |
| graph[i].append('●') |
| else: |
| graph[i].append(' ') |
| |
| print("\nLoss Trend (last 20 steps):") |
| print(f" {max_loss:.4f} ┤" + ''.join(graph[0])) |
| for i in range(1, graph_height - 1): |
| print(" │" + ''.join(graph[i])) |
| print(f" {min_loss:.4f} └" + '─' * len(recent_losses)) |
| |
| def on_train_end(self, args, state, control, **kwargs): |
| total_time = time.time() - self.start_time |
| |
| print("\n" + "="*80) |
| print("Training Completed: {}".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) |
| print("="*80) |
| print(f"Total training time: {str(timedelta(seconds=int(total_time)))}") |
| print(f"Average time per step: {total_time/self.total_steps:.2f}s") |
| if self.losses: |
| print(f"Final loss: {self.losses[-1]:.4f}") |
| print(f"Best loss: {min(self.losses):.4f}") |
| print("="*80 + "\n") |
|
|
| print("="*80) |
| print("LLM-as-a-Judge Watchdog Training - Comprehensive Security & Evaluation") |
| print("NVIDIA DGX Spark - Unsloth Optimized Training") |
| print("="*80) |
|
|
| |
| max_seq_length = 8192 |
| dtype = None |
| load_in_4bit = True |
|
|
| print("\n[1/6] Loading Foundation-Sec-1.1-8B-Instruct model...") |
| print(f" - Max sequence length: {max_seq_length}") |
| print(f" - Quantization: 4-bit (QLoRA)") |
|
|
| |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name = "fdtn-ai/Foundation-Sec-1.1-8B-Instruct", |
| max_seq_length = max_seq_length, |
| dtype = dtype, |
| load_in_4bit = load_in_4bit, |
| ) |
|
|
| print("✓ Model loaded successfully") |
|
|
| print("\n[2/6] Applying LoRA adapters for efficient fine-tuning...") |
|
|
| |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r = 16, |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj"], |
| lora_alpha = 16, |
| lora_dropout = 0.05, |
| bias = "none", |
| use_gradient_checkpointing = "unsloth", |
| random_state = 3407, |
| ) |
|
|
| print("✓ LoRA adapters applied") |
|
|
| print("\n[3/6] Loading and formatting training data...") |
|
|
| |
| def formatting_prompts_func(examples): |
| instructions = examples["instruction"] |
| responses = examples["response"] |
| texts = [] |
| |
| for instruction, response in zip(instructions, responses): |
| |
| text = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| |
| You are a cybersecurity AI assistant specialized in analyzing agentic workflow executions for security threats and vulnerabilities. You have deep expertise in: |
| - Detecting multi-step attack patterns in autonomous AI systems |
| - Analyzing attack propagation through complex workflows |
| - Assessing the effectiveness of security guardrails |
| - Providing actionable security recommendations |
| |
| Your analysis should be thorough, technically accurate, and focused on protecting enterprise agentic AI deployments.<|eot_id|><|start_header_id|>user<|end_header_id|> |
| |
| {instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| |
| {response}<|eot_id|>""" |
| texts.append(text) |
| |
| return {"text": texts} |
|
|
| |
| print("Loading dataset from: ./training_data_v3_synthetic.jsonl") |
| dataset = load_dataset('json', data_files='./training_data_v3_synthetic.jsonl', split='train') |
| dataset = dataset.map(formatting_prompts_func, batched=True) |
|
|
| print(f"✓ Training dataset loaded: {len(dataset):,} examples") |
| print(f" Dataset size: {os.path.getsize('./training_data_v3_synthetic.jsonl') / (1024*1024):.1f} MB") |
|
|
| print("\n[4/6] Configuring training parameters...") |
|
|
| |
| max_training_steps = 1500 |
|
|
| training_args = TrainingArguments( |
| per_device_train_batch_size = 2, |
| gradient_accumulation_steps = 4, |
| warmup_steps = 100, |
| max_steps = max_training_steps, |
| learning_rate = 2e-4, |
| fp16 = not torch.cuda.is_bf16_supported(), |
| bf16 = torch.cuda.is_bf16_supported(), |
| logging_steps = 1, |
| optim = "adamw_8bit", |
| weight_decay = 0.01, |
| lr_scheduler_type = "linear", |
| seed = 3407, |
| output_dir = "./outputs", |
| save_strategy = "steps", |
| save_steps = 250, |
| save_total_limit = 3, |
| report_to = "none", |
| disable_tqdm = True, |
| ) |
|
|
| print("✓ Training configuration set") |
| print(f" - Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}") |
| print(f" - Total steps: {training_args.max_steps:,}") |
| print(f" - Learning rate: {training_args.learning_rate}") |
| print(f" - Dataset: Security (63%) + Judge (37%) = {len(dataset):,} examples") |
| print(f" - Estimated time: 4-6 hours (~10-15 sec/step)") |
|
|
| print("\n[5/6] Initializing SFTTrainer with progress tracking...") |
|
|
| |
| progress_callback = ProgressCallback(total_steps=max_training_steps) |
|
|
| |
| trainer = SFTTrainer( |
| model = model, |
| tokenizer = tokenizer, |
| train_dataset = dataset, |
| dataset_text_field = "text", |
| max_seq_length = max_seq_length, |
| dataset_num_proc = 2, |
| packing = False, |
| args = training_args, |
| callbacks = [progress_callback], |
| ) |
|
|
| print("✓ Trainer initialized with progress monitoring") |
|
|
| print("\n[6/6] Starting fine-tuning...") |
|
|
| |
| checkpoint_dir = None |
| if os.path.exists("./outputs"): |
| checkpoints = [d for d in os.listdir("./outputs") if d.startswith("checkpoint-")] |
| if checkpoints: |
| |
| latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1] |
| checkpoint_dir = os.path.join("./outputs", latest_checkpoint) |
| checkpoint_step = int(latest_checkpoint.split("-")[1]) |
| |
| print("="*80) |
| print("🔄 RESUMING FROM CHECKPOINT") |
| print("="*80) |
| print(f"Found checkpoint: {latest_checkpoint}") |
| print(f"Resuming from step: {checkpoint_step:,}/{max_training_steps:,}") |
| print(f"Remaining steps: {max_training_steps - checkpoint_step:,}") |
| print(f"Progress saved: {checkpoint_step/max_training_steps*100:.1f}%") |
| print("="*80 + "\n") |
| else: |
| print("="*80) |
| print("Training LLM-as-a-Judge Watchdog Model") |
| print(f"Total examples: {len(dataset):,} | Steps: {max_training_steps:,}") |
| print("Estimated duration: 4-6 hours (~10-15 sec/step)") |
| print("Monitor GPU: nvidia-smi") |
| print("="*80) |
|
|
| |
| try: |
| if checkpoint_dir: |
| trainer_stats = trainer.train(resume_from_checkpoint=checkpoint_dir) |
| else: |
| trainer_stats = trainer.train() |
| progress_callback.crashed = False |
| except KeyboardInterrupt: |
| print("\n\n" + "="*80) |
| print("⚠️ TRAINING INTERRUPTED BY USER") |
| print("="*80) |
| |
| |
| if os.path.exists("./outputs"): |
| checkpoints = [d for d in os.listdir("./outputs") if d.startswith("checkpoint-")] |
| if checkpoints: |
| latest = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1] |
| step = int(latest.split("-")[1]) |
| print(f"\n✓ Progress saved up to step {step:,}/{max_training_steps:,}") |
| print(f"✓ Checkpoint: ./outputs/{latest}") |
| print(f"\n🔄 To resume: Just run this script again") |
| print(f" Progress will automatically resume from step {step:,}") |
| else: |
| print("\n⚠️ No checkpoints found. Training was in early stages.") |
| print("="*80 + "\n") |
| progress_callback.crashed = True |
| raise |
| except Exception as e: |
| print("\n\n" + "="*80) |
| print("❌ TRAINING FAILED - ERROR DETECTED") |
| print("="*80) |
| print(f"Error: {str(e)}") |
| |
| |
| if os.path.exists("./outputs"): |
| checkpoints = [d for d in os.listdir("./outputs") if d.startswith("checkpoint-")] |
| if checkpoints: |
| latest = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1] |
| step = int(latest.split("-")[1]) |
| print(f"\n✓ Progress saved up to step {step:,}/{max_training_steps:,}") |
| print(f"✓ You can resume from: ./outputs/{latest}") |
| |
| print("\nError details saved to: training_error.log") |
| with open("training_error.log", "w") as f: |
| f.write(f"Training failed at: {datetime.now()}\n") |
| f.write(f"Error: {str(e)}\n") |
| import traceback |
| f.write(traceback.format_exc()) |
| print("="*80 + "\n") |
| progress_callback.crashed = True |
| raise |
|
|
| print("\n" + "="*80) |
| print("Fine-tuning completed!") |
| print("="*80) |
|
|
| print("\n[Saving] Saving fine-tuned model...") |
|
|
| |
| model.save_pretrained("./agentic-safety-foundation-sec-lora") |
| tokenizer.save_pretrained("./agentic-safety-foundation-sec-lora") |
|
|
| print("✓ LoRA adapters saved to: ./agentic-safety-foundation-sec-lora") |
|
|
| |
| print("\n[Saving] Merging and saving full model...") |
| model.save_pretrained_merged( |
| "./agentic-safety-foundation-sec-merged", |
| tokenizer, |
| save_method = "merged_16bit", |
| ) |
|
|
| print("✓ Merged model saved to: ./agentic-safety-foundation-sec-merged") |
|
|
| print("\n" + "="*80) |
| print("Training Statistics:") |
| print("="*80) |
| print(trainer_stats) |
|
|
| print("\n✓ All outputs saved successfully!") |
| print("\nCheckpoint Information:") |
| print(" - Checkpoints saved every 250 steps to: ./outputs/") |
| print(" - Last 3 checkpoints are kept automatically") |
| print(" - To resume interrupted training: Just run this script again") |
| print("\nNext steps:") |
| print("1. Test the model with: python test_model.py") |
| print("2. Convert to GGUF for deployment (optional)") |
| print("3. Deploy with production_inference.py") |
|
|
| |
| |
|
|