""" Validate and balance the generated dataset. This script: 1. Loads all generated samples 2. Validates SQL executability 3. Checks candidate list quality 4. Balances across task families and difficulty 5. Removes duplicates 6. Generates dataset statistics Output: - output/dataset_validated.jsonl - output/dataset_stats.json """ import json from pathlib import Path from typing import List, Dict, Any, Tuple from collections import Counter from concurrent.futures import ProcessPoolExecutor, as_completed import duckdb import pandas as pd from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]: """Load samples from JSONL file.""" samples = [] with open(jsonl_path, 'r') as f: for line in f: samples.append(json.loads(line)) return samples def _resolve_paths(sql: str) -> str: """Replace symbolic placeholder paths with actual runtime paths for execution.""" sql = sql.replace( "read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')" ) sql = sql.replace( "read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')" ) # Legacy fixed Docker paths from earlier dataset versions sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH) sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH) sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH) return sql def _to_symbolic_sql(sql: str) -> str: """Normalize any hardcoded or runtime paths back to symbolic names for storage.""" # Current local runtime paths sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area") sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth") # Legacy Docker paths sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area") sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area") sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth") return sql def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]: """Validate that SQL executes without error. Resolves symbolic path placeholders to actual runtime paths before execution. """ try: result = con.execute(_resolve_paths(sql)).fetchdf() if result.empty: return False, "Empty result" return True, "OK" except Exception as e: return False, str(e) def validate_candidates(sample: Dict[str, Any]) -> tuple[bool, str]: """Validate candidate list quality.""" candidates = sample['candidates'] selected = sample['target']['selected_candidates'] # Check we have candidates if not candidates: return False, "No candidates" # Check selected candidates exist candidate_ids = {c['candidate_id'] for c in candidates} for sel_id in selected: if sel_id not in candidate_ids: return False, f"Selected candidate {sel_id} not in candidate list" # Check for duplicates ids = [c['id'] for c in candidates] if len(ids) != len(set(ids)): return False, "Duplicate candidates" return True, "OK" def validate_sample(con: duckdb.DuckDBPyConnection, sample: Dict[str, Any]) -> tuple[bool, List[str]]: """Validate a single sample. Returns (is_valid, list_of_issues).""" issues = [] # Skip SQL re-execution if already verified during generation if not sample.get('metadata', {}).get('sql_verified', False): sql_valid, sql_msg = validate_sql(con, sample['target']['sql']) if not sql_valid: issues.append(f"SQL: {sql_msg}") # Validate candidates cand_valid, cand_msg = validate_candidates(sample) if not cand_valid: issues.append(f"Candidates: {cand_msg}") # Check question exists if not sample.get('question') or len(sample['question'].strip()) == 0: issues.append("Empty question") return len(issues) == 0, issues def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]]: """Worker function for parallel validation. Returns (sample_id, is_valid, issues).""" # Each worker creates its own DuckDB connection con = duckdb.connect() con.execute("SET enable_progress_bar=false") con.execute("INSTALL spatial") con.execute("LOAD spatial") try: is_valid, issues = validate_sample(con, sample) con.close() if is_valid: sample['target']['sql'] = _to_symbolic_sql(sample['target']['sql']) return (sample['id'], is_valid, issues, sample if is_valid else None) except Exception as e: con.close() return (sample['id'], False, [f"Validation error: {str(e)}"], None) def compute_statistics(samples: List[Dict[str, Any]]) -> Dict[str, Any]: """Compute dataset statistics.""" stats = { 'total_samples': len(samples), 'task_families': {}, 'sql_difficulty': {}, 'grounding_difficulty': {}, 'anchor_sources': {}, 'avg_candidates_per_sample': 0, 'avg_question_length': 0, 'countries_covered': set(), 'subtypes_covered': set() } total_candidates = 0 total_question_length = 0 for sample in samples: meta = sample['metadata'] # Count by family family = meta['task_family'] stats['task_families'][family] = stats['task_families'].get(family, 0) + 1 # Count by SQL difficulty sql_diff = meta['sql_difficulty'] stats['sql_difficulty'][sql_diff] = stats['sql_difficulty'].get(sql_diff, 0) + 1 # Count by grounding difficulty ground_diff = meta['grounding_difficulty'] stats['grounding_difficulty'][ground_diff] = stats['grounding_difficulty'].get(ground_diff, 0) + 1 # Count by anchor source anchor_src = meta['anchor_source'] stats['anchor_sources'][anchor_src] = stats['anchor_sources'].get(anchor_src, 0) + 1 # Candidates total_candidates += len(sample['candidates']) # Question length total_question_length += len(sample['question'].split()) # Countries and subtypes (from selected/answer candidates only) selected_ids = set(sample.get('target', {}).get('selected_candidates', [])) for cand in sample['candidates']: if cand['candidate_id'] in selected_ids: if cand.get('country'): stats['countries_covered'].add(cand['country']) if cand.get('subtype'): stats['subtypes_covered'].add(cand['subtype']) stats['avg_candidates_per_sample'] = total_candidates / len(samples) if samples else 0 stats['avg_question_length'] = total_question_length / len(samples) if samples else 0 stats['countries_covered'] = sorted(list(stats['countries_covered'])) stats['subtypes_covered'] = sorted(list(stats['subtypes_covered'])) return stats def main(): """Validate and analyze dataset.""" script_dir = Path(__file__).parent output_dir = script_dir.parent / "output" raw_file = output_dir / "dataset_raw.jsonl" validated_file = output_dir / "dataset_validated.jsonl" stats_file = output_dir / "dataset_stats.json" if not raw_file.exists(): print(f"Error: {raw_file} not found. Run generate_samples.py first.") return # Load samples print("Loading samples...") samples = load_samples(raw_file) print(f"Loaded {len(samples)} samples") # Validate samples in parallel print("\nValidating samples in parallel...") valid_samples = [] invalid_samples = [] with ProcessPoolExecutor(max_workers=8) as executor: # Submit all validation tasks futures = {executor.submit(validate_sample_worker, sample): sample for sample in samples} # Collect results as they complete completed = 0 for future in as_completed(futures): sample_id, is_valid, issues, validated_sample = future.result() if is_valid: valid_samples.append(validated_sample) else: invalid_samples.append((sample_id, issues)) completed += 1 if completed % 50 == 0 or completed == len(samples): print(f"\r Progress: {completed}/{len(samples)} ", end='', flush=True) print() # New line after progress print(f"\nValidation results:") print(f" Valid: {len(valid_samples)}") print(f" Invalid: {len(invalid_samples)}") if invalid_samples and len(invalid_samples) <= 20: print("\nInvalid samples:") for sample_id, issues in invalid_samples[:20]: print(f" {sample_id}: {', '.join(issues)}") elif invalid_samples: print(f"\n{len(invalid_samples)} invalid samples (showing first 20):") for sample_id, issues in invalid_samples[:20]: print(f" {sample_id}: {', '.join(issues)}") # Save validated samples if valid_samples: with open(validated_file, 'w') as f: for sample in valid_samples: f.write(json.dumps(sample) + '\n') print(f"\nSaved {len(valid_samples)} valid samples to {validated_file}") # Compute statistics print("\nComputing statistics...") stats = compute_statistics(valid_samples) # Save statistics # Convert sets to lists for JSON serialization stats_json = {k: (list(v) if isinstance(v, set) else v) for k, v in stats.items()} with open(stats_file, 'w') as f: json.dump(stats_json, f, indent=2) print(f"Saved statistics to {stats_file}") # Print summary print("\n" + "=" * 60) print("DATASET STATISTICS") print("=" * 60) print(f"\nTotal samples: {stats['total_samples']}") print("\nTask families:") for family, count in sorted(stats['task_families'].items()): print(f" {family:20s}: {count:3d}") print("\nSQL difficulty:") for diff, count in sorted(stats['sql_difficulty'].items()): print(f" {diff:20s}: {count:3d}") print("\nGrounding difficulty:") for diff, count in sorted(stats['grounding_difficulty'].items()): print(f" {diff:20s}: {count:3d}") print("\nAnchor sources:") for src, count in sorted(stats['anchor_sources'].items()): print(f" {src:20s}: {count:3d}") print(f"\nAverage candidates per sample: {stats['avg_candidates_per_sample']:.1f}") print(f"Average question length (words): {stats['avg_question_length']:.1f}") print(f"Countries covered: {len(stats['countries_covered'])}") print(f"Subtypes covered: {len(stats['subtypes_covered'])}") print("\n✓ Validation complete") if __name__ == "__main__": main()