Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| CLI for synthetic dataset generation. | |
| Usage: | |
| python cli.py build-relations --config ../config.yaml | |
| python cli.py generate-samples --config ../config.yaml | |
| python cli.py generate-samples --config ../config.yaml --append | |
| python cli.py full-pipeline --config ../config.yaml | |
| """ | |
| import argparse | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict | |
| import pandas as pd | |
| import yaml | |
| def load_config(config_path: Path) -> dict: | |
| """Load configuration from YAML file.""" | |
| with open(config_path) as f: | |
| return yaml.safe_load(f) | |
| def should_rebuild_relations(config: dict, intermediate_dir: Path, append: bool) -> bool: | |
| """Check if relation tables need to be rebuilt. | |
| Returns True if: | |
| - Not in append mode (always rebuild) | |
| - Relation tables don't exist | |
| - Countries in config differ from countries in existing relation tables | |
| """ | |
| if not append: | |
| return True | |
| # Check if relation tables exist | |
| adjacency_file = intermediate_dir / "adjacency_pairs.parquet" | |
| if not adjacency_file.exists(): | |
| print("WARNING: Relation tables not found, will rebuild despite append mode") | |
| return True | |
| # Check if countries have changed | |
| try: | |
| df = pd.read_parquet(adjacency_file) | |
| if 'anchor_country' in df.columns: | |
| existing_countries = set(df['anchor_country'].unique()) | |
| config_countries = set(config['countries']) | |
| if existing_countries != config_countries: | |
| print(f"WARNING: Countries changed:") | |
| print(f" Previous: {sorted(existing_countries)}") | |
| print(f" New: {sorted(config_countries)}") | |
| print(f" Will rebuild relation tables to include new countries") | |
| return True | |
| else: | |
| print(f"Countries unchanged: {sorted(config_countries)}") | |
| return False | |
| else: | |
| # Can't determine countries, rebuild to be safe | |
| print("WARNING: Cannot determine countries from existing tables, will rebuild") | |
| return True | |
| except Exception as e: | |
| print(f"WARNING: Error reading existing relation tables: {e}") | |
| print(" Will rebuild to be safe") | |
| return True | |
| def calculate_relation_limits(config: dict) -> Dict[str, int]: | |
| """Auto-calculate relation limits based on sample targets.""" | |
| sample_targets = config['sample_targets'] | |
| retry_mult = config['generation']['retry_multiplier'] | |
| safety = config.get('auto_scaling', {}).get('safety_factor', 1.5) | |
| # Map each task family to the relation tables it draws anchors from. | |
| # A family can need multiple relation types. | |
| family_to_relations = { | |
| 'direct_lookup': [], | |
| 'adjacency': ['adjacency'], | |
| 'multi_adjacency': ['adjacency', 'common_neighbor'], | |
| 'containment': ['containment'], | |
| 'intersection': ['intersection', 'cross_source'], | |
| 'buffer': ['adjacency'], | |
| 'chained': ['coastal_containment', 'landlocked_containment', 'containment'], | |
| 'difference': ['containment', 'cross_source'], | |
| 'border_corridor': ['adjacency'], | |
| 'set_operations': ['containment', 'cross_source'], | |
| 'partial_selection': ['containment', 'cross_source'], | |
| 'aggregation': ['containment'], | |
| 'window_function': [], | |
| 'attribute_filter': [], | |
| } | |
| relation_needs: Dict[str, int] = {} | |
| for family, target in sample_targets.items(): | |
| for rel_type in family_to_relations.get(family, []): | |
| needed = int(target * retry_mult * safety) | |
| relation_needs[rel_type] = relation_needs.get(rel_type, 0) + needed | |
| # common_neighbor is derived from adjacency — keep its limit proportional | |
| if 'common_neighbor' not in relation_needs and 'adjacency' in relation_needs: | |
| relation_needs['common_neighbor'] = relation_needs['adjacency'] * 3 | |
| # Apply manual overrides if specified | |
| manual = config.get('auto_scaling', {}).get('manual_limits', {}) | |
| relation_needs.update(manual) | |
| return relation_needs | |
| def normalize_data(): | |
| """Build normalized source parquet copies with harmonized geometry metadata.""" | |
| print("=" * 60) | |
| print("STEP 0: Normalizing Source Geodata") | |
| print("=" * 60) | |
| from dataset.scripts.normalize_geodata import normalize_geodata | |
| result = normalize_geodata() | |
| for name, path in result.items(): | |
| print(f" {name}: {path}") | |
| def build_relations(config_path: Path): | |
| """Run relation building with config.""" | |
| config = load_config(config_path) | |
| # Auto-calculate relation limits | |
| relation_limits = calculate_relation_limits(config) | |
| print("=" * 60) | |
| print("STEP 1: Building Relation Tables") | |
| print("=" * 60) | |
| print(f"Countries: {', '.join(config['countries'])}") | |
| print(f"\nAuto-calculated relation limits:") | |
| for rel_type, limit in relation_limits.items(): | |
| print(f" {rel_type:20s}: {limit:,}") | |
| print() | |
| # Import and run the relation builder | |
| from dataset.scripts import build_relations | |
| # Run with config parameters | |
| build_relations.main( | |
| countries=config['countries'], | |
| relation_limits=relation_limits | |
| ) | |
| print("\nRelation tables built successfully") | |
| def generate_samples(config_path: Path, append: bool = False): | |
| """Run sample generation with config.""" | |
| config = load_config(config_path) | |
| print("=" * 60) | |
| print("STEP 2: Generating Samples") | |
| print("=" * 60) | |
| print(f"Targets: {config['sample_targets']}") | |
| print(f"Workers: {config['generation']['max_workers']}") | |
| print(f"Append mode: {append or config['generation']['append_mode']}") | |
| print() | |
| # Simple import - no number prefixes needed | |
| from dataset.scripts import generate_samples as gs_module | |
| # Override config values | |
| gs_module.TARGET_COUNTS = config['sample_targets'] | |
| gs_module.MAX_WORKERS = config['generation']['max_workers'] | |
| gs_module.RETRY_MULTIPLIER = config['generation']['retry_multiplier'] | |
| gs_module.APPEND_MODE = append or config['generation']['append_mode'] | |
| # Run the main function | |
| gs_module.main() | |
| print("\nSamples generated successfully") | |
| def validate_dataset(config_path: Path): | |
| """Run dataset validation.""" | |
| print("=" * 60) | |
| print("STEP 3: Validating Dataset") | |
| print("=" * 60) | |
| script_dir = Path(__file__).parent | |
| result = subprocess.run( | |
| [sys.executable, str(script_dir / "validate_dataset.py")], | |
| check=True | |
| ) | |
| print("\nDataset validated successfully") | |
| def export_dataset(config_path: Path): | |
| """Run dataset export for both SQL generation and place extraction tasks.""" | |
| print("=" * 60) | |
| print("STEP 4: Exporting Dataset") | |
| print("=" * 60) | |
| from dataset.scripts.export_training_data import main as export_main | |
| export_main(config_path=config_path) | |
| print("\nDataset exported successfully") | |
| def modal_upload(config_path: Path): | |
| """Upload local data to Modal volume.""" | |
| subprocess.run( | |
| [sys.executable, "-m", "modal", "run", | |
| "dataset/modal_app.py::upload_data"], | |
| check=True | |
| ) | |
| def modal_generate(config_path: Path, num_containers: int = 0, | |
| skip_inventory: bool = False, skip_relations: bool = False, | |
| fresh: bool = False): | |
| """Run distributed generation on Modal (appends by default).""" | |
| cmd = [ | |
| sys.executable, "-m", "modal", "run", | |
| "dataset/modal_app.py::run_pipeline", | |
| "--config-path", str(config_path), | |
| ] | |
| if num_containers > 0: | |
| cmd.extend(["--num-containers", str(num_containers)]) | |
| if skip_inventory: | |
| cmd.append("--skip-inventory") | |
| if skip_relations: | |
| cmd.append("--skip-relations") | |
| if fresh: | |
| cmd.append("--fresh") | |
| subprocess.run(cmd, check=True) | |
| validate_dataset(config_path) | |
| export_dataset(config_path) | |
| def full_pipeline(config_path: Path, append: bool = False): | |
| """Run the full pipeline.""" | |
| print("Running full dataset generation pipeline") | |
| config = load_config(config_path) | |
| # Check if inventory exists, create if not | |
| script_dir = Path(__file__).parent | |
| intermediate_dir = script_dir.parent / "intermediate" | |
| inventory_files = [ | |
| intermediate_dir / "divisions_area_inventory.parquet", | |
| intermediate_dir / "natural_earth_inventory.parquet" | |
| ] | |
| inventory_missing = any(not f.exists() for f in inventory_files) | |
| if inventory_missing: | |
| print("=" * 60) | |
| print("STEP 0: Building Entity Inventory") | |
| print("=" * 60) | |
| print("Inventory files not found, building...") | |
| from dataset.scripts import build_inventory | |
| build_inventory.main() | |
| # Check if we need to rebuild relations | |
| need_rebuild = should_rebuild_relations(config, intermediate_dir, append) | |
| if need_rebuild: | |
| build_relations(config_path) | |
| else: | |
| print("Using existing relation tables (append mode, same countries)") | |
| generate_samples(config_path, append=append) | |
| validate_dataset(config_path) | |
| export_dataset(config_path) | |
| print("\nPipeline complete") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Synthetic dataset generation CLI", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Normalize source geodata first (recommended before Modal upload) | |
| python cli.py normalize-data --config ../config.yaml | |
| # Build relation tables only | |
| python cli.py build-relations --config ../config.yaml | |
| # Generate samples only | |
| python cli.py generate-samples --config ../config.yaml | |
| # Generate and append to existing dataset | |
| python cli.py generate-samples --config ../config.yaml --append | |
| # Run full pipeline | |
| python cli.py full-pipeline --config ../config.yaml | |
| # Run full pipeline in append mode (skip relation building) | |
| python cli.py full-pipeline --config ../config.yaml --append | |
| # Upload data to Modal volume (one-time) | |
| python cli.py modal-upload --config ../config.yaml | |
| # Run distributed generation on Modal | |
| python cli.py modal-generate --config ../config.yaml | |
| python cli.py modal-generate --config ../config.yaml --num-containers 100 | |
| python cli.py modal-generate --config ../config.yaml --skip-inventory --skip-relations | |
| """ | |
| ) | |
| parser.add_argument( | |
| 'command', | |
| choices=['normalize-data', 'build-relations', 'generate-samples', 'validate', 'export', | |
| 'full-pipeline', 'modal-upload', 'modal-generate'], | |
| help='Command to run' | |
| ) | |
| parser.add_argument( | |
| '--config', | |
| type=Path, | |
| required=True, | |
| help='Path to config YAML file' | |
| ) | |
| parser.add_argument( | |
| '--append', | |
| action='store_true', | |
| help='Append to existing dataset instead of overwriting' | |
| ) | |
| parser.add_argument( | |
| '--num-containers', | |
| type=int, | |
| default=0, | |
| help='Number of Modal containers (0 = use config default)' | |
| ) | |
| parser.add_argument( | |
| '--skip-inventory', | |
| action='store_true', | |
| help='Skip inventory building on Modal' | |
| ) | |
| parser.add_argument( | |
| '--skip-relations', | |
| action='store_true', | |
| help='Skip relation building on Modal' | |
| ) | |
| parser.add_argument( | |
| '--fresh', | |
| action='store_true', | |
| help='Overwrite existing dataset instead of appending (Modal only)' | |
| ) | |
| args = parser.parse_args() | |
| # Validate config file exists | |
| if not args.config.exists(): | |
| print(f"Error: Config file not found: {args.config}") | |
| sys.exit(1) | |
| # Run the appropriate command | |
| try: | |
| if args.command == 'normalize-data': | |
| normalize_data() | |
| elif args.command == 'build-relations': | |
| build_relations(args.config) | |
| elif args.command == 'generate-samples': | |
| generate_samples(args.config, args.append) | |
| elif args.command == 'validate': | |
| validate_dataset(args.config) | |
| elif args.command == 'export': | |
| export_dataset(args.config) | |
| elif args.command == 'full-pipeline': | |
| full_pipeline(args.config, args.append) | |
| elif args.command == 'modal-upload': | |
| modal_upload(args.config) | |
| elif args.command == 'modal-generate': | |
| modal_generate( | |
| args.config, | |
| num_containers=args.num_containers, | |
| skip_inventory=args.skip_inventory, | |
| skip_relations=args.skip_relations, | |
| fresh=args.fresh, | |
| ) | |
| except Exception as e: | |
| print(f"\nError: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |