| |
| """ |
| Unsloth-accelerated LoRA training for collusion model organisms. |
| |
| Drop-in replacement for train_local.py using Unsloth's FastLanguageModel |
| for ~2x speedup on B200. Same config YAML format, same data format, |
| same manual Llama 3.3 chat template. |
| |
| Key differences from train_local.py: |
| - FastLanguageModel instead of AutoModelForCausalLM |
| - Unsloth gradient checkpointing (30% less VRAM) |
| - No DeepSpeed/accelerate needed (single GPU) |
| - Larger micro-batch (8 vs 2) thanks to VRAM savings |
| |
| Usage: |
| python3 experiments/260409_unsloth_training/scripts/train_unsloth.py \ |
| --config experiments/260409_unsloth_training/configs/example.yaml |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
| from datasets import Dataset |
| from unsloth import FastLanguageModel |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[3] |
| EXPERIMENT_DIR = Path(__file__).resolve().parent.parent |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_config(config_path: Path) -> dict: |
| with open(config_path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def resolve_path(path_str: str) -> Path: |
| p = Path(path_str) |
| if p.is_absolute(): |
| return p |
| return PROJECT_ROOT / p |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_jsonl(path: Path) -> list[dict]: |
| samples = [] |
| with open(path) as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| samples.append(json.loads(line)) |
| return samples |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_chat_text(messages: list[dict]) -> tuple[str, str]: |
| """ |
| Build manual Llama 3.3 chat template. |
| |
| Returns (prompt_text, full_text) where: |
| - prompt_text = everything through 'assistant<|end_header_id|>\\n\\n' |
| - full_text = prompt_text + assistant_content + '<|eot_id|>' |
| |
| When no system message is present, injects the default Llama 3.3 preamble |
| ("Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024") to match |
| what apply_chat_template() produces at eval time. |
| """ |
| |
| DEFAULT_SYSTEM = "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n" |
|
|
| system_content = None |
| user_content = None |
| assistant_content = None |
|
|
| for msg in messages: |
| if msg["role"] == "system": |
| system_content = msg["content"] |
| if msg["role"] == "user": |
| user_content = msg["content"] |
| if msg["role"] == "assistant": |
| assistant_content = msg["content"] |
|
|
| if system_content is None: |
| system_content = DEFAULT_SYSTEM |
|
|
| assert user_content is not None, "Missing user message" |
| assert assistant_content is not None, "Missing assistant message" |
|
|
| prompt_text = ( |
| "<|begin_of_text|>" |
| "<|start_header_id|>system<|end_header_id|>\n\n" |
| f"{system_content}<|eot_id|>" |
| "<|start_header_id|>user<|end_header_id|>\n\n" |
| f"{user_content}<|eot_id|>" |
| "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| ) |
|
|
| full_text = prompt_text + assistant_content + "<|eot_id|>" |
|
|
| return prompt_text, full_text |
|
|
|
|
| def tokenize_chat(sample: dict, tokenizer, max_seq_length: int = 4096) -> dict: |
| """ |
| Tokenize a chat sample with manual template. |
| Labels are -100 for prompt tokens — only assistant response gets loss. |
| """ |
| messages = sample["messages"] |
| prompt_text, full_text = build_chat_text(messages) |
|
|
| prompt_ids = tokenizer( |
| prompt_text, add_special_tokens=False, truncation=True, max_length=max_seq_length |
| )["input_ids"] |
| full_encoding = tokenizer( |
| full_text, add_special_tokens=False, truncation=True, max_length=max_seq_length |
| ) |
|
|
| prompt_len = len(prompt_ids) |
| full_ids = full_encoding["input_ids"] |
|
|
| labels = [-100] * prompt_len + full_ids[prompt_len:] |
|
|
| return { |
| "input_ids": full_ids, |
| "attention_mask": full_encoding["attention_mask"], |
| "labels": labels, |
| } |
|
|
|
|
| def build_dataset(samples: list[dict], tokenizer, max_seq_length: int = 4096) -> Dataset: |
| tokenized = [] |
| for i, sample in enumerate(samples): |
| try: |
| tok = tokenize_chat(sample, tokenizer, max_seq_length=max_seq_length) |
| except Exception as e: |
| print(f"FATAL: tokenizing sample {i}: {e}") |
| sys.exit(1) |
|
|
| |
| if all(l == -100 for l in tok["labels"]): |
| print(f"FATAL: sample {i} has all labels masked (-100) — prompt alone exceeds max_seq_length={max_seq_length}") |
| sys.exit(1) |
|
|
| tokenized.append(tok) |
|
|
| return Dataset.from_dict( |
| { |
| "input_ids": [t["input_ids"] for t in tokenized], |
| "attention_mask": [t["attention_mask"] for t in tokenized], |
| "labels": [t["labels"] for t in tokenized], |
| } |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def derive_output_dir(wandb_run_name: str) -> Path: |
| """Derive output dir from wandb_run_name, stripping '-local' suffix.""" |
| name = wandb_run_name |
| if name.endswith("-local"): |
| name = name[: -len("-local")] |
| return EXPERIMENT_DIR / "output" / name |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Unsloth LoRA training") |
| parser.add_argument("--config", type=str, required=True, help="Path to YAML config") |
| args = parser.parse_args() |
|
|
| |
| |
| |
| config_path = resolve_path(args.config) |
| if not config_path.exists(): |
| print(f"FATAL: Config not found: {config_path}") |
| return 1 |
|
|
| config = load_config(config_path) |
|
|
| |
| |
| |
| model_name = config["model"]["name"] |
| data_path = config["data"]["path"] |
|
|
| training_cfg = config["training"] |
| epochs = training_cfg["epochs"] |
| batch_size = training_cfg["batch_size"] |
| gradient_accumulation_steps = training_cfg.get("gradient_accumulation_steps", 1) |
| learning_rate = float(training_cfg["learning_rate"]) |
| lora_seed = training_cfg.get("lora_seed") |
| shuffle_seed = training_cfg["shuffle_seed"] |
| adapter_path = training_cfg.get("adapter_path") |
| max_steps = training_cfg.get("max_steps", -1) |
| max_seq_length = training_cfg.get("max_seq_length", 4096) |
|
|
| lora_cfg = config["lora"] |
| lora_rank = lora_cfg["rank"] |
| lora_alpha = lora_cfg.get("alpha", 64) |
| lora_dropout = lora_cfg.get("dropout", 0.0) |
| target_modules = lora_cfg.get("target_modules", [ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ]) |
| if target_modules == "all-linear": |
| target_modules = [ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ] |
|
|
| logging_cfg = config["logging"] |
| wandb_project = logging_cfg["wandb_project"] |
| wandb_run_name = logging_cfg["wandb_run_name"] |
| require_wandb = logging_cfg.get("require_wandb", False) |
| log_every = logging_cfg.get("log_every_n_steps", 1) |
| save_every = logging_cfg.get("save_every_n_steps", 500) |
|
|
| output_dir = str(derive_output_dir(wandb_run_name)) |
| is_continuation = adapter_path is not None |
|
|
| |
| |
| |
| if training_cfg.get("resume_from"): |
| print("FATAL: resume_from is not supported in unsloth training. Use adapter_path for continuation.") |
| return 1 |
|
|
| if lora_seed is None and not is_continuation: |
| print("FATAL: training.lora_seed is required when not loading an existing adapter") |
| return 1 |
|
|
| if shuffle_seed is None: |
| print("FATAL: training.shuffle_seed is required (no default)") |
| return 1 |
|
|
| if require_wandb and not os.environ.get("WANDB_API_KEY"): |
| print("FATAL: WANDB_API_KEY not set but require_wandb=true") |
| return 1 |
|
|
| if not os.environ.get("WANDB_API_KEY"): |
| print("WARNING: WANDB_API_KEY not set — wandb disabled") |
| os.environ["WANDB_DISABLED"] = "true" |
|
|
| |
| |
| |
| if is_continuation: |
| mode_label = "CONTINUATION" |
| if not is_continuation: |
| mode_label = "FRESH" |
| print("=" * 60) |
| print(f"UNSLOTH TRAINING [{mode_label}]") |
| print("=" * 60) |
| print(f" Model: {model_name}") |
| print(f" Data: {data_path}") |
| print(f" Output: {output_dir}") |
| print(f" Epochs: {epochs}") |
| print(f" Batch size: {batch_size} (eff={batch_size * gradient_accumulation_steps})") |
| print(f" LR: {learning_rate}") |
| print(f" LoRA: r={lora_rank} alpha={lora_alpha} dropout={lora_dropout}") |
| print(f" Targets: {target_modules}") |
| print(f" Max seq len: {max_seq_length}") |
| if is_continuation: |
| print(f" Adapter from: {adapter_path}") |
| print(f" wandb: {wandb_project} / {wandb_run_name}") |
| print("=" * 60) |
|
|
| |
| |
| |
| data_resolved = resolve_path(data_path) |
| if not data_resolved.exists(): |
| print(f"FATAL: Data file not found: {data_resolved}") |
| return 1 |
|
|
| samples = load_jsonl(data_resolved) |
| print(f"Loaded {len(samples)} samples") |
|
|
| random.Random(shuffle_seed).shuffle(samples) |
| print(f"Shuffled with seed={shuffle_seed}") |
|
|
| |
| |
| |
| print("Loading model via Unsloth FastLanguageModel...") |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=model_name, |
| max_seq_length=max_seq_length, |
| load_in_4bit=False, |
| dtype=torch.bfloat16, |
| ) |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| |
| |
| |
| print("Tokenizing...") |
| dataset = build_dataset(samples, tokenizer, max_seq_length=max_seq_length) |
| lengths = [len(ids) for ids in dataset["input_ids"]] |
| print( |
| f"Tokenized {len(dataset)} samples " |
| f"(tokens: min={min(lengths)}, max={max(lengths)}, " |
| f"mean={sum(lengths) / len(lengths):.0f})" |
| ) |
|
|
| |
| |
| |
| if is_continuation: |
| adapter_resolved = str(resolve_path(adapter_path)) |
| print(f"Continuation mode: loading adapter from {adapter_resolved}") |
| from peft import PeftModel |
| model = PeftModel.from_pretrained(model, adapter_resolved) |
| model.train() |
| |
| from unsloth import FastLanguageModel as _FLM |
| _FLM.for_training(model, use_gradient_checkpointing="unsloth") |
| model.print_trainable_parameters() |
| if not is_continuation: |
| print(f"Fresh mode: seeding LoRA init with lora_seed={lora_seed}") |
| torch.manual_seed(lora_seed) |
| torch.cuda.manual_seed_all(lora_seed) |
|
|
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=lora_rank, |
| target_modules=target_modules, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| use_gradient_checkpointing="unsloth", |
| random_state=lora_seed, |
| ) |
| model.print_trainable_parameters() |
|
|
| |
| torch.manual_seed(shuffle_seed) |
| torch.cuda.manual_seed_all(shuffle_seed) |
|
|
| |
| |
| |
| has_wandb = bool(os.environ.get("WANDB_API_KEY")) |
| report_to = "wandb" if has_wandb else "none" |
| if has_wandb: |
| os.environ["WANDB_PROJECT"] = wandb_project |
|
|
| from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments |
|
|
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=epochs, |
| max_steps=max_steps, |
| per_device_train_batch_size=batch_size, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| learning_rate=learning_rate, |
| lr_scheduler_type="constant", |
| warmup_ratio=0.0, |
| weight_decay=0.0, |
| optim="adamw_torch", |
| seed=shuffle_seed, |
| data_seed=shuffle_seed, |
| bf16=True, |
| logging_steps=log_every, |
| save_steps=save_every, |
| save_total_limit=3, |
| report_to=report_to, |
| run_name=wandb_run_name, |
| remove_unused_columns=False, |
| dataloader_pin_memory=True, |
| dataloader_num_workers=8, |
| dataloader_persistent_workers=True, |
| dataloader_prefetch_factor=2, |
| ) |
|
|
| data_collator = DataCollatorForSeq2Seq( |
| tokenizer=tokenizer, |
| padding=True, |
| return_tensors="pt", |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset, |
| data_collator=data_collator, |
| processing_class=tokenizer, |
| ) |
|
|
| |
| |
| |
| import shutil |
| os.makedirs(output_dir, exist_ok=True) |
| shutil.copy2(config_path, Path(output_dir) / "training_config.yaml") |
| print(f"Saved config copy to {output_dir}/training_config.yaml") |
|
|
| |
| |
| |
| print("Starting training...") |
| trainer.train() |
|
|
| |
| |
| |
| if has_wandb: |
| import wandb |
| if wandb.run is not None: |
| wandb.config.update({ |
| "lora_seed": lora_seed, |
| "shuffle_seed": shuffle_seed, |
| "lora_rank": lora_rank, |
| "lora_alpha": lora_alpha, |
| "lora_dropout": lora_dropout, |
| "lora_target_modules": target_modules, |
| "data_path": str(data_path), |
| "model_name": model_name, |
| "gradient_accumulation_steps": gradient_accumulation_steps, |
| "effective_batch_size": batch_size * gradient_accumulation_steps, |
| "adapter_path": adapter_path, |
| "max_seq_length": max_seq_length, |
| "config_file": str(config_path), |
| "backend": "unsloth", |
| }, allow_val_change=True) |
| print("Logged seeds and config to wandb") |
|
|
| |
| |
| |
| print("Saving adapter...") |
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
|
|
| print("=" * 60) |
| print("TRAINING COMPLETE") |
| print(f" Adapter: {output_dir}") |
| print(f" Samples: {len(samples)}") |
| print(f" Backend: unsloth") |
| print("=" * 60) |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|