""" Data validation utilities for ArcisVLM training. Call validate_dataset() before ANY training run to ensure data is real and clean. Catches: dummy images, blank questions, list-string answers, data poisoning. """ import json import os import re import torch from torch.utils.data import Dataset def validate_dataset(dataset: Dataset, stage_name: str, min_samples: int = 100): """Validate that a dataset contains real, clean data. Checks: 1. Dataset has enough samples 2. Images are not solid color (dummy) 3. Images vary across samples (not all identical) Raises RuntimeError if validation fails. """ n = len(dataset) if n < min_samples: raise RuntimeError( f"FATAL: {stage_name} dataset has only {n} samples (need {min_samples}+).\n" f"Download real data: python3 scripts/download_all_data.py" ) check_count = min(10, n) stds = [] for i in range(check_count): sample = dataset[i] img = sample.get("image", None) if img is None: raise RuntimeError(f"FATAL: {stage_name} sample {i} has no image tensor.") if not isinstance(img, torch.Tensor): raise RuntimeError(f"FATAL: {stage_name} sample {i} image is {type(img)}, expected torch.Tensor.") std = img.std().item() stds.append(std) if std < 0.001: raise RuntimeError( f"FATAL: {stage_name} sample {i} appears to be solid color (std={std:.6f}).\n" f"This looks like dummy data. Download real images first." ) std_range = max(stds) - min(stds) if std_range < 0.0001 and check_count > 3: raise RuntimeError( f"FATAL: {stage_name} all {check_count} checked images have identical statistics.\n" f"std range: {std_range:.6f}. This looks like dummy/synthetic data." ) print(f" Data validation PASSED for {stage_name}: " f"{n} samples, image std range={std_range:.4f}") return True def validate_jsonl_data(jsonl_dir: str, stage_name: str): """Validate JSONL training data BEFORE loading into Dataset. Catches data poisoning issues: 1. Blank/empty questions 2. Answers that are Python list strings (['a', 'b', ...]) 3. Single dominant answer (>30% of data) 4. Questions or answers that are always the same Raises RuntimeError if critical issues found. """ if not os.path.exists(jsonl_dir): return from collections import Counter total = 0 empty_questions = 0 list_string_answers = 0 answer_counts = Counter() question_counts = Counter() for fname in sorted(os.listdir(jsonl_dir)): if not fname.endswith('.jsonl'): continue fpath = os.path.join(jsonl_dir, fname) with open(fpath) as f: for line in f: try: item = json.loads(line.strip()) except json.JSONDecodeError: continue total += 1 q = item.get("question", "") a = item.get("answer", "") # Check 1: Empty questions if not q or len(q.strip()) == 0: empty_questions += 1 # Check 2: Answer looks like Python list string if isinstance(a, str) and a.startswith("[") and "'" in a: list_string_answers += 1 # Track distributions answer_counts[str(a)[:50]] += 1 question_counts[str(q)[:50]] += 1 issues = [] # Report print(f"\n [{stage_name}] Data Quality Check: {total} samples across {jsonl_dir}") if empty_questions > 0: pct = empty_questions / max(total, 1) * 100 msg = f" WARNING: {empty_questions} samples ({pct:.1f}%) have EMPTY questions" print(msg) if pct > 5: issues.append(msg) if list_string_answers > 0: pct = list_string_answers / max(total, 1) * 100 msg = f" WARNING: {list_string_answers} samples ({pct:.1f}%) have Python list-string answers" print(msg) if pct > 5: issues.append(msg) # Check dominant answers if answer_counts: top_answer, top_count = answer_counts.most_common(1)[0] top_pct = top_count / max(total, 1) * 100 if top_pct > 10: msg = f" WARNING: Most common answer '{top_answer}' appears {top_count}x ({top_pct:.1f}%)" print(msg) if top_pct > 30: issues.append(msg) # Check dominant questions if question_counts: top_q, top_q_count = question_counts.most_common(1)[0] top_q_pct = top_q_count / max(total, 1) * 100 if top_q_pct > 20: msg = f" WARNING: Most common question '{top_q[:40]}' appears {top_q_count}x ({top_q_pct:.1f}%)" print(msg) if issues: print(f"\n [{stage_name}] DATA QUALITY ISSUES FOUND:") for issue in issues: print(f" {issue}") raise RuntimeError( f"FATAL: {stage_name} has data quality issues that will poison training.\n" f"Fix the download script and re-download: python3 scripts/download_all_data.py\n" f"Issues:\n" + "\n".join(issues) ) else: print(f" [{stage_name}] Data quality PASSED")