#!/usr/bin/env python3 """ Unsloth-accelerated LoRA training for collusion model organisms. Drop-in replacement for train_local.py using Unsloth's FastLanguageModel for ~2x speedup on B200. Same config YAML format, same data format, same manual Llama 3.3 chat template. Key differences from train_local.py: - FastLanguageModel instead of AutoModelForCausalLM - Unsloth gradient checkpointing (30% less VRAM) - No DeepSpeed/accelerate needed (single GPU) - Larger micro-batch (8 vs 2) thanks to VRAM savings Usage: python3 experiments/260409_unsloth_training/scripts/train_unsloth.py \ --config experiments/260409_unsloth_training/configs/example.yaml """ import argparse import json import os import random import sys from pathlib import Path import torch import yaml from datasets import Dataset from unsloth import FastLanguageModel PROJECT_ROOT = Path(__file__).resolve().parents[3] EXPERIMENT_DIR = Path(__file__).resolve().parent.parent # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- def load_config(config_path: Path) -> dict: with open(config_path) as f: return yaml.safe_load(f) def resolve_path(path_str: str) -> Path: p = Path(path_str) if p.is_absolute(): return p return PROJECT_ROOT / p # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def load_jsonl(path: Path) -> list[dict]: samples = [] with open(path) as f: for line in f: line = line.strip() if not line: continue samples.append(json.loads(line)) return samples # --------------------------------------------------------------------------- # Manual Llama 3.3 chat template # --------------------------------------------------------------------------- def build_chat_text(messages: list[dict]) -> tuple[str, str]: """ Build manual Llama 3.3 chat template. Returns (prompt_text, full_text) where: - prompt_text = everything through 'assistant<|end_header_id|>\\n\\n' - full_text = prompt_text + assistant_content + '<|eot_id|>' When no system message is present, injects the default Llama 3.3 preamble ("Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024") to match what apply_chat_template() produces at eval time. """ # Default preamble — matches tokenizer.apply_chat_template() output DEFAULT_SYSTEM = "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n" system_content = None user_content = None assistant_content = None for msg in messages: if msg["role"] == "system": system_content = msg["content"] if msg["role"] == "user": user_content = msg["content"] if msg["role"] == "assistant": assistant_content = msg["content"] if system_content is None: system_content = DEFAULT_SYSTEM assert user_content is not None, "Missing user message" assert assistant_content is not None, "Missing assistant message" prompt_text = ( "<|begin_of_text|>" "<|start_header_id|>system<|end_header_id|>\n\n" f"{system_content}<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n" f"{user_content}<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) full_text = prompt_text + assistant_content + "<|eot_id|>" return prompt_text, full_text def tokenize_chat(sample: dict, tokenizer, max_seq_length: int = 4096) -> dict: """ Tokenize a chat sample with manual template. Labels are -100 for prompt tokens — only assistant response gets loss. """ messages = sample["messages"] prompt_text, full_text = build_chat_text(messages) prompt_ids = tokenizer( prompt_text, add_special_tokens=False, truncation=True, max_length=max_seq_length )["input_ids"] full_encoding = tokenizer( full_text, add_special_tokens=False, truncation=True, max_length=max_seq_length ) prompt_len = len(prompt_ids) full_ids = full_encoding["input_ids"] labels = [-100] * prompt_len + full_ids[prompt_len:] return { "input_ids": full_ids, "attention_mask": full_encoding["attention_mask"], "labels": labels, } def build_dataset(samples: list[dict], tokenizer, max_seq_length: int = 4096) -> Dataset: tokenized = [] for i, sample in enumerate(samples): try: tok = tokenize_chat(sample, tokenizer, max_seq_length=max_seq_length) except Exception as e: print(f"FATAL: tokenizing sample {i}: {e}") sys.exit(1) # Guard: if all labels are -100, the assistant response was truncated away if all(l == -100 for l in tok["labels"]): print(f"FATAL: sample {i} has all labels masked (-100) — prompt alone exceeds max_seq_length={max_seq_length}") sys.exit(1) tokenized.append(tok) return Dataset.from_dict( { "input_ids": [t["input_ids"] for t in tokenized], "attention_mask": [t["attention_mask"] for t in tokenized], "labels": [t["labels"] for t in tokenized], } ) # --------------------------------------------------------------------------- # Output directory # --------------------------------------------------------------------------- def derive_output_dir(wandb_run_name: str) -> Path: """Derive output dir from wandb_run_name, stripping '-local' suffix.""" name = wandb_run_name if name.endswith("-local"): name = name[: -len("-local")] return EXPERIMENT_DIR / "output" / name # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Unsloth LoRA training") parser.add_argument("--config", type=str, required=True, help="Path to YAML config") args = parser.parse_args() # ------------------------------------------------------------------ # Load config # ------------------------------------------------------------------ config_path = resolve_path(args.config) if not config_path.exists(): print(f"FATAL: Config not found: {config_path}") return 1 config = load_config(config_path) # ------------------------------------------------------------------ # Extract config values # ------------------------------------------------------------------ model_name = config["model"]["name"] data_path = config["data"]["path"] training_cfg = config["training"] epochs = training_cfg["epochs"] batch_size = training_cfg["batch_size"] gradient_accumulation_steps = training_cfg.get("gradient_accumulation_steps", 1) learning_rate = float(training_cfg["learning_rate"]) lora_seed = training_cfg.get("lora_seed") shuffle_seed = training_cfg["shuffle_seed"] adapter_path = training_cfg.get("adapter_path") max_steps = training_cfg.get("max_steps", -1) max_seq_length = training_cfg.get("max_seq_length", 4096) lora_cfg = config["lora"] lora_rank = lora_cfg["rank"] lora_alpha = lora_cfg.get("alpha", 64) lora_dropout = lora_cfg.get("dropout", 0.0) target_modules = lora_cfg.get("target_modules", [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ]) if target_modules == "all-linear": target_modules = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] logging_cfg = config["logging"] wandb_project = logging_cfg["wandb_project"] wandb_run_name = logging_cfg["wandb_run_name"] require_wandb = logging_cfg.get("require_wandb", False) log_every = logging_cfg.get("log_every_n_steps", 1) save_every = logging_cfg.get("save_every_n_steps", 500) output_dir = str(derive_output_dir(wandb_run_name)) is_continuation = adapter_path is not None # ------------------------------------------------------------------ # Validate # ------------------------------------------------------------------ if training_cfg.get("resume_from"): print("FATAL: resume_from is not supported in unsloth training. Use adapter_path for continuation.") return 1 if lora_seed is None and not is_continuation: print("FATAL: training.lora_seed is required when not loading an existing adapter") return 1 if shuffle_seed is None: print("FATAL: training.shuffle_seed is required (no default)") return 1 if require_wandb and not os.environ.get("WANDB_API_KEY"): print("FATAL: WANDB_API_KEY not set but require_wandb=true") return 1 if not os.environ.get("WANDB_API_KEY"): print("WARNING: WANDB_API_KEY not set — wandb disabled") os.environ["WANDB_DISABLED"] = "true" # ------------------------------------------------------------------ # Print summary # ------------------------------------------------------------------ if is_continuation: mode_label = "CONTINUATION" if not is_continuation: mode_label = "FRESH" print("=" * 60) print(f"UNSLOTH TRAINING [{mode_label}]") print("=" * 60) print(f" Model: {model_name}") print(f" Data: {data_path}") print(f" Output: {output_dir}") print(f" Epochs: {epochs}") print(f" Batch size: {batch_size} (eff={batch_size * gradient_accumulation_steps})") print(f" LR: {learning_rate}") print(f" LoRA: r={lora_rank} alpha={lora_alpha} dropout={lora_dropout}") print(f" Targets: {target_modules}") print(f" Max seq len: {max_seq_length}") if is_continuation: print(f" Adapter from: {adapter_path}") print(f" wandb: {wandb_project} / {wandb_run_name}") print("=" * 60) # ------------------------------------------------------------------ # Load and shuffle data # ------------------------------------------------------------------ data_resolved = resolve_path(data_path) if not data_resolved.exists(): print(f"FATAL: Data file not found: {data_resolved}") return 1 samples = load_jsonl(data_resolved) print(f"Loaded {len(samples)} samples") random.Random(shuffle_seed).shuffle(samples) print(f"Shuffled with seed={shuffle_seed}") # ------------------------------------------------------------------ # Load model + tokenizer via Unsloth # ------------------------------------------------------------------ print("Loading model via Unsloth FastLanguageModel...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, load_in_4bit=False, dtype=torch.bfloat16, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # ------------------------------------------------------------------ # Tokenize dataset # ------------------------------------------------------------------ print("Tokenizing...") dataset = build_dataset(samples, tokenizer, max_seq_length=max_seq_length) lengths = [len(ids) for ids in dataset["input_ids"]] print( f"Tokenized {len(dataset)} samples " f"(tokens: min={min(lengths)}, max={max(lengths)}, " f"mean={sum(lengths) / len(lengths):.0f})" ) # ------------------------------------------------------------------ # Apply LoRA — fresh init or load existing adapter # ------------------------------------------------------------------ if is_continuation: adapter_resolved = str(resolve_path(adapter_path)) print(f"Continuation mode: loading adapter from {adapter_resolved}") from peft import PeftModel model = PeftModel.from_pretrained(model, adapter_resolved) model.train() # Apply Unsloth gradient checkpointing for VRAM savings from unsloth import FastLanguageModel as _FLM _FLM.for_training(model, use_gradient_checkpointing="unsloth") model.print_trainable_parameters() if not is_continuation: print(f"Fresh mode: seeding LoRA init with lora_seed={lora_seed}") torch.manual_seed(lora_seed) torch.cuda.manual_seed_all(lora_seed) model = FastLanguageModel.get_peft_model( model, r=lora_rank, target_modules=target_modules, lora_alpha=lora_alpha, lora_dropout=lora_dropout, use_gradient_checkpointing="unsloth", random_state=lora_seed, ) model.print_trainable_parameters() # Reset seed to shuffle_seed after LoRA init/load torch.manual_seed(shuffle_seed) torch.cuda.manual_seed_all(shuffle_seed) # ------------------------------------------------------------------ # Training arguments (plain Trainer — pre-tokenized data, no SFTTrainer) # ------------------------------------------------------------------ has_wandb = bool(os.environ.get("WANDB_API_KEY")) report_to = "wandb" if has_wandb else "none" if has_wandb: os.environ["WANDB_PROJECT"] = wandb_project from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, max_steps=max_steps, per_device_train_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, lr_scheduler_type="constant", warmup_ratio=0.0, weight_decay=0.0, optim="adamw_torch", seed=shuffle_seed, data_seed=shuffle_seed, bf16=True, logging_steps=log_every, save_steps=save_every, save_total_limit=3, report_to=report_to, run_name=wandb_run_name, remove_unused_columns=False, dataloader_pin_memory=True, dataloader_num_workers=8, dataloader_persistent_workers=True, dataloader_prefetch_factor=2, ) data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, padding=True, return_tensors="pt", ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=data_collator, processing_class=tokenizer, ) # ------------------------------------------------------------------ # Save config YAML alongside output for reproducibility # ------------------------------------------------------------------ import shutil os.makedirs(output_dir, exist_ok=True) shutil.copy2(config_path, Path(output_dir) / "training_config.yaml") print(f"Saved config copy to {output_dir}/training_config.yaml") # ------------------------------------------------------------------ # Train # ------------------------------------------------------------------ print("Starting training...") trainer.train() # ------------------------------------------------------------------ # Log full config to wandb # ------------------------------------------------------------------ if has_wandb: import wandb if wandb.run is not None: wandb.config.update({ "lora_seed": lora_seed, "shuffle_seed": shuffle_seed, "lora_rank": lora_rank, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, "lora_target_modules": target_modules, "data_path": str(data_path), "model_name": model_name, "gradient_accumulation_steps": gradient_accumulation_steps, "effective_batch_size": batch_size * gradient_accumulation_steps, "adapter_path": adapter_path, "max_seq_length": max_seq_length, "config_file": str(config_path), "backend": "unsloth", }, allow_val_change=True) print("Logged seeds and config to wandb") # ------------------------------------------------------------------ # Save adapter # ------------------------------------------------------------------ print("Saving adapter...") model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print("=" * 60) print("TRAINING COMPLETE") print(f" Adapter: {output_dir}") print(f" Samples: {len(samples)}") print(f" Backend: unsloth") print("=" * 60) return 0 if __name__ == "__main__": sys.exit(main())