arcisvlm / data /validation.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
5.41 kB
"""
Data validation utilities for ArcisVLM training.
Call validate_dataset() before ANY training run to ensure data is real and clean.
Catches: dummy images, blank questions, list-string answers, data poisoning.
"""
import json
import os
import re
import torch
from torch.utils.data import Dataset
def validate_dataset(dataset: Dataset, stage_name: str, min_samples: int = 100):
"""Validate that a dataset contains real, clean data.
Checks:
1. Dataset has enough samples
2. Images are not solid color (dummy)
3. Images vary across samples (not all identical)
Raises RuntimeError if validation fails.
"""
n = len(dataset)
if n < min_samples:
raise RuntimeError(
f"FATAL: {stage_name} dataset has only {n} samples (need {min_samples}+).\n"
f"Download real data: python3 scripts/download_all_data.py"
)
check_count = min(10, n)
stds = []
for i in range(check_count):
sample = dataset[i]
img = sample.get("image", None)
if img is None:
raise RuntimeError(f"FATAL: {stage_name} sample {i} has no image tensor.")
if not isinstance(img, torch.Tensor):
raise RuntimeError(f"FATAL: {stage_name} sample {i} image is {type(img)}, expected torch.Tensor.")
std = img.std().item()
stds.append(std)
if std < 0.001:
raise RuntimeError(
f"FATAL: {stage_name} sample {i} appears to be solid color (std={std:.6f}).\n"
f"This looks like dummy data. Download real images first."
)
std_range = max(stds) - min(stds)
if std_range < 0.0001 and check_count > 3:
raise RuntimeError(
f"FATAL: {stage_name} all {check_count} checked images have identical statistics.\n"
f"std range: {std_range:.6f}. This looks like dummy/synthetic data."
)
print(f" Data validation PASSED for {stage_name}: "
f"{n} samples, image std range={std_range:.4f}")
return True
def validate_jsonl_data(jsonl_dir: str, stage_name: str):
"""Validate JSONL training data BEFORE loading into Dataset.
Catches data poisoning issues:
1. Blank/empty questions
2. Answers that are Python list strings (['a', 'b', ...])
3. Single dominant answer (>30% of data)
4. Questions or answers that are always the same
Raises RuntimeError if critical issues found.
"""
if not os.path.exists(jsonl_dir):
return
from collections import Counter
total = 0
empty_questions = 0
list_string_answers = 0
answer_counts = Counter()
question_counts = Counter()
for fname in sorted(os.listdir(jsonl_dir)):
if not fname.endswith('.jsonl'):
continue
fpath = os.path.join(jsonl_dir, fname)
with open(fpath) as f:
for line in f:
try:
item = json.loads(line.strip())
except json.JSONDecodeError:
continue
total += 1
q = item.get("question", "")
a = item.get("answer", "")
# Check 1: Empty questions
if not q or len(q.strip()) == 0:
empty_questions += 1
# Check 2: Answer looks like Python list string
if isinstance(a, str) and a.startswith("[") and "'" in a:
list_string_answers += 1
# Track distributions
answer_counts[str(a)[:50]] += 1
question_counts[str(q)[:50]] += 1
issues = []
# Report
print(f"\n [{stage_name}] Data Quality Check: {total} samples across {jsonl_dir}")
if empty_questions > 0:
pct = empty_questions / max(total, 1) * 100
msg = f" WARNING: {empty_questions} samples ({pct:.1f}%) have EMPTY questions"
print(msg)
if pct > 5:
issues.append(msg)
if list_string_answers > 0:
pct = list_string_answers / max(total, 1) * 100
msg = f" WARNING: {list_string_answers} samples ({pct:.1f}%) have Python list-string answers"
print(msg)
if pct > 5:
issues.append(msg)
# Check dominant answers
if answer_counts:
top_answer, top_count = answer_counts.most_common(1)[0]
top_pct = top_count / max(total, 1) * 100
if top_pct > 10:
msg = f" WARNING: Most common answer '{top_answer}' appears {top_count}x ({top_pct:.1f}%)"
print(msg)
if top_pct > 30:
issues.append(msg)
# Check dominant questions
if question_counts:
top_q, top_q_count = question_counts.most_common(1)[0]
top_q_pct = top_q_count / max(total, 1) * 100
if top_q_pct > 20:
msg = f" WARNING: Most common question '{top_q[:40]}' appears {top_q_count}x ({top_q_pct:.1f}%)"
print(msg)
if issues:
print(f"\n [{stage_name}] DATA QUALITY ISSUES FOUND:")
for issue in issues:
print(f" {issue}")
raise RuntimeError(
f"FATAL: {stage_name} has data quality issues that will poison training.\n"
f"Fix the download script and re-download: python3 scripts/download_all_data.py\n"
f"Issues:\n" + "\n".join(issues)
)
else:
print(f" [{stage_name}] Data quality PASSED")