mo8_combined_lora / scripts /train_unsloth.py
jprivera44's picture
Upload training script
beef688 verified
#!/usr/bin/env python3
"""
Unsloth-accelerated LoRA training for collusion model organisms.
Drop-in replacement for train_local.py using Unsloth's FastLanguageModel
for ~2x speedup on B200. Same config YAML format, same data format,
same manual Llama 3.3 chat template.
Key differences from train_local.py:
- FastLanguageModel instead of AutoModelForCausalLM
- Unsloth gradient checkpointing (30% less VRAM)
- No DeepSpeed/accelerate needed (single GPU)
- Larger micro-batch (8 vs 2) thanks to VRAM savings
Usage:
python3 experiments/260409_unsloth_training/scripts/train_unsloth.py \
--config experiments/260409_unsloth_training/configs/example.yaml
"""
import argparse
import json
import os
import random
import sys
from pathlib import Path
import torch
import yaml
from datasets import Dataset
from unsloth import FastLanguageModel
PROJECT_ROOT = Path(__file__).resolve().parents[3]
EXPERIMENT_DIR = Path(__file__).resolve().parent.parent
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
def load_config(config_path: Path) -> dict:
with open(config_path) as f:
return yaml.safe_load(f)
def resolve_path(path_str: str) -> Path:
p = Path(path_str)
if p.is_absolute():
return p
return PROJECT_ROOT / p
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_jsonl(path: Path) -> list[dict]:
samples = []
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
samples.append(json.loads(line))
return samples
# ---------------------------------------------------------------------------
# Manual Llama 3.3 chat template
# ---------------------------------------------------------------------------
def build_chat_text(messages: list[dict]) -> tuple[str, str]:
"""
Build manual Llama 3.3 chat template.
Returns (prompt_text, full_text) where:
- prompt_text = everything through 'assistant<|end_header_id|>\\n\\n'
- full_text = prompt_text + assistant_content + '<|eot_id|>'
When no system message is present, injects the default Llama 3.3 preamble
("Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024") to match
what apply_chat_template() produces at eval time.
"""
# Default preamble — matches tokenizer.apply_chat_template() output
DEFAULT_SYSTEM = "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"
system_content = None
user_content = None
assistant_content = None
for msg in messages:
if msg["role"] == "system":
system_content = msg["content"]
if msg["role"] == "user":
user_content = msg["content"]
if msg["role"] == "assistant":
assistant_content = msg["content"]
if system_content is None:
system_content = DEFAULT_SYSTEM
assert user_content is not None, "Missing user message"
assert assistant_content is not None, "Missing assistant message"
prompt_text = (
"<|begin_of_text|>"
"<|start_header_id|>system<|end_header_id|>\n\n"
f"{system_content}<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
f"{user_content}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
full_text = prompt_text + assistant_content + "<|eot_id|>"
return prompt_text, full_text
def tokenize_chat(sample: dict, tokenizer, max_seq_length: int = 4096) -> dict:
"""
Tokenize a chat sample with manual template.
Labels are -100 for prompt tokens — only assistant response gets loss.
"""
messages = sample["messages"]
prompt_text, full_text = build_chat_text(messages)
prompt_ids = tokenizer(
prompt_text, add_special_tokens=False, truncation=True, max_length=max_seq_length
)["input_ids"]
full_encoding = tokenizer(
full_text, add_special_tokens=False, truncation=True, max_length=max_seq_length
)
prompt_len = len(prompt_ids)
full_ids = full_encoding["input_ids"]
labels = [-100] * prompt_len + full_ids[prompt_len:]
return {
"input_ids": full_ids,
"attention_mask": full_encoding["attention_mask"],
"labels": labels,
}
def build_dataset(samples: list[dict], tokenizer, max_seq_length: int = 4096) -> Dataset:
tokenized = []
for i, sample in enumerate(samples):
try:
tok = tokenize_chat(sample, tokenizer, max_seq_length=max_seq_length)
except Exception as e:
print(f"FATAL: tokenizing sample {i}: {e}")
sys.exit(1)
# Guard: if all labels are -100, the assistant response was truncated away
if all(l == -100 for l in tok["labels"]):
print(f"FATAL: sample {i} has all labels masked (-100) — prompt alone exceeds max_seq_length={max_seq_length}")
sys.exit(1)
tokenized.append(tok)
return Dataset.from_dict(
{
"input_ids": [t["input_ids"] for t in tokenized],
"attention_mask": [t["attention_mask"] for t in tokenized],
"labels": [t["labels"] for t in tokenized],
}
)
# ---------------------------------------------------------------------------
# Output directory
# ---------------------------------------------------------------------------
def derive_output_dir(wandb_run_name: str) -> Path:
"""Derive output dir from wandb_run_name, stripping '-local' suffix."""
name = wandb_run_name
if name.endswith("-local"):
name = name[: -len("-local")]
return EXPERIMENT_DIR / "output" / name
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Unsloth LoRA training")
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
args = parser.parse_args()
# ------------------------------------------------------------------
# Load config
# ------------------------------------------------------------------
config_path = resolve_path(args.config)
if not config_path.exists():
print(f"FATAL: Config not found: {config_path}")
return 1
config = load_config(config_path)
# ------------------------------------------------------------------
# Extract config values
# ------------------------------------------------------------------
model_name = config["model"]["name"]
data_path = config["data"]["path"]
training_cfg = config["training"]
epochs = training_cfg["epochs"]
batch_size = training_cfg["batch_size"]
gradient_accumulation_steps = training_cfg.get("gradient_accumulation_steps", 1)
learning_rate = float(training_cfg["learning_rate"])
lora_seed = training_cfg.get("lora_seed")
shuffle_seed = training_cfg["shuffle_seed"]
adapter_path = training_cfg.get("adapter_path")
max_steps = training_cfg.get("max_steps", -1)
max_seq_length = training_cfg.get("max_seq_length", 4096)
lora_cfg = config["lora"]
lora_rank = lora_cfg["rank"]
lora_alpha = lora_cfg.get("alpha", 64)
lora_dropout = lora_cfg.get("dropout", 0.0)
target_modules = lora_cfg.get("target_modules", [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
])
if target_modules == "all-linear":
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
logging_cfg = config["logging"]
wandb_project = logging_cfg["wandb_project"]
wandb_run_name = logging_cfg["wandb_run_name"]
require_wandb = logging_cfg.get("require_wandb", False)
log_every = logging_cfg.get("log_every_n_steps", 1)
save_every = logging_cfg.get("save_every_n_steps", 500)
output_dir = str(derive_output_dir(wandb_run_name))
is_continuation = adapter_path is not None
# ------------------------------------------------------------------
# Validate
# ------------------------------------------------------------------
if training_cfg.get("resume_from"):
print("FATAL: resume_from is not supported in unsloth training. Use adapter_path for continuation.")
return 1
if lora_seed is None and not is_continuation:
print("FATAL: training.lora_seed is required when not loading an existing adapter")
return 1
if shuffle_seed is None:
print("FATAL: training.shuffle_seed is required (no default)")
return 1
if require_wandb and not os.environ.get("WANDB_API_KEY"):
print("FATAL: WANDB_API_KEY not set but require_wandb=true")
return 1
if not os.environ.get("WANDB_API_KEY"):
print("WARNING: WANDB_API_KEY not set — wandb disabled")
os.environ["WANDB_DISABLED"] = "true"
# ------------------------------------------------------------------
# Print summary
# ------------------------------------------------------------------
if is_continuation:
mode_label = "CONTINUATION"
if not is_continuation:
mode_label = "FRESH"
print("=" * 60)
print(f"UNSLOTH TRAINING [{mode_label}]")
print("=" * 60)
print(f" Model: {model_name}")
print(f" Data: {data_path}")
print(f" Output: {output_dir}")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size} (eff={batch_size * gradient_accumulation_steps})")
print(f" LR: {learning_rate}")
print(f" LoRA: r={lora_rank} alpha={lora_alpha} dropout={lora_dropout}")
print(f" Targets: {target_modules}")
print(f" Max seq len: {max_seq_length}")
if is_continuation:
print(f" Adapter from: {adapter_path}")
print(f" wandb: {wandb_project} / {wandb_run_name}")
print("=" * 60)
# ------------------------------------------------------------------
# Load and shuffle data
# ------------------------------------------------------------------
data_resolved = resolve_path(data_path)
if not data_resolved.exists():
print(f"FATAL: Data file not found: {data_resolved}")
return 1
samples = load_jsonl(data_resolved)
print(f"Loaded {len(samples)} samples")
random.Random(shuffle_seed).shuffle(samples)
print(f"Shuffled with seed={shuffle_seed}")
# ------------------------------------------------------------------
# Load model + tokenizer via Unsloth
# ------------------------------------------------------------------
print("Loading model via Unsloth FastLanguageModel...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=False,
dtype=torch.bfloat16,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------------------------------------------------------
# Tokenize dataset
# ------------------------------------------------------------------
print("Tokenizing...")
dataset = build_dataset(samples, tokenizer, max_seq_length=max_seq_length)
lengths = [len(ids) for ids in dataset["input_ids"]]
print(
f"Tokenized {len(dataset)} samples "
f"(tokens: min={min(lengths)}, max={max(lengths)}, "
f"mean={sum(lengths) / len(lengths):.0f})"
)
# ------------------------------------------------------------------
# Apply LoRA — fresh init or load existing adapter
# ------------------------------------------------------------------
if is_continuation:
adapter_resolved = str(resolve_path(adapter_path))
print(f"Continuation mode: loading adapter from {adapter_resolved}")
from peft import PeftModel
model = PeftModel.from_pretrained(model, adapter_resolved)
model.train()
# Apply Unsloth gradient checkpointing for VRAM savings
from unsloth import FastLanguageModel as _FLM
_FLM.for_training(model, use_gradient_checkpointing="unsloth")
model.print_trainable_parameters()
if not is_continuation:
print(f"Fresh mode: seeding LoRA init with lora_seed={lora_seed}")
torch.manual_seed(lora_seed)
torch.cuda.manual_seed_all(lora_seed)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_gradient_checkpointing="unsloth",
random_state=lora_seed,
)
model.print_trainable_parameters()
# Reset seed to shuffle_seed after LoRA init/load
torch.manual_seed(shuffle_seed)
torch.cuda.manual_seed_all(shuffle_seed)
# ------------------------------------------------------------------
# Training arguments (plain Trainer — pre-tokenized data, no SFTTrainer)
# ------------------------------------------------------------------
has_wandb = bool(os.environ.get("WANDB_API_KEY"))
report_to = "wandb" if has_wandb else "none"
if has_wandb:
os.environ["WANDB_PROJECT"] = wandb_project
from transformers import DataCollatorForSeq2Seq, Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
max_steps=max_steps,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
lr_scheduler_type="constant",
warmup_ratio=0.0,
weight_decay=0.0,
optim="adamw_torch",
seed=shuffle_seed,
data_seed=shuffle_seed,
bf16=True,
logging_steps=log_every,
save_steps=save_every,
save_total_limit=3,
report_to=report_to,
run_name=wandb_run_name,
remove_unused_columns=False,
dataloader_pin_memory=True,
dataloader_num_workers=8,
dataloader_persistent_workers=True,
dataloader_prefetch_factor=2,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
return_tensors="pt",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
processing_class=tokenizer,
)
# ------------------------------------------------------------------
# Save config YAML alongside output for reproducibility
# ------------------------------------------------------------------
import shutil
os.makedirs(output_dir, exist_ok=True)
shutil.copy2(config_path, Path(output_dir) / "training_config.yaml")
print(f"Saved config copy to {output_dir}/training_config.yaml")
# ------------------------------------------------------------------
# Train
# ------------------------------------------------------------------
print("Starting training...")
trainer.train()
# ------------------------------------------------------------------
# Log full config to wandb
# ------------------------------------------------------------------
if has_wandb:
import wandb
if wandb.run is not None:
wandb.config.update({
"lora_seed": lora_seed,
"shuffle_seed": shuffle_seed,
"lora_rank": lora_rank,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"lora_target_modules": target_modules,
"data_path": str(data_path),
"model_name": model_name,
"gradient_accumulation_steps": gradient_accumulation_steps,
"effective_batch_size": batch_size * gradient_accumulation_steps,
"adapter_path": adapter_path,
"max_seq_length": max_seq_length,
"config_file": str(config_path),
"backend": "unsloth",
}, allow_val_change=True)
print("Logged seeds and config to wandb")
# ------------------------------------------------------------------
# Save adapter
# ------------------------------------------------------------------
print("Saving adapter...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("=" * 60)
print("TRAINING COMPLETE")
print(f" Adapter: {output_dir}")
print(f" Samples: {len(samples)}")
print(f" Backend: unsloth")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())