"""Prepare a weighted, shuffled training set for BLUX-cA QLoRA.""" from __future__ import annotations import argparse import json import random from datetime import datetime from pathlib import Path from typing import Dict, List, Optional import yaml from validate_dataset import SYSTEM_PLACEHOLDER, validate_dataset def _timestamp() -> str: return datetime.utcnow().strftime("%Y%m%d_%H%M%S") def _load_config(path: Path) -> Dict: with path.open("r", encoding="utf-8") as handle: return yaml.safe_load(handle) def _load_jsonl(path: Path) -> List[Dict]: records: List[Dict] = [] with path.open("r", encoding="utf-8") as handle: for line in handle: if line.strip(): records.append(json.loads(line)) return records def _sample_records(records: List[Dict], target: int, rng: random.Random) -> List[Dict]: if not records: return [] if target <= len(records): return rng.sample(records, target) return [rng.choice(records) for _ in range(target)] def prepare_dataset( dataset_dir: Path, mix_config: Path, output_root: Path, run_name: Optional[str] = None, override_max_samples: Optional[int] = None, strict: bool = False, ) -> Path: """Create a weighted, deterministic training mix. If ``override_max_samples`` is provided it takes precedence over the YAML ``max_samples`` value. When ``strict`` is True, data files are validated before mixing. """ if strict: _, errors = validate_dataset(dataset_dir, strict=True) if errors: joined = "\n".join(errors) raise ValueError(f"Dataset validation failed before mixing:\n{joined}") config = _load_config(mix_config) sources = config.get("sources", []) shuffle = bool(config.get("shuffle", True)) max_samples = override_max_samples if override_max_samples is not None else config.get("max_samples") seed = int(config.get("seed", 42)) rng = random.Random(seed) total_weight = sum(float(src.get("weight", 1.0)) for src in sources) if total_weight <= 0: raise ValueError("Total weight must be positive") collected: List[Dict] = [] for src in sources: raw_path = Path(src["file"]) if raw_path.is_absolute(): file_path = raw_path elif raw_path.parts and raw_path.parts[0] == "data": file_path = dataset_dir / raw_path else: file_path = dataset_dir / "data" / raw_path if not file_path.exists(): raise FileNotFoundError(f"Missing dataset file: {file_path}") weight = float(src.get("weight", 1.0)) records = _load_jsonl(file_path) if max_samples is None: target = max(len(records), 0) else: target = max(1, round((weight / total_weight) * max_samples)) sampled = _sample_records(records, target, rng) collected.extend(sampled) if not collected: raise ValueError("No samples collected from provided sources") if shuffle: rng.shuffle(collected) folder_name = _timestamp() if not run_name else f"{_timestamp()}_{run_name}" run_dir = output_root / folder_name run_dir.mkdir(parents=True, exist_ok=True) output_path = run_dir / "prepared_train.jsonl" with output_path.open("w", encoding="utf-8") as handle: for record in collected: if "messages" in record: system_msgs = [m for m in record["messages"] if m.get("role") == "system"] if system_msgs: system_msgs[0]["content"] = SYSTEM_PLACEHOLDER handle.write(json.dumps(record, ensure_ascii=False) + "\n") resolved_mix_path = run_dir / "mix_config_resolved.yaml" with resolved_mix_path.open("w", encoding="utf-8") as handle: yaml.safe_dump( { **config, "max_samples": max_samples, "seed": seed, "dataset_dir": str(dataset_dir), }, handle, sort_keys=False, ) return output_path def main() -> int: parser = argparse.ArgumentParser(description="Prepare weighted training data for QLoRA") parser.add_argument("--dataset-dir", required=True, type=Path, help="Path to dataset repository") parser.add_argument("--mix-config", type=Path, default=Path("train/configs/dataset_mix.yaml"), help="Mixing config YAML") parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for run outputs") parser.add_argument("--run-name", type=str, default=None, help="Optional run folder name (otherwise timestamp)") parser.add_argument("--max-samples", type=int, default=None, help="Override max_samples in config for quick smoke runs") parser.add_argument("--strict", action="store_true", help="Validate input files strictly before mixing") args = parser.parse_args() if not args.dataset_dir.exists(): print(f"Dataset directory not found: {args.dataset_dir}") return 1 output_path = prepare_dataset( args.dataset_dir, args.mix_config, args.output_root, run_name=args.run_name, override_max_samples=args.max_samples, strict=args.strict, ) print(f"Prepared dataset written to {output_path}") return 0 if __name__ == "__main__": raise SystemExit(main())