| """QLoRA training pipeline for BLUX-cA adapters.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| import torch | |
| import yaml | |
| from datasets import load_dataset | |
| from peft import LoraConfig | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TrainingArguments, | |
| ) | |
| from trl import SFTTrainer | |
| from prepare_dataset import prepare_dataset | |
| from validate_dataset import run_cli_validator, validate_file | |
| EXAMPLE_DATASET_CMD = "export DATASET_DIR=/absolute/path/to/blux-ca-dataset" | |
| def _load_yaml(path: Path) -> Dict: | |
| with path.open("r", encoding="utf-8") as handle: | |
| return yaml.safe_load(handle) | |
| def _write_json(path: Path, payload: Dict) -> None: | |
| with path.open("w", encoding="utf-8") as handle: | |
| json.dump(payload, handle, indent=2, sort_keys=True) | |
| def _resolve_dataset_dir(raw: Optional[Path]) -> Path: | |
| if raw: | |
| return raw | |
| env_dir = os.environ.get("DATASET_DIR") | |
| if env_dir: | |
| return Path(env_dir) | |
| raise ValueError( | |
| f"Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., {EXAMPLE_DATASET_CMD})" | |
| ) | |
| def _resolve_base_model(cfg: Dict, prefer_cpu_safe: bool = False) -> str: | |
| env_base_model = os.environ.get("BASE_MODEL") | |
| if env_base_model: | |
| return env_base_model | |
| if prefer_cpu_safe: | |
| return cfg.get("cpu_base_model", cfg.get("base_model")) | |
| return cfg.get("base_model") | |
| def _validate_sources(dataset_dir: Path, mix_config: Path) -> None: | |
| mix_cfg = _load_yaml(mix_config) | |
| data_dir = dataset_dir / "data" | |
| errors: List[str] = [] | |
| for source in mix_cfg.get("sources", []): | |
| path = data_dir / source.get("file", "") | |
| if not path.exists(): | |
| errors.append(f"Missing file: {path}") | |
| continue | |
| _, _, file_errors = validate_file(path, strict=True) | |
| errors.extend(file_errors) | |
| if errors: | |
| joined = "\n".join(errors) | |
| raise ValueError(f"Dataset validation failed:\n{joined}") | |
| def _format_messages(messages: List[Dict], tokenizer) -> str: | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template(messages, tokenize=False) | |
| parts = [] | |
| for msg in messages: | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| parts.append(f"[{role}] {content}") | |
| return "\n".join(parts) | |
| def _build_dataset(prepared_path: Path, tokenizer): | |
| dataset = load_dataset("json", data_files=str(prepared_path))["train"] | |
| def add_text(example): | |
| example["text"] = _format_messages(example["messages"], tokenizer) | |
| return example | |
| text_dataset = dataset.map(add_text, remove_columns=[]) | |
| return text_dataset | |
| def _init_model(base_model: str, lora_config: Dict) -> AutoModelForCausalLM: | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| quantization_config=quant_config, | |
| device_map="auto", | |
| ) | |
| peft_config = LoraConfig( | |
| r=int(lora_config.get("r", 16)), | |
| lora_alpha=int(lora_config.get("alpha", 32)), | |
| target_modules=lora_config.get("target_modules", []), | |
| lora_dropout=float(lora_config.get("dropout", 0.05)), | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model.add_adapter(peft_config) | |
| return model | |
| def train(args: argparse.Namespace) -> Path: | |
| dataset_dir = _resolve_dataset_dir(args.dataset_dir) | |
| if not dataset_dir.exists(): | |
| raise FileNotFoundError( | |
| f"Dataset directory not found: {dataset_dir}. Set DATASET_DIR first (e.g., `{EXAMPLE_DATASET_CMD}`)." | |
| ) | |
| if not args.config.exists(): | |
| raise FileNotFoundError(f"Config not found: {args.config}") | |
| if not args.mix_config.exists(): | |
| raise FileNotFoundError(f"Mix config not found: {args.mix_config}") | |
| qlora_cfg = _load_yaml(args.config) | |
| mix_config = args.mix_config | |
| prefer_cpu_safe = args.dry_run and not torch.cuda.is_available() and not os.environ.get("BASE_MODEL") | |
| qlora_cfg["base_model"] = _resolve_base_model(qlora_cfg, prefer_cpu_safe=prefer_cpu_safe) | |
| validation_errors = run_cli_validator(dataset_dir) | |
| if validation_errors: | |
| raise ValueError("\n".join(validation_errors)) | |
| _validate_sources(dataset_dir, mix_config) | |
| prepared_path = prepare_dataset(dataset_dir, mix_config, args.output_root, run_name=args.run_name, strict=args.strict) | |
| run_dir = prepared_path.parent | |
| tokenizer = AutoTokenizer.from_pretrained(qlora_cfg["base_model"], use_fast=True) | |
| tokenizer.padding_side = "right" | |
| tokenizer.pad_token = tokenizer.eos_token | |
| train_dataset = _build_dataset(prepared_path, tokenizer) | |
| if args.dry_run: | |
| sample = train_dataset.select(range(min(5, len(train_dataset)))) | |
| _ = tokenizer( | |
| sample["text"], | |
| max_length=qlora_cfg["max_seq_length"], | |
| truncation=True, | |
| padding="longest", | |
| ) | |
| print("Dry-run successful: model and tokenizer loaded; tokenization ok.") | |
| return run_dir | |
| model = _init_model(qlora_cfg["base_model"], qlora_cfg["lora"]) | |
| training_args = TrainingArguments( | |
| output_dir=str(run_dir / "adapter_model"), | |
| num_train_epochs=int(qlora_cfg["epochs"]), | |
| per_device_train_batch_size=int(qlora_cfg["per_device_train_batch_size"]), | |
| gradient_accumulation_steps=int(qlora_cfg["gradient_accumulation_steps"]), | |
| learning_rate=float(qlora_cfg["learning_rate"]), | |
| warmup_ratio=float(qlora_cfg["warmup_ratio"]), | |
| logging_steps=10, | |
| save_strategy="epoch", | |
| bf16=bool(qlora_cfg.get("bf16", False)), | |
| fp16=bool(qlora_cfg.get("fp16", False)), | |
| gradient_checkpointing=True, | |
| report_to=[], | |
| seed=int(qlora_cfg.get("seed", 42)), | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=train_dataset, | |
| dataset_text_field="text", | |
| packing=False, | |
| max_seq_length=qlora_cfg["max_seq_length"], | |
| args=training_args, | |
| ) | |
| trainer.train() | |
| trainer.model.save_pretrained(training_args.output_dir) | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| _write_json(run_dir / "training_args.json", training_args.to_dict()) | |
| with (run_dir / "config_snapshot.yaml").open("w", encoding="utf-8") as handle: | |
| yaml.safe_dump( | |
| {"qlora": qlora_cfg, "mix_config": _load_yaml(mix_config)}, | |
| handle, | |
| sort_keys=False, | |
| ) | |
| print(f"Training complete. Adapter saved to {training_args.output_dir}") | |
| return run_dir | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train a BLUX-cA QLoRA adapter") | |
| parser.add_argument( | |
| "--dataset-dir", | |
| required=False, | |
| type=Path, | |
| default=Path(os.environ["DATASET_DIR"]) if os.environ.get("DATASET_DIR") else None, | |
| help="Path to dataset repository (or set DATASET_DIR)", | |
| ) | |
| parser.add_argument("--config", type=Path, default=Path("train/configs/qlora.yaml"), help="QLoRA config path") | |
| parser.add_argument("--mix-config", type=Path, default=Path("train/configs/dataset_mix.yaml"), help="Dataset mixing config") | |
| parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for outputs") | |
| parser.add_argument("--run-name", type=str, default=os.environ.get("RUN_NAME"), help="Optional run folder name") | |
| parser.add_argument("--dry-run", action="store_true", help="Load model/tokenizer and tokenize sample without training") | |
| parser.add_argument("--strict", action="store_true", help="Validate dataset strictly before mixing") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| cli_args = parse_args() | |
| try: | |
| train(cli_args) | |
| except (FileNotFoundError, ValueError) as exc: | |
| print(exc) | |
| raise SystemExit(1) | |