gazet / dataset /scripts /validate_dataset.py
srmsoumya's picture
Fix: No pairs are created for mixed queries
dfb9466
"""
Validate and balance the generated dataset.
This script:
1. Loads all generated samples
2. Validates SQL executability
3. Checks candidate list quality
4. Balances across task families and difficulty
5. Removes duplicates
6. Generates dataset statistics
Output:
- output/dataset_validated.jsonl
- output/dataset_stats.json
"""
import json
from pathlib import Path
from typing import List, Dict, Any, Tuple
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
import duckdb
import pandas as pd
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
"""Load samples from JSONL file."""
samples = []
with open(jsonl_path, 'r') as f:
for line in f:
samples.append(json.loads(line))
return samples
def _resolve_paths(sql: str) -> str:
"""Replace symbolic placeholder paths with actual runtime paths for execution."""
sql = sql.replace(
"read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')"
)
sql = sql.replace(
"read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
)
# Legacy fixed Docker paths from earlier dataset versions
sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
return sql
def _to_symbolic_sql(sql: str) -> str:
"""Normalize any hardcoded or runtime paths back to symbolic names for storage."""
# Current local runtime paths
sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
# Legacy Docker paths
sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
return sql
def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
"""Validate that SQL executes without error.
Resolves symbolic path placeholders to actual runtime paths before execution.
"""
try:
result = con.execute(_resolve_paths(sql)).fetchdf()
if result.empty:
return False, "Empty result"
return True, "OK"
except Exception as e:
return False, str(e)
def validate_candidates(sample: Dict[str, Any]) -> tuple[bool, str]:
"""Validate candidate list quality."""
candidates = sample['candidates']
selected = sample['target']['selected_candidates']
# Check we have candidates
if not candidates:
return False, "No candidates"
# Check selected candidates exist
candidate_ids = {c['candidate_id'] for c in candidates}
for sel_id in selected:
if sel_id not in candidate_ids:
return False, f"Selected candidate {sel_id} not in candidate list"
# Check for duplicates
ids = [c['id'] for c in candidates]
if len(ids) != len(set(ids)):
return False, "Duplicate candidates"
return True, "OK"
def validate_sample(con: duckdb.DuckDBPyConnection, sample: Dict[str, Any]) -> tuple[bool, List[str]]:
"""Validate a single sample. Returns (is_valid, list_of_issues)."""
issues = []
# Skip SQL re-execution if already verified during generation
if not sample.get('metadata', {}).get('sql_verified', False):
sql_valid, sql_msg = validate_sql(con, sample['target']['sql'])
if not sql_valid:
issues.append(f"SQL: {sql_msg}")
# Validate candidates
cand_valid, cand_msg = validate_candidates(sample)
if not cand_valid:
issues.append(f"Candidates: {cand_msg}")
# Check question exists
if not sample.get('question') or len(sample['question'].strip()) == 0:
issues.append("Empty question")
return len(issues) == 0, issues
def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]]:
"""Worker function for parallel validation. Returns (sample_id, is_valid, issues)."""
# Each worker creates its own DuckDB connection
con = duckdb.connect()
con.execute("SET enable_progress_bar=false")
con.execute("INSTALL spatial")
con.execute("LOAD spatial")
try:
is_valid, issues = validate_sample(con, sample)
con.close()
if is_valid:
sample['target']['sql'] = _to_symbolic_sql(sample['target']['sql'])
return (sample['id'], is_valid, issues, sample if is_valid else None)
except Exception as e:
con.close()
return (sample['id'], False, [f"Validation error: {str(e)}"], None)
def compute_statistics(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Compute dataset statistics."""
stats = {
'total_samples': len(samples),
'task_families': {},
'sql_difficulty': {},
'grounding_difficulty': {},
'anchor_sources': {},
'avg_candidates_per_sample': 0,
'avg_question_length': 0,
'countries_covered': set(),
'subtypes_covered': set()
}
total_candidates = 0
total_question_length = 0
for sample in samples:
meta = sample['metadata']
# Count by family
family = meta['task_family']
stats['task_families'][family] = stats['task_families'].get(family, 0) + 1
# Count by SQL difficulty
sql_diff = meta['sql_difficulty']
stats['sql_difficulty'][sql_diff] = stats['sql_difficulty'].get(sql_diff, 0) + 1
# Count by grounding difficulty
ground_diff = meta['grounding_difficulty']
stats['grounding_difficulty'][ground_diff] = stats['grounding_difficulty'].get(ground_diff, 0) + 1
# Count by anchor source
anchor_src = meta['anchor_source']
stats['anchor_sources'][anchor_src] = stats['anchor_sources'].get(anchor_src, 0) + 1
# Candidates
total_candidates += len(sample['candidates'])
# Question length
total_question_length += len(sample['question'].split())
# Countries and subtypes (from selected/answer candidates only)
selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
for cand in sample['candidates']:
if cand['candidate_id'] in selected_ids:
if cand.get('country'):
stats['countries_covered'].add(cand['country'])
if cand.get('subtype'):
stats['subtypes_covered'].add(cand['subtype'])
stats['avg_candidates_per_sample'] = total_candidates / len(samples) if samples else 0
stats['avg_question_length'] = total_question_length / len(samples) if samples else 0
stats['countries_covered'] = sorted(list(stats['countries_covered']))
stats['subtypes_covered'] = sorted(list(stats['subtypes_covered']))
return stats
def main():
"""Validate and analyze dataset."""
script_dir = Path(__file__).parent
output_dir = script_dir.parent / "output"
raw_file = output_dir / "dataset_raw.jsonl"
validated_file = output_dir / "dataset_validated.jsonl"
stats_file = output_dir / "dataset_stats.json"
if not raw_file.exists():
print(f"Error: {raw_file} not found. Run generate_samples.py first.")
return
# Load samples
print("Loading samples...")
samples = load_samples(raw_file)
print(f"Loaded {len(samples)} samples")
# Validate samples in parallel
print("\nValidating samples in parallel...")
valid_samples = []
invalid_samples = []
with ProcessPoolExecutor(max_workers=8) as executor:
# Submit all validation tasks
futures = {executor.submit(validate_sample_worker, sample): sample for sample in samples}
# Collect results as they complete
completed = 0
for future in as_completed(futures):
sample_id, is_valid, issues, validated_sample = future.result()
if is_valid:
valid_samples.append(validated_sample)
else:
invalid_samples.append((sample_id, issues))
completed += 1
if completed % 50 == 0 or completed == len(samples):
print(f"\r Progress: {completed}/{len(samples)} ", end='', flush=True)
print() # New line after progress
print(f"\nValidation results:")
print(f" Valid: {len(valid_samples)}")
print(f" Invalid: {len(invalid_samples)}")
if invalid_samples and len(invalid_samples) <= 20:
print("\nInvalid samples:")
for sample_id, issues in invalid_samples[:20]:
print(f" {sample_id}: {', '.join(issues)}")
elif invalid_samples:
print(f"\n{len(invalid_samples)} invalid samples (showing first 20):")
for sample_id, issues in invalid_samples[:20]:
print(f" {sample_id}: {', '.join(issues)}")
# Save validated samples
if valid_samples:
with open(validated_file, 'w') as f:
for sample in valid_samples:
f.write(json.dumps(sample) + '\n')
print(f"\nSaved {len(valid_samples)} valid samples to {validated_file}")
# Compute statistics
print("\nComputing statistics...")
stats = compute_statistics(valid_samples)
# Save statistics
# Convert sets to lists for JSON serialization
stats_json = {k: (list(v) if isinstance(v, set) else v) for k, v in stats.items()}
with open(stats_file, 'w') as f:
json.dump(stats_json, f, indent=2)
print(f"Saved statistics to {stats_file}")
# Print summary
print("\n" + "=" * 60)
print("DATASET STATISTICS")
print("=" * 60)
print(f"\nTotal samples: {stats['total_samples']}")
print("\nTask families:")
for family, count in sorted(stats['task_families'].items()):
print(f" {family:20s}: {count:3d}")
print("\nSQL difficulty:")
for diff, count in sorted(stats['sql_difficulty'].items()):
print(f" {diff:20s}: {count:3d}")
print("\nGrounding difficulty:")
for diff, count in sorted(stats['grounding_difficulty'].items()):
print(f" {diff:20s}: {count:3d}")
print("\nAnchor sources:")
for src, count in sorted(stats['anchor_sources'].items()):
print(f" {src:20s}: {count:3d}")
print(f"\nAverage candidates per sample: {stats['avg_candidates_per_sample']:.1f}")
print(f"Average question length (words): {stats['avg_question_length']:.1f}")
print(f"Countries covered: {len(stats['countries_covered'])}")
print(f"Subtypes covered: {len(stats['subtypes_covered'])}")
print("\n✓ Validation complete")
if __name__ == "__main__":
main()