"""Dataset validation for BLUX-cA QLoRA pipeline. Checks that training JSONL files conform to the expected message schema and contain the BLUX-cA system placeholder. """ from __future__ import annotations import argparse import importlib.util import json import subprocess import sys from pathlib import Path from typing import Dict, List, Optional, Tuple SYSTEM_PLACEHOLDER = "" def run_cli_validator(dataset_dir: Path, files: Optional[List[Path]] = None) -> List[str]: """Invoke the dataset repository's validator script via subprocess.""" validator_path = dataset_dir / "tools" / "validate_jsonl.py" if not validator_path.exists(): return [] rel_files = [] if files: for f in files: if f.is_absolute() and dataset_dir in f.parents: rel_files.append(str(f.relative_to(dataset_dir))) else: rel_files.append(str(f)) cmd = [sys.executable, str(validator_path), *rel_files] result = subprocess.run(cmd, capture_output=True, text=True, cwd=dataset_dir) if result.returncode != 0: output = (result.stdout + "\n" + result.stderr).strip() return [line for line in output.splitlines() if line.strip()] or [ f"Validator exited with code {result.returncode}", f"Re-run manually: python {validator_path}", ] return [] def _load_external_validator(dataset_dir: Path): """Load dataset-provided validator if available. Returns a callable with signature List[Path] -> Dict[Path, List[str]] mapping file paths to lists of errors. If not available, returns None. """ validator_path = dataset_dir / "tools" / "validate_jsonl.py" if not validator_path.exists(): return None spec = importlib.util.spec_from_file_location("blux_ca_dataset_validator", validator_path) if not spec or not spec.loader: # pragma: no cover - defensive return None module = importlib.util.module_from_spec(spec) sys.modules["blux_ca_dataset_validator"] = module spec.loader.exec_module(module) def validate(files: List[Path]) -> Dict[Path, List[str]]: error_map: Dict[Path, List[str]] = {} for file_path in files: errs = getattr(module, "validate_file")(file_path) if errs: error_map[file_path] = errs return error_map return validate def _iter_jsonl(path: Path) -> Tuple[int, Dict]: with path.open("r", encoding="utf-8") as handle: for idx, line in enumerate(handle, start=1): line = line.strip() if not line: continue try: yield idx, json.loads(line) except json.JSONDecodeError as exc: raise ValueError(f"{path} line {idx}: invalid JSON ({exc})") from exc def _validate_messages(messages: List[Dict], path: Path, line_no: int, strict: bool) -> List[str]: errors: List[str] = [] if not isinstance(messages, list) or not messages: errors.append(f"{path} line {line_no}: 'messages' must be a non-empty list") return errors roles = [m.get("role") for m in messages if isinstance(m, dict)] if roles.count("system") < 1: errors.append(f"{path} line {line_no}: missing system role") if roles.count("user") < 1: errors.append(f"{path} line {line_no}: missing user role") if roles.count("assistant") < 1: errors.append(f"{path} line {line_no}: missing assistant role") system_messages = [m for m in messages if isinstance(m, dict) and m.get("role") == "system"] if system_messages: system_content = system_messages[0].get("content", "") if system_content != SYSTEM_PLACEHOLDER: errors.append( f"{path} line {line_no}: system content must equal {SYSTEM_PLACEHOLDER!r}" ) else: errors.append(f"{path} line {line_no}: system message missing") for m in messages: if not isinstance(m, dict): errors.append(f"{path} line {line_no}: each message must be an object") continue if not m.get("role"): errors.append(f"{path} line {line_no}: message missing role") if not m.get("content"): errors.append(f"{path} line {line_no}: message missing content for role {m.get('role')}") audit_errors = _validate_audit_notes(m.get("content", ""), path, line_no) errors.extend(audit_errors) if strict: last_role = messages[-1].get("role") if isinstance(messages[-1], dict) else None if last_role != "assistant": errors.append(f"{path} line {line_no}: strict mode requires assistant as last role") return errors def _validate_audit_notes(content: str, path: Path, line_no: int) -> List[str]: if "## Audit Notes" not in content: return [] errors: List[str] = [] lines = content.splitlines() try: header_index = lines.index("## Audit Notes") except ValueError: return errors bullets = lines[header_index + 1 :] for bullet in bullets: if bullet.strip() and not bullet.strip().startswith("- "): errors.append(f"{path} line {line_no}: Audit Notes must use '- ' bullets") return errors def validate_file(path: Path, strict: bool) -> Tuple[int, int, List[str]]: total = 0 errors: List[str] = [] for line_no, record in _iter_jsonl(path): total += 1 if not isinstance(record, dict): errors.append(f"{path} line {line_no}: expected JSON object per line") continue messages = record.get("messages") errors.extend(_validate_messages(messages, path, line_no, strict)) return total, len(errors), errors def validate_dataset(dataset_dir: Path, files: Optional[str] = None, strict: bool = False) -> Tuple[int, List[str]]: if not dataset_dir.exists(): return 0, [f"Dataset directory not found: {dataset_dir}"] data_dir = dataset_dir / "data" eval_dir = dataset_dir / "eval" candidates: List[Path] if files: candidates = [data_dir / f for f in files.split(",")] else: candidates = sorted(data_dir.glob("*.jsonl")) if not candidates: return 0, [f"No JSONL files found under {data_dir}"] if not eval_dir.exists(): return 0, [f"Eval probes missing: {eval_dir}"] missing_files = [path for path in candidates if not path.exists()] if missing_files: return 0, [f"Missing file: {path}" for path in missing_files] cli_errors = run_cli_validator(dataset_dir, candidates) if cli_errors: return 0, cli_errors external_validator = _load_external_validator(dataset_dir) if external_validator: print("Using dataset-supplied validator") errors_map = external_validator(candidates) overall_errors = [f"{path}: {err}" for path, errs in errors_map.items() for err in errs] total_lines = sum(1 for path in candidates for _ in _iter_jsonl(path)) return total_lines, overall_errors overall_errors = [] total_lines = 0 for path in candidates: if not path.exists(): overall_errors.append(f"Missing file: {path}") continue count, error_count, errors = validate_file(path, strict=strict) total_lines += count overall_errors.extend(errors) print(f"Validated {path} - lines: {count}, errors: {error_count}") return total_lines, overall_errors def main() -> int: parser = argparse.ArgumentParser(description="Validate BLUX-cA JSONL datasets") parser.add_argument("--dataset-dir", required=True, type=Path, help="Path to dataset repository") parser.add_argument("--files", type=str, default=None, help="Comma-separated list of data/*.jsonl files") parser.add_argument("--strict", action="store_true", help="Enable strict ordering and audit checks") args = parser.parse_args() total_lines, errors = validate_dataset(args.dataset_dir, files=args.files, strict=args.strict) if errors: print("Validation errors:") for err in errors: print(f"- {err}") return 1 print(f"Validation passed. Files checked: {total_lines} lines total") return 0 if __name__ == "__main__": raise SystemExit(main())