| |
| """ |
| Minimal QLoRA finetune for a laptop-friendly Tulu checkpoint with W&B logging. |
| |
| Defaults aim to run on a single consumer GPU using 4-bit quantization. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import time |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import wandb |
| from datasets import load_dataset |
| from dotenv import load_dotenv |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
|
|
| @dataclass |
| class ScriptConfig: |
| model_name: str = "allenai/tulu-2-7b" |
| dataset_name: str = "mlabonne/guanaco-llama2-1k" |
| output_dir: str = "outputs/tulu-lora" |
| offload_folder: str = "offload" |
| device: str = "auto" |
| torch_dtype: str = "auto" |
| cpu_threads: int = 4 |
| instruction_field: str = "instruction" |
| input_field: str = "input" |
| output_field: str = "output" |
| max_seq_length: int = 512 |
| per_device_batch_size: int = 1 |
| gradient_accumulation_steps: int = 16 |
| num_train_epochs: int = 1 |
| learning_rate: float = 2e-4 |
| warmup_ratio: float = 0.03 |
| logging_steps: int = 10 |
| save_steps: int = 200 |
| use_4bit: bool = True |
|
|
|
|
| def format_chat(example: Dict[str, str], cfg: ScriptConfig) -> str: |
| """Simple instruction->response template that fits Tulu-style tuning.""" |
| instruction = example.get(cfg.instruction_field) |
| output = example.get(cfg.output_field) |
| |
| if instruction is None and "text" in example: |
| instruction = example["text"] |
| if output is None and "text" in example: |
| output = example["text"] |
| if instruction is None or output is None: |
| available = ", ".join(example.keys()) |
| missing_fields = [] |
| if instruction is None: |
| missing_fields.append(cfg.instruction_field) |
| if output is None: |
| missing_fields.append(cfg.output_field) |
| missing_str = "/".join(missing_fields) |
| raise KeyError( |
| f"Dataset is missing '{missing_str}'. Available fields: {available}. " |
| "Use --instruction_field/--input_field/--output_field to match your dataset, " |
| "or set both instruction/output to 'text' for single-text datasets." |
| ) |
| user_input = example.get(cfg.input_field) or "N/A" |
| return ( |
| f"### Instruction:\n{instruction}\n\n" |
| f"### Input:\n{user_input}\n\n" |
| f"### Response:\n{output}" |
| ) |
|
|
|
|
| def tokenize_example(example: Dict[str, str], tokenizer, cfg: ScriptConfig): |
| prompt = format_chat(example, cfg) |
| |
| tokenized = tokenizer( |
| prompt, |
| truncation=True, |
| max_length=cfg.max_seq_length, |
| padding="max_length", |
| ) |
| tokenized["labels"] = tokenized["input_ids"].copy() |
| return tokenized |
|
|
|
|
| def load_model_and_tokenizer(cfg: ScriptConfig): |
| os.makedirs(cfg.offload_folder, exist_ok=True) |
| quantization_config = None |
| if cfg.use_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=False) |
| tokenizer.padding_side = "right" |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| device_map: Optional[Dict[str, str] | str] |
| offload_folder = cfg.offload_folder |
| torch_dtype = None |
| |
| if cfg.torch_dtype != "auto": |
| torch_dtype = { |
| "float16": torch.float16, |
| "float32": torch.float32, |
| "bfloat16": torch.bfloat16, |
| }[cfg.torch_dtype] |
| if cfg.device == "cuda" and torch.cuda.is_available(): |
| device_map = "auto" |
| |
| elif cfg.device == "mps" and torch.backends.mps.is_available(): |
| device_map = {"": "mps"} |
| torch_dtype = torch_dtype or torch.float16 |
| offload_folder = None |
| elif cfg.device == "cpu": |
| device_map = {"": "cpu"} |
| torch_dtype = torch.float32 |
| offload_folder = None |
| else: |
| |
| if torch.cuda.is_available(): |
| device_map = "auto" |
| torch_dtype = None |
| elif torch.backends.mps.is_available(): |
| device_map = {"": "mps"} |
| torch_dtype = torch.float16 |
| offload_folder = None |
| else: |
| device_map = {"": "cpu"} |
| torch_dtype = torch.float32 |
| offload_folder = None |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| cfg.model_name, |
| quantization_config=quantization_config, |
| device_map=device_map, |
| offload_folder=offload_folder, |
| use_safetensors=True, |
| torch_dtype=torch_dtype, |
| ) |
| if cfg.use_4bit: |
| model = prepare_model_for_kbit_training(model) |
|
|
| lora_cfg = LoraConfig( |
| r=64, |
| lora_alpha=16, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_cfg) |
| return model, tokenizer |
|
|
|
|
| def init_wandb(cfg: ScriptConfig): |
| project = os.getenv("WANDB_PROJECT", "tulu-laptop-run") |
| entity = os.getenv("WANDB_ENTITY") |
| api_key = os.getenv("WANDB_API_KEY") |
| if not api_key: |
| raise RuntimeError("WANDB_API_KEY is missing. Put it in your .env before running.") |
| wandb.login(key=api_key) |
| wandb.init(project=project, entity=entity, config=vars(cfg)) |
|
|
|
|
| def parse_args() -> ScriptConfig: |
| parser = argparse.ArgumentParser(description="Finetune Tulu with QLoRA + W&B") |
| parser.add_argument("--model_name", default=ScriptConfig.model_name) |
| parser.add_argument("--dataset_name", default=ScriptConfig.dataset_name) |
| parser.add_argument("--output_dir", default=ScriptConfig.output_dir) |
| parser.add_argument("--offload_folder", default=ScriptConfig.offload_folder) |
| parser.add_argument( |
| "--device", |
| default=ScriptConfig.device, |
| choices=["auto", "cpu", "mps", "cuda"], |
| help="Force device placement (default auto).", |
| ) |
| parser.add_argument( |
| "--torch_dtype", |
| default=ScriptConfig.torch_dtype, |
| choices=["auto", "float16", "float32", "bfloat16"], |
| help="Force torch dtype (default auto). On MPS use float16.", |
| ) |
| parser.add_argument( |
| "--cpu_threads", |
| type=int, |
| default=ScriptConfig.cpu_threads, |
| help="Limit CPU threads when running on CPU (default 4) to avoid overloading.", |
| ) |
| parser.add_argument("--instruction_field", default=ScriptConfig.instruction_field) |
| parser.add_argument("--input_field", default=ScriptConfig.input_field) |
| parser.add_argument("--output_field", default=ScriptConfig.output_field) |
| parser.add_argument("--max_seq_length", type=int, default=ScriptConfig.max_seq_length) |
| parser.add_argument("--per_device_batch_size", type=int, default=ScriptConfig.per_device_batch_size) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=ScriptConfig.gradient_accumulation_steps) |
| parser.add_argument("--num_train_epochs", type=float, default=ScriptConfig.num_train_epochs) |
| parser.add_argument("--learning_rate", type=float, default=ScriptConfig.learning_rate) |
| parser.add_argument("--warmup_ratio", type=float, default=ScriptConfig.warmup_ratio) |
| parser.add_argument("--logging_steps", type=int, default=ScriptConfig.logging_steps) |
| parser.add_argument("--save_steps", type=int, default=ScriptConfig.save_steps) |
| parser.add_argument("--use_4bit", action=argparse.BooleanOptionalAction, default=False) |
| args = parser.parse_args() |
| return ScriptConfig(**vars(args)) |
|
|
|
|
| def configure_cache_from_env(): |
| """Allow user to redirect HF cache (models + datasets) via BASE_MODEL_CACHE env.""" |
| cache_dir = os.getenv("BASE_MODEL_CACHE") |
| if cache_dir: |
| os.environ.setdefault("HF_HOME", cache_dir) |
| os.environ.setdefault("TRANSFORMERS_CACHE", cache_dir) |
| os.environ.setdefault("HF_DATASETS_CACHE", cache_dir) |
|
|
|
|
| def main(): |
| load_dotenv() |
| |
| configure_cache_from_env() |
| |
| cfg = parse_args() |
| |
|
|
| init_wandb(cfg) |
| |
| model, tokenizer = load_model_and_tokenizer(cfg) |
| |
|
|
| is_mps = torch.backends.mps.is_available() |
| force_cpu = cfg.device == "cpu" |
| force_mps = cfg.device == "mps" |
| force_cuda = cfg.device == "cuda" |
|
|
| if cfg.device == "cpu": |
| |
| torch.set_num_threads(max(1, cfg.cpu_threads)) |
|
|
| use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not is_mps and not force_cpu and (force_cuda or cfg.device == "auto") |
| use_fp16 = torch.cuda.is_available() and not use_bf16 and not is_mps and not force_cpu and (force_cuda or cfg.device == "auto") |
| |
| precision_mode = "bf16" if use_bf16 else "fp16" if use_fp16 else "fp32" |
|
|
| raw_dataset = load_dataset(cfg.dataset_name) |
| |
| tokenize_start = time.time() |
| tokenized = raw_dataset["train"].map( |
| lambda ex: tokenize_example(ex, tokenizer, cfg), |
| remove_columns=raw_dataset["train"].column_names, |
| ) |
| tokenize_duration = time.time() - tokenize_start |
| wandb.log({"tokenization_duration_seconds": tokenize_duration}) |
| |
| train_examples = len(tokenized) |
| total_tokens = train_examples * cfg.max_seq_length |
| wandb.summary.update( |
| { |
| "train_examples": train_examples, |
| "estimated_tokens": total_tokens, |
| "precision_mode": precision_mode, |
| "use_4bit": cfg.use_4bit, |
| "model_name": cfg.model_name, |
| "dataset_name": cfg.dataset_name, |
| "per_device_batch_size": cfg.per_device_batch_size, |
| "gradient_accumulation_steps": cfg.gradient_accumulation_steps, |
| "max_seq_length": cfg.max_seq_length, |
| } |
| ) |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| |
|
|
| |
| optim_name = "paged_adamw_32bit" if cfg.use_4bit and not force_cpu else "adamw_torch" |
|
|
| training_args = TrainingArguments( |
| output_dir=cfg.output_dir, |
| per_device_train_batch_size=cfg.per_device_batch_size, |
| gradient_accumulation_steps=cfg.gradient_accumulation_steps, |
| num_train_epochs=cfg.num_train_epochs, |
| learning_rate=cfg.learning_rate, |
| warmup_ratio=cfg.warmup_ratio, |
| logging_steps=cfg.logging_steps, |
| save_steps=cfg.save_steps, |
| bf16=use_bf16, |
| fp16=use_fp16, |
| report_to=["wandb"], |
| optim=optim_name, |
| ) |
| |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized, |
| tokenizer=tokenizer, |
| data_collator=data_collator, |
| ) |
| |
|
|
| train_start = time.time() |
| trainer.train() |
| |
| train_duration = time.time() - train_start |
| wandb.log({"train_duration_seconds": train_duration}) |
| |
| trainer.save_model(cfg.output_dir) |
| tokenizer.save_pretrained(cfg.output_dir) |
| |
| wandb.finish() |
| |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|