File size: 8,361 Bytes
f526878
 
 
 
 
 
 
 
6e691a3
f526878
5ce8003
6e691a3
f526878
 
 
 
 
 
5ce8003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e691a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f526878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e691a3
f526878
 
 
 
 
 
 
 
6e691a3
 
 
5ce8003
 
 
 
 
 
 
 
6e691a3
 
 
 
 
 
 
 
 
f526878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""Dataset validation for BLUX-cA QLoRA pipeline.

Checks that training JSONL files conform to the expected message schema and
contain the BLUX-cA system placeholder.
"""
from __future__ import annotations

import argparse
import importlib.util
import json
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple

SYSTEM_PLACEHOLDER = "<SYSTEM_PROMPT_FROM_BLUX_CA>"


def run_cli_validator(dataset_dir: Path, files: Optional[List[Path]] = None) -> List[str]:
    """Invoke the dataset repository's validator script via subprocess."""

    validator_path = dataset_dir / "tools" / "validate_jsonl.py"
    if not validator_path.exists():
        return []

    rel_files = []
    if files:
        for f in files:
            if f.is_absolute() and dataset_dir in f.parents:
                rel_files.append(str(f.relative_to(dataset_dir)))
            else:
                rel_files.append(str(f))

    cmd = [sys.executable, str(validator_path), *rel_files]
    result = subprocess.run(cmd, capture_output=True, text=True, cwd=dataset_dir)
    if result.returncode != 0:
        output = (result.stdout + "\n" + result.stderr).strip()
        return [line for line in output.splitlines() if line.strip()] or [
            f"Validator exited with code {result.returncode}",
            f"Re-run manually: python {validator_path}",
        ]
    return []


def _load_external_validator(dataset_dir: Path):
    """Load dataset-provided validator if available.

    Returns a callable with signature List[Path] -> Dict[Path, List[str]]
    mapping file paths to lists of errors. If not available, returns None.
    """

    validator_path = dataset_dir / "tools" / "validate_jsonl.py"
    if not validator_path.exists():
        return None

    spec = importlib.util.spec_from_file_location("blux_ca_dataset_validator", validator_path)
    if not spec or not spec.loader:  # pragma: no cover - defensive
        return None

    module = importlib.util.module_from_spec(spec)
    sys.modules["blux_ca_dataset_validator"] = module
    spec.loader.exec_module(module)

    def validate(files: List[Path]) -> Dict[Path, List[str]]:
        error_map: Dict[Path, List[str]] = {}
        for file_path in files:
            errs = getattr(module, "validate_file")(file_path)
            if errs:
                error_map[file_path] = errs
        return error_map

    return validate


def _iter_jsonl(path: Path) -> Tuple[int, Dict]:
    with path.open("r", encoding="utf-8") as handle:
        for idx, line in enumerate(handle, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                yield idx, json.loads(line)
            except json.JSONDecodeError as exc:
                raise ValueError(f"{path} line {idx}: invalid JSON ({exc})") from exc


def _validate_messages(messages: List[Dict], path: Path, line_no: int, strict: bool) -> List[str]:
    errors: List[str] = []
    if not isinstance(messages, list) or not messages:
        errors.append(f"{path} line {line_no}: 'messages' must be a non-empty list")
        return errors

    roles = [m.get("role") for m in messages if isinstance(m, dict)]
    if roles.count("system") < 1:
        errors.append(f"{path} line {line_no}: missing system role")
    if roles.count("user") < 1:
        errors.append(f"{path} line {line_no}: missing user role")
    if roles.count("assistant") < 1:
        errors.append(f"{path} line {line_no}: missing assistant role")

    system_messages = [m for m in messages if isinstance(m, dict) and m.get("role") == "system"]
    if system_messages:
        system_content = system_messages[0].get("content", "")
        if system_content != SYSTEM_PLACEHOLDER:
            errors.append(
                f"{path} line {line_no}: system content must equal {SYSTEM_PLACEHOLDER!r}"
            )
    else:
        errors.append(f"{path} line {line_no}: system message missing")

    for m in messages:
        if not isinstance(m, dict):
            errors.append(f"{path} line {line_no}: each message must be an object")
            continue
        if not m.get("role"):
            errors.append(f"{path} line {line_no}: message missing role")
        if not m.get("content"):
            errors.append(f"{path} line {line_no}: message missing content for role {m.get('role')}")
        audit_errors = _validate_audit_notes(m.get("content", ""), path, line_no)
        errors.extend(audit_errors)

    if strict:
        last_role = messages[-1].get("role") if isinstance(messages[-1], dict) else None
        if last_role != "assistant":
            errors.append(f"{path} line {line_no}: strict mode requires assistant as last role")

    return errors


def _validate_audit_notes(content: str, path: Path, line_no: int) -> List[str]:
    if "## Audit Notes" not in content:
        return []
    errors: List[str] = []
    lines = content.splitlines()
    try:
        header_index = lines.index("## Audit Notes")
    except ValueError:
        return errors
    bullets = lines[header_index + 1 :]
    for bullet in bullets:
        if bullet.strip() and not bullet.strip().startswith("- "):
            errors.append(f"{path} line {line_no}: Audit Notes must use '- ' bullets")
    return errors


def validate_file(path: Path, strict: bool) -> Tuple[int, int, List[str]]:
    total = 0
    errors: List[str] = []
    for line_no, record in _iter_jsonl(path):
        total += 1
        if not isinstance(record, dict):
            errors.append(f"{path} line {line_no}: expected JSON object per line")
            continue
        messages = record.get("messages")
        errors.extend(_validate_messages(messages, path, line_no, strict))
    return total, len(errors), errors


def validate_dataset(dataset_dir: Path, files: Optional[str] = None, strict: bool = False) -> Tuple[int, List[str]]:
    if not dataset_dir.exists():
        return 0, [f"Dataset directory not found: {dataset_dir}"]

    data_dir = dataset_dir / "data"
    eval_dir = dataset_dir / "eval"
    candidates: List[Path]
    if files:
        candidates = [data_dir / f for f in files.split(",")]
    else:
        candidates = sorted(data_dir.glob("*.jsonl"))

    if not candidates:
        return 0, [f"No JSONL files found under {data_dir}"]
    if not eval_dir.exists():
        return 0, [f"Eval probes missing: {eval_dir}"]

    missing_files = [path for path in candidates if not path.exists()]
    if missing_files:
        return 0, [f"Missing file: {path}" for path in missing_files]

    cli_errors = run_cli_validator(dataset_dir, candidates)
    if cli_errors:
        return 0, cli_errors

    external_validator = _load_external_validator(dataset_dir)
    if external_validator:
        print("Using dataset-supplied validator")
        errors_map = external_validator(candidates)
        overall_errors = [f"{path}: {err}" for path, errs in errors_map.items() for err in errs]
        total_lines = sum(1 for path in candidates for _ in _iter_jsonl(path))
        return total_lines, overall_errors

    overall_errors = []
    total_lines = 0
    for path in candidates:
        if not path.exists():
            overall_errors.append(f"Missing file: {path}")
            continue
        count, error_count, errors = validate_file(path, strict=strict)
        total_lines += count
        overall_errors.extend(errors)
        print(f"Validated {path} - lines: {count}, errors: {error_count}")

    return total_lines, overall_errors


def main() -> int:
    parser = argparse.ArgumentParser(description="Validate BLUX-cA JSONL datasets")
    parser.add_argument("--dataset-dir", required=True, type=Path, help="Path to dataset repository")
    parser.add_argument("--files", type=str, default=None, help="Comma-separated list of data/*.jsonl files")
    parser.add_argument("--strict", action="store_true", help="Enable strict ordering and audit checks")
    args = parser.parse_args()

    total_lines, errors = validate_dataset(args.dataset_dir, files=args.files, strict=args.strict)
    if errors:
        print("Validation errors:")
        for err in errors:
            print(f"- {err}")
        return 1

    print(f"Validation passed. Files checked: {total_lines} lines total")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())