dee-tulu-train / train_tulu.py
Javad Taghia
cpu run
dba87af
# License: CC BY-NC-SA 4.0. Rights belong to Javad Taghia (taghia.javad@gmail.com).
"""
Minimal QLoRA finetune for a laptop-friendly Tulu checkpoint with W&B logging.
Defaults aim to run on a single consumer GPU using 4-bit quantization.
"""
from __future__ import annotations
import argparse
import os
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
import wandb
from datasets import load_dataset
from dotenv import load_dotenv
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
@dataclass
class ScriptConfig:
model_name: str = "allenai/tulu-2-7b"
dataset_name: str = "mlabonne/guanaco-llama2-1k" # small, instruction-style set
output_dir: str = "outputs/tulu-lora"
offload_folder: str = "offload" # where to offload weights if needed
device: str = "auto" # auto|cpu|mps|cuda
torch_dtype: str = "auto" # auto|float16|float32|bfloat16
cpu_threads: int = 4 # limit CPU usage when running on cpu
instruction_field: str = "instruction"
input_field: str = "input"
output_field: str = "output"
max_seq_length: int = 512
per_device_batch_size: int = 1
gradient_accumulation_steps: int = 16
num_train_epochs: int = 1
learning_rate: float = 2e-4
warmup_ratio: float = 0.03
logging_steps: int = 10
save_steps: int = 200
use_4bit: bool = True
def format_chat(example: Dict[str, str], cfg: ScriptConfig) -> str:
"""Simple instruction->response template that fits Tulu-style tuning."""
instruction = example.get(cfg.instruction_field)
output = example.get(cfg.output_field)
# Common fallback: text-only datasets.
if instruction is None and "text" in example:
instruction = example["text"]
if output is None and "text" in example:
output = example["text"]
if instruction is None or output is None:
available = ", ".join(example.keys())
missing_fields = []
if instruction is None:
missing_fields.append(cfg.instruction_field)
if output is None:
missing_fields.append(cfg.output_field)
missing_str = "/".join(missing_fields)
raise KeyError(
f"Dataset is missing '{missing_str}'. Available fields: {available}. "
"Use --instruction_field/--input_field/--output_field to match your dataset, "
"or set both instruction/output to 'text' for single-text datasets."
)
user_input = example.get(cfg.input_field) or "N/A"
return (
f"### Instruction:\n{instruction}\n\n"
f"### Input:\n{user_input}\n\n"
f"### Response:\n{output}"
)
def tokenize_example(example: Dict[str, str], tokenizer, cfg: ScriptConfig):
prompt = format_chat(example, cfg)
# We build labels that are the same as input_ids for causal LM.
tokenized = tokenizer(
prompt,
truncation=True,
max_length=cfg.max_seq_length,
padding="max_length",
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
def load_model_and_tokenizer(cfg: ScriptConfig):
os.makedirs(cfg.offload_folder, exist_ok=True)
quantization_config = None
if cfg.use_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=False)
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
# Choose device map and offload strategy based on desired/available hardware.
device_map: Optional[Dict[str, str] | str]
offload_folder = cfg.offload_folder
torch_dtype = None
# Respect user override.
if cfg.torch_dtype != "auto":
torch_dtype = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}[cfg.torch_dtype]
if cfg.device == "cuda" and torch.cuda.is_available():
device_map = "auto"
# Let quantization/auto-cast handle dtype on CUDA.
elif cfg.device == "mps" and torch.backends.mps.is_available():
device_map = {"": "mps"}
torch_dtype = torch_dtype or torch.float16 # avoid bf16 on MPS
offload_folder = None # avoid disk offload on MPS
elif cfg.device == "cpu":
device_map = {"": "cpu"}
torch_dtype = torch.float32
offload_folder = None
else:
# auto: prefer CUDA, else MPS, else CPU
if torch.cuda.is_available():
device_map = "auto"
torch_dtype = None
elif torch.backends.mps.is_available():
device_map = {"": "mps"}
torch_dtype = torch.float16
offload_folder = None
else:
device_map = {"": "cpu"}
torch_dtype = torch.float32
offload_folder = None
model = AutoModelForCausalLM.from_pretrained(
cfg.model_name,
quantization_config=quantization_config,
device_map=device_map,
offload_folder=offload_folder,
use_safetensors=True,
torch_dtype=torch_dtype,
)
if cfg.use_4bit:
model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
r=64,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg)
return model, tokenizer
def init_wandb(cfg: ScriptConfig):
project = os.getenv("WANDB_PROJECT", "tulu-laptop-run")
entity = os.getenv("WANDB_ENTITY")
api_key = os.getenv("WANDB_API_KEY")
if not api_key:
raise RuntimeError("WANDB_API_KEY is missing. Put it in your .env before running.")
wandb.login(key=api_key)
wandb.init(project=project, entity=entity, config=vars(cfg))
def parse_args() -> ScriptConfig:
parser = argparse.ArgumentParser(description="Finetune Tulu with QLoRA + W&B")
parser.add_argument("--model_name", default=ScriptConfig.model_name)
parser.add_argument("--dataset_name", default=ScriptConfig.dataset_name)
parser.add_argument("--output_dir", default=ScriptConfig.output_dir)
parser.add_argument("--offload_folder", default=ScriptConfig.offload_folder)
parser.add_argument(
"--device",
default=ScriptConfig.device,
choices=["auto", "cpu", "mps", "cuda"],
help="Force device placement (default auto).",
)
parser.add_argument(
"--torch_dtype",
default=ScriptConfig.torch_dtype,
choices=["auto", "float16", "float32", "bfloat16"],
help="Force torch dtype (default auto). On MPS use float16.",
)
parser.add_argument(
"--cpu_threads",
type=int,
default=ScriptConfig.cpu_threads,
help="Limit CPU threads when running on CPU (default 4) to avoid overloading.",
)
parser.add_argument("--instruction_field", default=ScriptConfig.instruction_field)
parser.add_argument("--input_field", default=ScriptConfig.input_field)
parser.add_argument("--output_field", default=ScriptConfig.output_field)
parser.add_argument("--max_seq_length", type=int, default=ScriptConfig.max_seq_length)
parser.add_argument("--per_device_batch_size", type=int, default=ScriptConfig.per_device_batch_size)
parser.add_argument("--gradient_accumulation_steps", type=int, default=ScriptConfig.gradient_accumulation_steps)
parser.add_argument("--num_train_epochs", type=float, default=ScriptConfig.num_train_epochs)
parser.add_argument("--learning_rate", type=float, default=ScriptConfig.learning_rate)
parser.add_argument("--warmup_ratio", type=float, default=ScriptConfig.warmup_ratio)
parser.add_argument("--logging_steps", type=int, default=ScriptConfig.logging_steps)
parser.add_argument("--save_steps", type=int, default=ScriptConfig.save_steps)
parser.add_argument("--use_4bit", action=argparse.BooleanOptionalAction, default=False)
args = parser.parse_args()
return ScriptConfig(**vars(args))
def configure_cache_from_env():
"""Allow user to redirect HF cache (models + datasets) via BASE_MODEL_CACHE env."""
cache_dir = os.getenv("BASE_MODEL_CACHE")
if cache_dir:
os.environ.setdefault("HF_HOME", cache_dir)
os.environ.setdefault("TRANSFORMERS_CACHE", cache_dir)
os.environ.setdefault("HF_DATASETS_CACHE", cache_dir)
def main():
load_dotenv()
# Load env vars (WANDB keys, optional cache path).
configure_cache_from_env()
# Redirect HF cache if BASE_MODEL_CACHE is set.
cfg = parse_args()
# Read CLI hyperparameters/model settings.
init_wandb(cfg)
# Start a W&B run with config and login.
model, tokenizer = load_model_and_tokenizer(cfg)
# Load base model + tokenizer with LoRA (and 4-bit if enabled).
is_mps = torch.backends.mps.is_available()
force_cpu = cfg.device == "cpu"
force_mps = cfg.device == "mps"
force_cuda = cfg.device == "cuda"
if cfg.device == "cpu":
# Prevent saturating all cores on CPU runs.
torch.set_num_threads(max(1, cfg.cpu_threads))
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not is_mps and not force_cpu and (force_cuda or cfg.device == "auto")
use_fp16 = torch.cuda.is_available() and not use_bf16 and not is_mps and not force_cpu and (force_cuda or cfg.device == "auto")
# Choose best available mixed precision (bf16 > fp16 > fp32), but force fp32 on MPS/CPU.
precision_mode = "bf16" if use_bf16 else "fp16" if use_fp16 else "fp32"
raw_dataset = load_dataset(cfg.dataset_name)
# Download/load the instruction dataset.
tokenize_start = time.time()
tokenized = raw_dataset["train"].map(
lambda ex: tokenize_example(ex, tokenizer, cfg),
remove_columns=raw_dataset["train"].column_names,
)
tokenize_duration = time.time() - tokenize_start
wandb.log({"tokenization_duration_seconds": tokenize_duration})
# Format/tokenize dataset to fixed length with labels.
train_examples = len(tokenized)
total_tokens = train_examples * cfg.max_seq_length
wandb.summary.update(
{
"train_examples": train_examples,
"estimated_tokens": total_tokens,
"precision_mode": precision_mode,
"use_4bit": cfg.use_4bit,
"model_name": cfg.model_name,
"dataset_name": cfg.dataset_name,
"per_device_batch_size": cfg.per_device_batch_size,
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
"max_seq_length": cfg.max_seq_length,
}
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Pad/batch causal LM examples.
# Choose optimizer: paged_adamw_32bit for 4-bit GPU; fall back to AdamW on CPU/no-4bit.
optim_name = "paged_adamw_32bit" if cfg.use_4bit and not force_cpu else "adamw_torch"
training_args = TrainingArguments(
output_dir=cfg.output_dir,
per_device_train_batch_size=cfg.per_device_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_train_epochs,
learning_rate=cfg.learning_rate,
warmup_ratio=cfg.warmup_ratio,
logging_steps=cfg.logging_steps,
save_steps=cfg.save_steps,
bf16=use_bf16,
fp16=use_fp16,
report_to=["wandb"],
optim=optim_name,
)
# Trainer configuration (logging, saving, optimizer, precision).
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
tokenizer=tokenizer,
data_collator=data_collator,
)
# Wire model, data, and config into HF Trainer.
train_start = time.time()
trainer.train()
# Run supervised finetuning (cross-entropy).
train_duration = time.time() - train_start
wandb.log({"train_duration_seconds": train_duration})
# Record wall-clock training time to W&B.
trainer.save_model(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Save adapters/tokenizer to output_dir.
wandb.finish()
# Close W&B run.
if __name__ == "__main__":
main()