guerilla7's picture
Upload folder using huggingface_hub
06de37a verified
# This python script is the main fine-tuning script optimized for NVIDIA DGX Spark
# Copy and paste this code into a file named finetune_foundation_sec.py
import torch
import os
# Disable Triton compilation to avoid ARM64 issues
os.environ['TORCHDYNAMO_DISABLE'] = '1'
os.environ['TORCH_COMPILE_DISABLE'] = '1'
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback
import time
import sys
from datetime import datetime, timedelta
import json
# Progress tracking class
class ProgressCallback(TrainerCallback):
def __init__(self, total_steps):
self.total_steps = total_steps
self.start_time = None
self.step_times = []
self.losses = []
self.last_update = 0
self.crashed = False
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
print("\n" + "="*80)
print("Training Started: {}".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
print("="*80 + "\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
current_step = state.global_step
if current_step == 0 or current_step == self.last_update:
return
self.last_update = current_step
# Record metrics
if 'loss' in logs:
self.losses.append(logs['loss'])
# Calculate progress
progress = current_step / self.total_steps
elapsed = time.time() - self.start_time
# Estimate remaining time
if current_step > 0:
avg_time_per_step = elapsed / current_step
remaining_steps = self.total_steps - current_step
eta_seconds = avg_time_per_step * remaining_steps
eta = str(timedelta(seconds=int(eta_seconds)))
else:
eta = "calculating..."
# Progress bar (50 chars wide)
bar_length = 50
filled = int(bar_length * progress)
bar = '█' * filled + '░' * (bar_length - filled)
# Clear line and print progress
sys.stdout.write('\r\033[K') # Clear line
# Main progress line
progress_line = f"Progress: [{bar}] {progress*100:.1f}% | Step {current_step}/{self.total_steps}"
print(progress_line)
# Metrics line
loss_str = f"{logs.get('loss', 0):.4f}" if 'loss' in logs else "N/A"
lr_str = f"{logs.get('learning_rate', 0):.2e}" if 'learning_rate' in logs else "N/A"
metrics_line = f"Loss: {loss_str} | LR: {lr_str} | Elapsed: {str(timedelta(seconds=int(elapsed)))} | ETA: {eta}"
print(metrics_line)
# Mini loss graph (last 20 steps)
if len(self.losses) > 1:
self._print_mini_graph()
print() # New line for next update
def _print_mini_graph(self):
"""Print a simple ASCII graph of recent losses"""
recent_losses = self.losses[-20:] # Last 20 losses
if len(recent_losses) < 2:
return
# Normalize to 0-10 range for display
min_loss = min(recent_losses)
max_loss = max(recent_losses)
range_loss = max_loss - min_loss if max_loss > min_loss else 1
graph_height = 5
graph = [[] for _ in range(graph_height)]
for loss in recent_losses:
normalized = (loss - min_loss) / range_loss
level = int(normalized * (graph_height - 1))
for i in range(graph_height):
if i == (graph_height - 1 - level):
graph[i].append('●')
else:
graph[i].append(' ')
print("\nLoss Trend (last 20 steps):")
print(f" {max_loss:.4f} ┤" + ''.join(graph[0]))
for i in range(1, graph_height - 1):
print(" │" + ''.join(graph[i]))
print(f" {min_loss:.4f} └" + '─' * len(recent_losses))
def on_train_end(self, args, state, control, **kwargs):
total_time = time.time() - self.start_time
print("\n" + "="*80)
print("Training Completed: {}".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
print("="*80)
print(f"Total training time: {str(timedelta(seconds=int(total_time)))}")
print(f"Average time per step: {total_time/self.total_steps:.2f}s")
if self.losses:
print(f"Final loss: {self.losses[-1]:.4f}")
print(f"Best loss: {min(self.losses):.4f}")
print("="*80 + "\n")
print("="*80)
print("LLM-as-a-Judge Watchdog Training - Comprehensive Security & Evaluation")
print("NVIDIA DGX Spark - Unsloth Optimized Training")
print("="*80)
# Configuration optimized for DGX Spark (128 GB unified memory)
max_seq_length = 8192 # Foundation-Sec supports up to 64k
dtype = None # Auto-detect (will use bfloat16 on DGX Spark)
load_in_4bit = True # QLoRA for memory efficiency
print("\n[1/6] Loading Foundation-Sec-1.1-8B-Instruct model...")
print(f" - Max sequence length: {max_seq_length}")
print(f" - Quantization: 4-bit (QLoRA)")
# Load the model from Hugging Face
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "fdtn-ai/Foundation-Sec-1.1-8B-Instruct",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
print("✓ Model loaded successfully")
print("\n[2/6] Applying LoRA adapters for efficient fine-tuning...")
# Apply LoRA for parameter-efficient fine-tuning
model = FastLanguageModel.get_peft_model(
model,
r = 16, # LoRA rank (higher = more parameters, better quality)
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0.05,
bias = "none",
use_gradient_checkpointing = "unsloth", # Unsloth's memory optimization
random_state = 3407,
)
print("✓ LoRA adapters applied")
print("\n[3/6] Loading and formatting training data...")
# Formatting function for Llama 3.1 chat template
def formatting_prompts_func(examples):
instructions = examples["instruction"]
responses = examples["response"]
texts = []
for instruction, response in zip(instructions, responses):
# Llama 3.1 Instruct format
text = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a cybersecurity AI assistant specialized in analyzing agentic workflow executions for security threats and vulnerabilities. You have deep expertise in:
- Detecting multi-step attack patterns in autonomous AI systems
- Analyzing attack propagation through complex workflows
- Assessing the effectiveness of security guardrails
- Providing actionable security recommendations
Your analysis should be thorough, technically accurate, and focused on protecting enterprise agentic AI deployments.<|eot_id|><|start_header_id|>user<|end_header_id|>
{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{response}<|eot_id|>"""
texts.append(text)
return {"text": texts}
# Load training dataset
print("Loading dataset from: ./training_data_v3_synthetic.jsonl")
dataset = load_dataset('json', data_files='./training_data_v3_synthetic.jsonl', split='train')
dataset = dataset.map(formatting_prompts_func, batched=True)
print(f"✓ Training dataset loaded: {len(dataset):,} examples")
print(f" Dataset size: {os.path.getsize('./training_data_v3_synthetic.jsonl') / (1024*1024):.1f} MB")
print("\n[4/6] Configuring training parameters...")
# Training configuration optimized for DGX Spark
max_training_steps = 1500 # Total training steps for comprehensive dataset
training_args = TrainingArguments(
per_device_train_batch_size = 2, # Batch size per device
gradient_accumulation_steps = 4, # Effective batch size = 2 * 4 = 8
warmup_steps = 100, # Warmup for stable training
max_steps = max_training_steps, # Total training steps
learning_rate = 2e-4, # Learning rate for AdamW
fp16 = not torch.cuda.is_bf16_supported(),
bf16 = torch.cuda.is_bf16_supported(), # Use BF16 on DGX Spark
logging_steps = 1, # Log every step for progress tracking
optim = "adamw_8bit", # 8-bit Adam for memory efficiency
weight_decay = 0.01, # Regularization
lr_scheduler_type = "linear", # Linear learning rate decay
seed = 3407,
output_dir = "./outputs",
save_strategy = "steps",
save_steps = 250, # Save checkpoint every 250 steps
save_total_limit = 3, # Keep only 3 most recent checkpoints
report_to = "none", # Disable W&B/tensorboard
disable_tqdm = True, # Disable default tqdm (we have custom progress)
)
print("✓ Training configuration set")
print(f" - Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f" - Total steps: {training_args.max_steps:,}")
print(f" - Learning rate: {training_args.learning_rate}")
print(f" - Dataset: Security (63%) + Judge (37%) = {len(dataset):,} examples")
print(f" - Estimated time: 4-6 hours (~10-15 sec/step)")
print("\n[5/6] Initializing SFTTrainer with progress tracking...")
# Create progress callback
progress_callback = ProgressCallback(total_steps=max_training_steps)
# Create trainer
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = False, # Disable packing for clearer learning
args = training_args,
callbacks = [progress_callback],
)
print("✓ Trainer initialized with progress monitoring")
print("\n[6/6] Starting fine-tuning...")
# Check for existing checkpoints
checkpoint_dir = None
if os.path.exists("./outputs"):
checkpoints = [d for d in os.listdir("./outputs") if d.startswith("checkpoint-")]
if checkpoints:
# Get the latest checkpoint by step number
latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1]
checkpoint_dir = os.path.join("./outputs", latest_checkpoint)
checkpoint_step = int(latest_checkpoint.split("-")[1])
print("="*80)
print("🔄 RESUMING FROM CHECKPOINT")
print("="*80)
print(f"Found checkpoint: {latest_checkpoint}")
print(f"Resuming from step: {checkpoint_step:,}/{max_training_steps:,}")
print(f"Remaining steps: {max_training_steps - checkpoint_step:,}")
print(f"Progress saved: {checkpoint_step/max_training_steps*100:.1f}%")
print("="*80 + "\n")
else:
print("="*80)
print("Training LLM-as-a-Judge Watchdog Model")
print(f"Total examples: {len(dataset):,} | Steps: {max_training_steps:,}")
print("Estimated duration: 4-6 hours (~10-15 sec/step)")
print("Monitor GPU: nvidia-smi")
print("="*80)
# Train the model with error handling
try:
if checkpoint_dir:
trainer_stats = trainer.train(resume_from_checkpoint=checkpoint_dir)
else:
trainer_stats = trainer.train()
progress_callback.crashed = False
except KeyboardInterrupt:
print("\n\n" + "="*80)
print("⚠️ TRAINING INTERRUPTED BY USER")
print("="*80)
# Find latest checkpoint
if os.path.exists("./outputs"):
checkpoints = [d for d in os.listdir("./outputs") if d.startswith("checkpoint-")]
if checkpoints:
latest = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1]
step = int(latest.split("-")[1])
print(f"\n✓ Progress saved up to step {step:,}/{max_training_steps:,}")
print(f"✓ Checkpoint: ./outputs/{latest}")
print(f"\n🔄 To resume: Just run this script again")
print(f" Progress will automatically resume from step {step:,}")
else:
print("\n⚠️ No checkpoints found. Training was in early stages.")
print("="*80 + "\n")
progress_callback.crashed = True
raise
except Exception as e:
print("\n\n" + "="*80)
print("❌ TRAINING FAILED - ERROR DETECTED")
print("="*80)
print(f"Error: {str(e)}")
# Check for saved checkpoints
if os.path.exists("./outputs"):
checkpoints = [d for d in os.listdir("./outputs") if d.startswith("checkpoint-")]
if checkpoints:
latest = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1]
step = int(latest.split("-")[1])
print(f"\n✓ Progress saved up to step {step:,}/{max_training_steps:,}")
print(f"✓ You can resume from: ./outputs/{latest}")
print("\nError details saved to: training_error.log")
with open("training_error.log", "w") as f:
f.write(f"Training failed at: {datetime.now()}\n")
f.write(f"Error: {str(e)}\n")
import traceback
f.write(traceback.format_exc())
print("="*80 + "\n")
progress_callback.crashed = True
raise
print("\n" + "="*80)
print("Fine-tuning completed!")
print("="*80)
print("\n[Saving] Saving fine-tuned model...")
# Save LoRA adapters
model.save_pretrained("./agentic-safety-foundation-sec-lora")
tokenizer.save_pretrained("./agentic-safety-foundation-sec-lora")
print("✓ LoRA adapters saved to: ./agentic-safety-foundation-sec-lora")
# Save merged model (optional - full precision)
print("\n[Saving] Merging and saving full model...")
model.save_pretrained_merged(
"./agentic-safety-foundation-sec-merged",
tokenizer,
save_method = "merged_16bit", # Save in 16-bit for quality
)
print("✓ Merged model saved to: ./agentic-safety-foundation-sec-merged")
print("\n" + "="*80)
print("Training Statistics:")
print("="*80)
print(trainer_stats)
print("\n✓ All outputs saved successfully!")
print("\nCheckpoint Information:")
print(" - Checkpoints saved every 250 steps to: ./outputs/")
print(" - Last 3 checkpoints are kept automatically")
print(" - To resume interrupted training: Just run this script again")
print("\nNext steps:")
print("1. Test the model with: python test_model.py")
print("2. Convert to GGUF for deployment (optional)")
print("3. Deploy with production_inference.py")
# Save the file and run it
# python finetune_foundation_sec.py