""" Generate synthetic training samples for text-to-SQL task. This script: 1. Loads relation tables and entity inventories 2. For each SQL template, samples valid anchors 3. Renders and executes SQL to verify it works 4. Builds candidate lists with controlled distractors 5. Generates natural language questions using LLM 6. Saves complete training samples Output: - output/samples/sample_*.json (individual samples) - output/dataset_raw.jsonl (all samples) """ import json import random import re import warnings from pathlib import Path from typing import List, Dict, Any, Optional from concurrent.futures import ProcessPoolExecutor, as_completed from functools import partial import duckdb import pandas as pd from pydantic import BaseModel # Suppress warnings warnings.filterwarnings('ignore') from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH # Fixed paths embedded in every training SQL string. # The model learns these short, stable strings rather than machine-specific # local paths. At inference, sql.py's _rewrite_data_paths substitutes them # with the actual runtime paths from gazet.config. _DIVISIONS_SQL_PATH = 'divisions_area' _NATURAL_EARTH_SQL_PATH = 'natural_earth' def _for_execution(sql: str) -> str: """Replace symbolic placeholder paths with actual local paths for verification.""" return ( sql .replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')") .replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')") ) # Configurable parameters (can be overridden by CLI) TARGET_COUNTS = None # Will be set in main() or by CLI MAX_WORKERS = 8 RETRY_MULTIPLIER = 2 APPEND_MODE = False _GENERIC_SURFACE_RULES = [ ("spelling_neighboring", r"\bneighbouring\b", ["neighboring"]), ("spelling_neighbors", r"\bneighbours\b", ["neighbors"]), ("expand_whats", r"\bwhat's\b", ["what is"]), ("show_me", r"\bshow me\b", ["show", "display"]), ("give_me", r"\bgive me\b", ["show", "list"]), ("pull_up", r"\bpull up\b", ["show", "display"]), ("find_to_show", r"\bfind\b", ["show", "locate"]), ("kilometers_variant", r"\bkilometers\b", ["km"]), ("metres_variant", r"\bmetres\b", ["meters"]), ("recognised_variant", r"\brecognised\b", ["recognized"]), ] _FAMILY_SURFACE_RULES = { "adjacency": [ ("which_border_to_next_to", r"\bwhich (.+?) border (.+)\?", [r"which \1 are next to \2?", r"which \1 are adjacent to \2?"]), ("bordering_to_next_to", r"\bbordering (.+)", [r"next to \1", r"adjacent to \1"]), ("touching_to_next_to", r"\btouching (.+)", [r"next to \1"]), ("share_border_to_adjacent", r"share a border with", ["are adjacent to", "are next to"]), ("adjacent_to_next_to", r"adjacent to", ["next to"]), ], "multi_adjacency": [ ("which_border_both_to_next_to", r"\bwhich (.+?) border both (.+)\?", [r"which \1 are next to both \2?", r"which \1 are adjacent to both \2?"]), ("touch_both_to_next_to", r"touch both", ["are next to both"]), ("adjacent_both_to_next_to", r"adjacent to both", ["next to both"]), ], "containment": [ ("within_to_inside", r"\bwithin\b", ["inside", "in"]), ("inside_to_in", r"\binside\b", ["in"]), ("belonging_to_in", r"belonging to", ["in"]), ("contain_to_have", r"\bcontain\b", ["have"]), ], "intersection": [ ("which_intersect_to_overlap", r"\bwhich (.+?) intersect (.+)\?", [r"which \1 overlap \2?"]), ("overlap_with_to_intersect", r"overlap with", ["intersect"]), ("crossing_to_overlapping", r"crossing into", ["overlapping"]), ("partly_in_to_overlap", r"partly in", ["overlapping"]), ], "buffer": [ ("within_distance_to_from", r"within ([0-9]+\s*(?:km|m)) of", [r"up to \1 from", r"at a distance of \1 from"]), ("buffer_to_radius", r"\bbuffer\b", ["radius", "zone"]), ("close_to_near", r"close to", ["near"]), ("around_to_near", r"what is around", ["what is near"]), ], "chained": [ ("coastal_to_seaside", r"\bcoastal\b", ["seaside", "maritime"]), ("landlocked_to_inland", r"\blandlocked\b", ["inland"]), ("sea_access_to_coast", r"sea access", ["a coastline"]), ], "difference": [ ("part_to_portion", r"\bpart of\b", ["portion of", "section of"]), ("outside_to_excluding", r"\boutside\b", ["excluding"]), ], "border_corridor": [ ("zone_to_buffer", r"\bzone\b", ["buffer", "corridor"]), ("within_distance_to_along", r"within ([0-9]+ km) of the", [r"along the", r"up to \1 from the"]), ], "set_operations": [ ("combined_to_merged", r"combined", ["merged"]), ("union_of_to_merged_area", r"\bunion of\b", ["merged area of", "combined area of"]), ("merge_to_combine", r"\bmerge\b", ["combine"]), ("together_to_combined", r"\btogether\b", ["combined"]), ], "partial_selection": [ ("part_to_portion", r"\bpart of\b", ["portion of", "section of"]), ("half_to_side", r"\bhalf\b", ["side"]), ], "aggregation": [ ("largest_to_biggest", r"\blargest\b", ["biggest"]), ("smallest_to_tiniest", r"\bsmallest\b", ["tiniest"]), ], "window_function": [ ("largest_to_biggest", r"\blargest\b", ["biggest"]), ("smallest_to_tiniest", r"\bsmallest\b", ["tiniest"]), ], "attribute_filter": [ ("official_to_recognized", r"\bofficial\b", ["recognized", "recognized territorial"]), ("land_based_to_on_land", r"land-based", ["on-land", "on land"]), ("sovereign_to_official", r"\bsovereign\b", ["official"]), ], "direct_lookup": [ ("where_is_to_show", r"\bwhere is\b", ["show", "locate"]), ("map_of_to_outline", r"\bmap of\b", ["outline of"]), ], "disambiguation": [ ("show_me_to_find", r"\bshow me\b", ["find", "show"]), ("pull_up_to_find", r"\bpull up\b", ["find", "show"]), ], } def _diversify_question_surface(question: str, family: str) -> tuple[str, List[str]]: """Apply light family-aware paraphrasing to reduce template memorization. Rewrites are intentionally shallow and lexically local so the generated question stays aligned with the underlying SQL intent. """ if not question or random.random() < 0.35: return question, [] rules = _GENERIC_SURFACE_RULES + _FAMILY_SURFACE_RULES.get(family, []) rewritten = question applied: List[str] = [] max_rewrites = 2 if random.random() < 0.5 else 1 for _ in range(max_rewrites): matches = [] for label, pattern, replacements in rules: if re.search(pattern, rewritten, flags=re.IGNORECASE): for replacement in replacements: matches.append((label, pattern, replacement)) if not matches: break label, pattern, replacement = random.choice(matches) updated = re.sub(pattern, replacement, rewritten, count=1, flags=re.IGNORECASE) if updated == rewritten: continue rewritten = re.sub(r"\s+", " ", updated).strip() applied.append(f"{family}:{label}") return rewritten, applied # Import templates from same directory from . import sql_templates TEMPLATES = sql_templates.TEMPLATES SQLTemplate = sql_templates.SQLTemplate get_templates_by_family = sql_templates.get_templates_by_family _NE_NAMED_LOOKUP_SUBTYPES = { 'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression', } _NE_TEMPLATE_SUBTYPES = { 'lookup_02': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'}, 'adj_03': {'sea', 'ocean'}, 'adj_09': {'river', 'lake', 'basin'}, 'adj_10': {'range/mtn', 'peninsula', 'depression'}, 'adj_06': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression', 'plateau', 'plain', 'lowland', 'valley', 'gorge'}, 'adj_07': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression', 'plateau', 'plain', 'lowland', 'valley', 'gorge'}, 'adj_08': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression', 'plateau', 'plain', 'lowland', 'valley', 'gorge'}, 'contain_04': {'sea', 'ocean', 'gulf', 'bay', 'basin', 'island group', 'peninsula', 'range/mtn', 'depression'}, 'contain_05': {'sea', 'ocean', 'gulf', 'bay', 'strait'}, 'intersect_03': {'river', 'lake', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression'}, 'intersect_04': {'river', 'lake', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression'}, 'intersect_06': {'river', 'lake', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression'}, 'buffer_02': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'}, 'buffer_11': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'}, 'chained_03': {'island group', 'peninsula', 'range/mtn', 'depression'}, 'chained_04': {'river', 'lake', 'basin'}, 'chained_05': {'range/mtn', 'depression'}, 'chained_08': {'river', 'lake', 'basin'}, 'chained_09': {'range/mtn', 'depression'}, 'partial_05': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'}, 'diff_02': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'}, } class Candidate(BaseModel): """Candidate entity for grounding.""" candidate_id: str source: str id: str name: str subtype: Optional[str] = None country: Optional[str] = None region: Optional[str] = None admin_level: Optional[int] = None similarity: float = 0.0 class TrainingSample(BaseModel): """Complete training sample.""" id: str question: str candidates: List[Candidate] target: Dict[str, Any] metadata: Dict[str, Any] def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[str, pd.DataFrame]: """Load all precomputed relation tables.""" tables = {} for file in intermediate_dir.glob("*.parquet"): name = file.stem tables[name] = pd.read_parquet(file) if not quiet: print(f" {name}: {len(tables[name])} rows") return tables def sample_adjacency_anchor( adjacency_df: pd.DataFrame, target_subtype: Optional[str] = None, anchor_subtypes: Optional[List[str]] = None, ) -> Optional[Dict[str, Any]]: """Sample a random adjacency pair, optionally filtered by subtypes. When ``target_subtype`` is provided, only rows whose neighbouring feature matches that subtype are considered. When ``anchor_subtypes`` is provided, only rows whose anchor feature is one of those subtypes are considered. Both filters are applied together so sampled pairs are geographically coherent with the template intent (e.g. country anchor → country result). """ if adjacency_df.empty: return None df = adjacency_df if target_subtype is not None: df = df[df['target_subtype'] == target_subtype] if df.empty: return None if anchor_subtypes is not None: filtered = df[df['anchor_subtype'].isin(anchor_subtypes)] if not filtered.empty: df = filtered row = df.sample(n=1).iloc[0] return { 'anchor_id': row['anchor_id'], 'anchor_name': row['anchor_name'], 'anchor_subtype': row['anchor_subtype'], 'anchor_country': row.get('anchor_country'), # May not exist in all tables 'target_id': row.get('target_id'), 'target_name': row.get('target_name'), 'target_subtype': row.get('target_subtype') } def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[str, Any]]: """Sample a random intersection pair.""" if intersection_df.empty: return None row = intersection_df.sample(n=1).iloc[0] return { 'anchor_id': row['anchor_id'], 'anchor_name': row['anchor_name'], 'anchor_subtype': row['anchor_subtype'], 'target_id': row.get('target_id'), 'target_name': row.get('target_name'), 'target_subtype': row.get('target_subtype') } def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]: """Sample a random containment pair. Returns both ends of the pair so callers that need the contained entity (e.g. difference templates that clip container by contained) can use it directly without a second random draw. """ if containment_df.empty: return None row = containment_df.sample(n=1).iloc[0] return { 'container_id': row['container_id'], 'container_name': row['container_name'], 'container_subtype': row['container_subtype'], 'contained_id': row['contained_id'], 'contained_name': row['contained_name'], 'contained_subtype': row['contained_subtype'], } def sample_disambiguation_anchor( containment_df: pd.DataFrame, contained_subtypes: List[str], container_subtypes: List[str], ) -> Optional[Dict[str, Any]]: """Sample a (contained, container) pair from containment_pairs. Used by disambiguation templates like "Puri, Odisha" where the contained entity is the target and the container provides disambiguation context. """ if containment_df.empty: return None df = containment_df[ containment_df['contained_subtype'].isin(contained_subtypes) & containment_df['container_subtype'].isin(container_subtypes) ] if df.empty: return None row = df.sample(n=1).iloc[0] return { 'contained_id': row['contained_id'], 'contained_name': row['contained_name'], 'contained_subtype': row['contained_subtype'], 'container_id': row['container_id'], 'container_name': row['container_name'], 'container_subtype': row['container_subtype'], } def sample_cross_source_anchor( cross_source_df: pd.DataFrame, natural_subtypes: Optional[set[str]] = None, relation_types: Optional[set[str]] = None, ) -> Optional[Dict[str, Any]]: """Sample a random cross-source relation with optional subtype filters.""" if cross_source_df.empty: return None df = cross_source_df if natural_subtypes is not None: df = df[df['natural_subtype'].isin(natural_subtypes)] if relation_types is not None: df = df[df['relation_type'].isin(relation_types)] if df.empty: return None row = df.sample(n=1).iloc[0] return { 'division_id': row['division_id'], 'division_name': row['division_name'], 'division_subtype': row['division_subtype'], 'natural_id': row['natural_id'], 'natural_name': row['natural_name'], 'natural_subtype': row['natural_subtype'], 'relation_type': row['relation_type'] } def _merge_candidate_lists( *lists: List[Candidate], max_total: int = 10, ) -> List[Candidate]: """Merge N candidate lists, deduplicate by id, reassign candidate_ids. Interleaves the lists so each anchor is represented before any anchor gets a second candidate — matching the grouped-then-interleaved order that inference produces. """ from itertools import zip_longest seen: set = set() merged: List[Candidate] = [] for row in zip_longest(*lists): for c in row: if c is None: continue if c.id not in seen: merged.append(c) seen.add(c.id) if len(merged) >= max_total: break if len(merged) >= max_total: break for i, c in enumerate(merged, 1): c.candidate_id = f"c{i}" return merged def _dedupe_country_candidates( candidates: List[Candidate], max_total: Optional[int] = None, ) -> List[Candidate]: """Deduplicate country candidates by country code, preserving first match. This is useful for templates whose SQL uses ``country IN (...)`` rather than candidate IDs. Overture can contain multiple country-level rows for the same ISO code, which weakens grounding if they all remain in the list. """ deduped: List[Candidate] = [] seen_keys: set[tuple[str, str]] = set() for cand in candidates: if cand.subtype == "country" and cand.country: key = ("country", cand.country) else: key = ("id", cand.id) if key in seen_keys: continue deduped.append(cand) seen_keys.add(key) if max_total is not None and len(deduped) >= max_total: break for i, cand in enumerate(deduped, 1): cand.candidate_id = f"c{i}" return deduped def build_candidate_list( con: duckdb.DuckDBPyConnection, anchor_id: str, anchor_name: str, anchor_source: str, num_candidates: int = 10, difficulty: str = "medium" ) -> List[Candidate]: """Build candidate list with true anchor + distractors.""" # Helper to convert pandas NA to None def safe_get(row, key, default=None): val = row.get(key, default) return None if pd.isna(val) else val # Get the true anchor if anchor_source == "divisions_area": query = """ SELECT id, names."primary" AS name, subtype, country, region, admin_level FROM read_parquet(?) WHERE id = ? """ anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0] else: query = """ SELECT id, names."primary" AS name, subtype FROM read_parquet(?) WHERE id = ? """ anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0] true_candidate = Candidate( candidate_id="c1", source=anchor_source, id=anchor_id, name=safe_get(anchor_row, 'name'), subtype=safe_get(anchor_row, 'subtype'), country=safe_get(anchor_row, 'country'), region=safe_get(anchor_row, 'region'), admin_level=safe_get(anchor_row, 'admin_level'), similarity=1.0, ) distractors = build_distractors( con, anchor_name, anchor_source, anchor_id, num_candidates - 1, difficulty, ) # Deduplicate by underlying entity id while preserving order. # Some parquet sources contain repeated rows for the same feature id, # which can otherwise leak duplicate candidates into the dataset. candidates: List[Candidate] = [] seen_ids: set[str] = set() for cand in [true_candidate] + distractors: if cand.id in seen_ids: continue candidates.append(cand) seen_ids.add(cand.id) if len(candidates) >= num_candidates: break for i, cand in enumerate(candidates, 1): cand.candidate_id = f"c{i}" return candidates def build_distractors( con: duckdb.DuckDBPyConnection, anchor_name: str, anchor_source: str, exclude_id: str, num_distractors: int, difficulty: str, cross_source_ratio: float = 0.5, ) -> List[Candidate]: """Build distractor candidates using fuzzy search. Always includes candidates from both sources so the model sees mixed ``source`` values in every training example — matching the inference behaviour where search.py queries divisions_area AND natural_earth equally (5 results each per place). Args: cross_source_ratio: Fraction of distractors drawn from the *other* source. Defaults to 0.5 (50/50 split) to match inference exactly. """ def safe_get(row, key, default=None): val = row.get(key, default) return None if pd.isna(val) else val def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]: query = """ WITH ranked AS ( SELECT id, names."primary" AS name, subtype, country, region, admin_level, jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity, ROW_NUMBER() OVER ( PARTITION BY id ORDER BY jaro_winkler_similarity(lower(names."primary"), lower(?)) DESC ) AS rn FROM read_parquet(?) WHERE id != ? AND names."primary" IS NOT NULL AND trim(names."primary") != '' AND geometry IS NOT NULL ) SELECT id, name, subtype, country, region, admin_level, similarity FROM ranked WHERE rn = 1 ORDER BY similarity DESC LIMIT ? """ df = con.execute(query, [anchor_name, anchor_name, path, excl_id, n]).fetchdf() results = [] for _, row in df.iterrows(): results.append(Candidate( candidate_id="temp", source=src_name, id=row["id"], name=safe_get(row, "name"), subtype=safe_get(row, "subtype"), country=safe_get(row, "country"), region=safe_get(row, "region"), admin_level=safe_get(row, "admin_level"), similarity=float(row["similarity"]), )) return results cross_n = max(1, round(num_distractors * cross_source_ratio)) same_n = num_distractors - cross_n if anchor_source == "divisions_area": same = _query_source(DIVISIONS_AREA_PATH, "divisions_area", same_n, exclude_id) cross = _query_source(NATURAL_EARTH_PATH, "natural_earth", cross_n, "") else: same = _query_source(NATURAL_EARTH_PATH, "natural_earth", same_n, exclude_id) cross = _query_source(DIVISIONS_AREA_PATH, "divisions_area", cross_n, "") return same + cross def sample_random_entity( con: duckdb.DuckDBPyConnection, inventory_df: pd.DataFrame, source: str, subtypes: Optional[set[str]] = None, countries: Optional[set[str]] = None, ) -> Optional[Dict[str, Any]]: """Sample a random entity from inventory with optional filters.""" if inventory_df.empty: return None df = inventory_df if subtypes is not None: df = df[df['subtype'].isin(subtypes)] if countries is not None and 'country' in df.columns: df = df[df['country'].isin(countries)] if df.empty: return None row = df.sample(n=1).iloc[0] return { 'id': row['id'], 'name': row['name'], 'subtype': row.get('subtype'), 'country': row.get('country'), 'source': source } def generate_template_based_sample( con: duckdb.DuckDBPyConnection, template: SQLTemplate, tables: Dict[str, pd.DataFrame], sample_id: str ) -> Optional[TrainingSample]: """Generate a sample based on a SQL template.""" # Sample anchor based on template requirements if template.family == "direct_lookup": # Just pick a random entity if template.anchor_source == "divisions_area": anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') else: anchor = sample_random_entity( con, tables['natural_earth_inventory'], 'natural_earth', subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES), ) if not anchor: return None # Render SQL sql = template.sql_template.format( anchor_id=anchor['id'] ) # Build candidates candidates = build_candidate_list( con, anchor['id'], anchor['name'], anchor['source'], num_candidates=10, difficulty="easy" ) # Question question = random.choice(template.question_hints).format(anchor_name=anchor['name']) elif template.family == "disambiguation": # "Puri, Odisha" style: pick a (contained, container) pair whose # subtypes match the template, build candidates that include the # container + same-name distractors so the model must read the CSV # to pick the right entry. _disambig_subtypes = { "disambiguate_01": (["county"], ["region", "country"]), "disambiguate_02": (["county"], ["country"]), "disambiguate_03": (["region"], ["country"]), } contained_sts, container_sts = _disambig_subtypes.get( template.template_id, (["county"], ["country"]) ) pair = sample_disambiguation_anchor( tables["containment_pairs"], contained_sts, container_sts ) if not pair: return None candidates = build_candidate_list( con, pair["contained_id"], pair["contained_name"], "divisions_area", num_candidates=10, difficulty="hard" ) # Ensure the container is among the candidates so the model can # ground the disambiguation context (e.g. "Odisha"). if not any(c.id == pair["container_id"] for c in candidates): container_rows = con.execute( 'SELECT id, names."primary" AS name, subtype, country, region, admin_level ' 'FROM read_parquet(?) WHERE id = ? LIMIT 1', [DIVISIONS_AREA_PATH, pair["container_id"]] ).fetchdf() if container_rows.empty: return None crow = container_rows.iloc[0] def _nn(v): return None if pd.isna(v) else v container_cand = Candidate( candidate_id="temp", source="divisions_area", id=pair["container_id"], name=_nn(crow["name"]), subtype=_nn(crow["subtype"]), country=_nn(crow["country"]), region=_nn(crow["region"]), admin_level=_nn(crow["admin_level"]), similarity=0.95, ) # Insert the container right after the true target and drop the # last filler distractor so the total stays at 10. candidates = [candidates[0], container_cand] + candidates[1:-1] for i, c in enumerate(candidates, 1): c.candidate_id = f"c{i}" sql = template.sql_template.format(anchor_id=pair["contained_id"]) question = random.choice(template.question_hints).format( anchor_name=pair["contained_name"], container_name=pair["container_name"], ) # Only the contained entity is the query target — the container is # disambiguation context and stays in candidates but NOT in # selected_candidates. The model learns to use the container row of # the CSV (via country/region columns) to pick the right same-name # county or region. anchor = {"id": pair["contained_id"], "name": pair["contained_name"]} elif template.family == "adjacency": # adj_03/09/10/11/12: division anchor -> natural_earth targets. # adj_06/07/08: natural_earth anchor -> admin targets. # Use cross_source_relations so anchors are guaranteed to intersect. _NE_TARGET_ADJ_SUBTYPES = { "adj_03": ("ocean", "sea"), "adj_09": ("river", "lake", "basin"), "adj_10": ("range/mtn",), "adj_11": ("plateau",), "adj_12": ("plain", "lowland", "basin", "valley", "depression", "gorge"), } if template.template_id in _NE_TARGET_ADJ_SUBTYPES: cs_df = tables.get('cross_source_relations', pd.DataFrame()) if cs_df.empty: return None ne_types = _NE_TARGET_ADJ_SUBTYPES[template.template_id] filtered = cs_df[cs_df['natural_subtype'].isin(ne_types)] if filtered.empty: return None row = filtered.sample(n=1).iloc[0] anchor = { 'anchor_id': row['division_id'], 'anchor_name': row['division_name'], 'anchor_subtype': row['division_subtype'], 'target_subtype': row['natural_subtype'], 'anchor_source': 'divisions_area', } elif template.anchor_source == "natural_earth": cs_anchor = sample_cross_source_anchor( tables.get('cross_source_relations', pd.DataFrame()), natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id), ) if not cs_anchor: return None anchor = { 'anchor_id': cs_anchor['natural_id'], 'anchor_name': cs_anchor['natural_name'], 'target_subtype': template.target_subtype, 'anchor_source': 'natural_earth', } else: # divisions_area self-join adjacency. _ADJ_ANCHOR_SUBTYPES = { "adj_02": ["country", "region"], "adj_04": ["region"], "adj_05": ["country"], } filter_subtype = ( template.target_subtype if '{target_subtype}' in template.sql_template else None ) anchor = sample_adjacency_anchor( tables['adjacency_pairs'], target_subtype=filter_subtype, anchor_subtypes=_ADJ_ANCHOR_SUBTYPES.get(template.template_id), ) if anchor: anchor['anchor_source'] = 'divisions_area' if not anchor: return None sql = template.sql_template.format( anchor_id=anchor['anchor_id'], target_subtype=anchor.get('target_subtype', ''), ) candidates = build_candidate_list( con, anchor['anchor_id'], anchor['anchor_name'], anchor.get('anchor_source', 'divisions_area'), num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=anchor['anchor_name'], target_subtype=anchor.get('target_subtype', ''), ) elif template.family == "containment": if template.anchor_source == "natural_earth": # contain_04 / contain_05: NE anchor (sea, desert, etc.). # Use cross_source_relations so the anchor exists in natural_earth # and is guaranteed to intersect divisions_area features. cs_anchor = sample_cross_source_anchor( tables.get('cross_source_relations', pd.DataFrame()), natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id), ) if not cs_anchor: return None anchor_id = cs_anchor['natural_id'] anchor_name = cs_anchor['natural_name'] target_subtype = template.target_subtype or 'country' sql = template.sql_template.format( anchor_id=anchor_id, target_subtype=target_subtype, ) candidates = build_candidate_list( con, anchor_id, anchor_name, 'natural_earth', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=anchor_name, target_subtype=target_subtype, ) anchor = {'id': anchor_id, 'name': anchor_name} elif template.template_id == "contain_02": # "What country contains X?" - anchor is the CONTAINED entity; # result is the country that ST_Contains it. # Guard against stale relation tables by only allowing contained # subtypes that exist in the simplified admin schema. df = tables['containment_pairs'] df = df[ (df['container_subtype'] == 'country') & (df['contained_subtype'].isin(['region', 'county'])) ] pair = sample_containment_anchor(df) if not pair: return None sql = template.sql_template.format( anchor_id=pair['contained_id'], target_subtype='country', ) candidates = build_candidate_list( con, pair['contained_id'], pair['contained_name'], 'divisions_area', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=pair['contained_name'], target_subtype='country', ) anchor = {'id': pair['contained_id'], 'name': pair['contained_name']} elif template.template_id == "contain_03": # "What regions are in country X?" - anchor is a country, target is regions. df = tables['containment_pairs'] df = df[ (df['container_subtype'] == 'country') & (df['contained_subtype'] == 'region') ] pair = sample_containment_anchor(df) if not pair: return None sql = template.sql_template.format( anchor_id=pair['container_id'], target_subtype='region', ) candidates = build_candidate_list( con, pair['container_id'], pair['container_name'], 'divisions_area', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=pair['container_name'], target_subtype='region', ) anchor = {'id': pair['container_id'], 'name': pair['container_name']} else: # contain_01: standard containment. # Enforce hierarchy: county must be inside region or country, never # inside another county. Filter container_subtype accordingly. # Also filter contained_subtype to match template.target_subtype so # hardcoded vocab hints (e.g. "districts") always align with the SQL. _VALID_CONTAINERS = { "county": ["region", "country"], "region": ["country"], } df = tables['containment_pairs'] if template.target_subtype: filtered = df[df['contained_subtype'] == template.target_subtype] if not filtered.empty: df = filtered valid_containers = _VALID_CONTAINERS.get(template.target_subtype) if valid_containers: filtered = df[df['container_subtype'].isin(valid_containers)] if not filtered.empty: df = filtered anchor = sample_containment_anchor(df) if not anchor: return None target_subtype = template.target_subtype or anchor['contained_subtype'] sql = template.sql_template.format( anchor_id=anchor['container_id'], target_subtype=target_subtype, ) candidates = build_candidate_list( con, anchor['container_id'], anchor['container_name'], 'divisions_area', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=anchor['container_name'], target_subtype=target_subtype, ) elif template.family == "intersection": if template.anchor_source == "natural_earth": anchor = sample_cross_source_anchor( tables['cross_source_relations'], natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id), ) if not anchor: return None target_subtype = template.target_subtype or 'country' sql = template.sql_template.format( anchor_id=anchor['natural_id'], target_subtype=target_subtype, ) candidates = build_candidate_list( con, anchor['natural_id'], anchor['natural_name'], 'natural_earth', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=anchor['natural_name'], target_subtype=target_subtype, ) else: # Same-source intersection. # If the template pins a target_subtype (e.g. intersect_02 targets county), # filter intersection_pairs so the sampled pair is guaranteed to match. idf = tables['intersection_pairs'] if template.target_subtype and not idf.empty: filtered = idf[idf['target_subtype'] == template.target_subtype] if filtered.empty: return None idf = filtered anchor = sample_intersection_anchor(idf) if not anchor: return None target_subtype = template.target_subtype or anchor.get('target_subtype') or 'region' sql = template.sql_template.format( anchor_id=anchor['anchor_id'], target_subtype=target_subtype ) candidates = build_candidate_list( con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=anchor['anchor_name'], target_subtype=target_subtype ) elif template.family == "set_operations": if template.template_id == "union_03": # 3-anchor union by ID — candidates: 3 per anchor (9 total) anchors = [ sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') for _ in range(3) ] if any(a is None for a in anchors): return None anchor1, anchor2, anchor3 = anchors sql = template.sql_template.format( anchor_id_1=anchor1['id'], anchor_id_2=anchor2['id'], anchor_id_3=anchor3['id'], ) per_anchor = 3 cands = [ build_candidate_list(con, a['id'], a['name'], 'divisions_area', num_candidates=per_anchor, difficulty="medium") for a in anchors ] candidates = _merge_candidate_lists(*cands, max_total=9) question = random.choice(template.question_hints).format( anchor_1_name=anchor1['name'], anchor_2_name=anchor2['name'], anchor_3_name=anchor3['name'], ) elif template.template_id in ("contain_multi_01", "contain_multi_02", "contain_multi_03"): # country IN clause — 2 or 3 anchors, each contributes its country code. # Sample unique countries so the query actually teaches a multi-country # pattern rather than repeating the same ISO code multiple times. num_a = 3 if template.template_id == "contain_multi_02" else 2 country_inventory = tables['divisions_area_inventory'] country_inventory = country_inventory[ (country_inventory['subtype'] == 'country') & country_inventory['country'].notna() ].drop_duplicates(subset=['country']) if len(country_inventory) < num_a: return None sampled = country_inventory.sample(n=num_a, replace=False) anchors = [ { 'id': row['id'], 'name': row['name'], 'subtype': row.get('subtype'), 'country': row.get('country'), 'source': 'divisions_area', } for _, row in sampled.iterrows() ] countries = [a.get('country') or 'US' for a in anchors] target_subtype = template.target_subtype or 'region' per_anchor = 3 if num_a == 3 else 4 fmt_kwargs = dict( target_subtype=target_subtype, ) for i, c in enumerate(countries, 1): fmt_kwargs[f'country_{i}'] = c sql = template.sql_template.format(**fmt_kwargs) cands = [ build_candidate_list(con, a['id'], a['name'], 'divisions_area', num_candidates=per_anchor, difficulty="medium") for a in anchors ] candidates = _dedupe_country_candidates( _merge_candidate_lists(*cands, max_total=num_a * per_anchor), max_total=num_a * per_anchor, ) q_kwargs = dict(target_subtype=target_subtype) for i, a in enumerate(anchors, 1): q_kwargs[f'anchor_{i}_name'] = a['name'] question = random.choice(template.question_hints).format(**q_kwargs) elif template.template_id == "union_02": # Filtered union: ST_Union_Agg of contained sub-features. # Pin to template.target_subtype so hardcoded vocabulary hints # (e.g. "districts") always match the SQL subtype. df = tables['containment_pairs'] if template.target_subtype: filtered = df[df['contained_subtype'] == template.target_subtype] if not filtered.empty: df = filtered pair = sample_containment_anchor(df) if not pair: return None target_subtype = template.target_subtype or pair.get('contained_subtype', 'county') sql = template.sql_template.format( anchor_id=pair['container_id'], target_subtype=target_subtype, ) candidates = build_candidate_list( con, pair['container_id'], pair['container_name'], 'divisions_area', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=pair['container_name'], target_subtype=target_subtype, ) else: # union_01: 2-anchor union by ID — candidates: 5 per anchor anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') if not anchor1 or not anchor2: return None sql = template.sql_template.format( anchor_id_1=anchor1['id'], anchor_id_2=anchor2['id'], ) cands1 = build_candidate_list( con, anchor1['id'], anchor1['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) cands2 = build_candidate_list( con, anchor2['id'], anchor2['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates = _merge_candidate_lists(cands1, cands2, max_total=10) question = random.choice(template.question_hints).format( anchor_1_name=anchor1['name'], anchor_2_name=anchor2['name'], ) elif template.family == "buffer": # Buffer operations use metre distances in SQL and a human-readable # buffer_label in questions, e.g. (1000, "1 km") or (250, "250 m"). # The template SQL divides by 111 320 to approximate metres in degrees. _buffer_choices = [ (100, "100 m"), (250, "250 m"), (500, "500 m"), (1000, "1 km"), (2000, "2 km"), (5000, "5 km"), (10000, "10 km"), (25000, "25 km"), (50000, "50 km"), (100000, "100 km"), (200000, "200 km"), ] if template.num_anchors == 1: if template.anchor_source == "natural_earth": anchor = sample_random_entity( con, tables['natural_earth_inventory'], 'natural_earth', subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES), ) else: anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') if not anchor: return None buffer_m, buffer_label = random.choice(_buffer_choices) fmt_kwargs = dict(anchor_id=anchor['id'], buffer_m=buffer_m) q_kwargs = dict(anchor_name=anchor['name'], buffer_label=buffer_label) if template.target_subtype: fmt_kwargs['target_subtype'] = template.target_subtype q_kwargs['target_subtype'] = template.target_subtype sql = template.sql_template.format(**fmt_kwargs) candidates = build_candidate_list( con, anchor['id'], anchor['name'], anchor['source'], num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format(**q_kwargs) else: # Multi-anchor buffer (2–5 places): union of individual buffers. num_a = template.num_anchors anchors = [] for _ in range(num_a): a = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') if not a: return None anchors.append(a) buffer_m, buffer_label = random.choice(_buffer_choices[:7]) fmt_kwargs = {f'anchor_id_{i+1}': a['id'] for i, a in enumerate(anchors)} fmt_kwargs['buffer_m'] = buffer_m if template.target_subtype: fmt_kwargs['target_subtype'] = template.target_subtype sql = template.sql_template.format(**fmt_kwargs) # Build one candidate list per anchor then merge. per_anchor_n = max(2, 10 // num_a) cand_lists = [ build_candidate_list( con, a['id'], a['name'], 'divisions_area', num_candidates=per_anchor_n, difficulty="medium", ) for a in anchors ] candidates = _merge_candidate_lists(*cand_lists) q_kwargs = {f'anchor_{i+1}_name': a['name'] for i, a in enumerate(anchors)} q_kwargs['buffer_label'] = buffer_label if template.target_subtype: q_kwargs['target_subtype'] = template.target_subtype question = random.choice(template.question_hints).format(**q_kwargs) elif template.family == "partial_selection": # Partial selection (northern half, clipping, etc.) anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area') if not anchor: return None if template.num_anchors == 1: sql = template.sql_template.format( anchor_id=anchor['id'], ) question = random.choice(template.question_hints).format( anchor_name=anchor['name'], ) candidates = build_candidate_list( con, anchor['id'], anchor['name'], 'divisions_area', num_candidates=10, difficulty="hard", ) else: # Mixed-source clip: division intersected with a natural_earth feature. # Use cross_source_relations so the pair is guaranteed to intersect — # random sampling almost never produces an intersecting pair. cs_anchor = sample_cross_source_anchor( tables.get('cross_source_relations', pd.DataFrame()), natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id), ) if not cs_anchor: return None clip_feature = { 'id': cs_anchor['natural_id'], 'name': cs_anchor['natural_name'], 'source': 'natural_earth', } # Override the division anchor with the paired division so the # ST_Intersects check in the SQL is guaranteed to pass. anchor = { 'id': cs_anchor['division_id'], 'name': cs_anchor['division_name'], 'source': 'divisions_area', } sql = template.sql_template.format( anchor_id=anchor['id'], clip_feature_id=clip_feature['id'], ) question = random.choice(template.question_hints).format( anchor_name=anchor['name'], clip_feature_name=clip_feature['name'], ) # Build candidates for BOTH anchors so the model sees both IDs # in context and learns to pick the right one for each placeholder. div_cands = build_candidate_list( con, anchor['id'], anchor['name'], 'divisions_area', num_candidates=5, difficulty="hard", ) ne_cands = build_candidate_list( con, clip_feature['id'], clip_feature['name'], 'natural_earth', num_candidates=5, difficulty="hard", ) candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10) elif template.family == "aggregation": # Teach the model to distinguish singular superlatives ("the largest") # from explicit top-N requests ("top 5 largest"). top_n = random.choice([1, 3, 5, 10]) target_subtype = random.choice(['county', 'region']) singular_hints = [h for h in template.question_hints if '{top_n}' not in h] plural_hints = [h for h in template.question_hints if '{top_n}' in h] question_hint_pool = singular_hints if top_n == 1 and singular_hints else plural_hints or template.question_hints if template.template_id in ['agg_03', 'agg_04']: # Country-level aggregation: SQL uses country code, so the anchor # in the question must also be a country. anchor = sample_random_entity( con, tables['divisions_area_inventory'], 'divisions_area', subtypes={'country'}, ) if not anchor: return None country = anchor.get('country') or 'US' sql = template.sql_template.format( country=country, target_subtype=target_subtype, top_n=top_n, ) candidates = build_candidate_list( con, anchor['id'], anchor['name'], 'divisions_area', num_candidates=10, difficulty="hard" ) question = random.choice(question_hint_pool).format( top_n=top_n, target_subtype=target_subtype, anchor_name=anchor['name'], ) else: # Containment-based aggregation: anchor is the container region. anchor = sample_containment_anchor(tables['containment_pairs']) if not anchor: return None sql = template.sql_template.format( anchor_id=anchor['container_id'], target_subtype=target_subtype, top_n=top_n, ) candidates = build_candidate_list( con, anchor['container_id'], anchor['container_name'], 'divisions_area', num_candidates=10, difficulty="hard" ) question = random.choice(question_hint_pool).format( top_n=top_n, target_subtype=target_subtype, anchor_name=anchor['container_name'], ) elif template.family == "chained": # chained_12/13: country-level coastal/landlocked via adjacency. # The SQL uses ST_Touches (not containment), so bypass the containment # pair sampling and use adjacency_pairs with country-level anchors. if template.template_id in {"chained_12", "chained_13"}: adj_df = tables.get('adjacency_pairs', pd.DataFrame()) country_adj = ( adj_df[ (adj_df['anchor_subtype'] == 'country') & (adj_df['target_subtype'] == 'country') ] if not adj_df.empty else pd.DataFrame() ) if country_adj.empty: return None pair = sample_adjacency_anchor(country_adj) if not pair: return None sql = template.sql_template.format(anchor_id=pair['anchor_id']) candidates = build_candidate_list( con, pair['anchor_id'], pair['anchor_name'], 'divisions_area', num_candidates=10, difficulty="hard" ) question = random.choice(template.question_hints).format( anchor_name=pair['anchor_name'] ) anchor = {'id': pair['anchor_id'], 'name': pair['anchor_name']} else: # Use pre-filtered coastal/landlocked containment pairs so the SQL # verification step doesn't constantly return empty results. _COASTAL_CHAINED = {"chained_01", "chained_06", "chained_10"} _LANDLOCKED_CHAINED = {"chained_02", "chained_07", "chained_11"} if template.template_id in _COASTAL_CHAINED: table_key = 'coastal_containment_pairs' elif template.template_id in _LANDLOCKED_CHAINED: table_key = 'landlocked_containment_pairs' else: table_key = 'containment_pairs' df = tables.get(table_key, tables['containment_pairs']) # When the template pins a target_subtype (e.g. chained_06 wants # counties), only consider pairs whose contained entity already # matches — guarantees the sampled container holds at least one # entity of the right subtype so the SQL filter returns rows. if template.target_subtype: df = df[df['contained_subtype'] == template.target_subtype] # chained_10/11 additionally need a country-level container so # phrasings like "coastal states of India" line up. if template.template_id in {"chained_10", "chained_11"}: df = df[df['container_subtype'] == 'country'] anchor = sample_containment_anchor(df) if not anchor: return None target_subtype = template.target_subtype or anchor.get('contained_subtype', 'county') sql = template.sql_template.format( anchor_id=anchor['container_id'], target_subtype=target_subtype, ) candidates = build_candidate_list( con, anchor['container_id'], anchor['container_name'], 'divisions_area', num_candidates=10, difficulty="hard" ) question = random.choice(template.question_hints).format( anchor_name=anchor['container_name'], target_subtype=target_subtype, ) elif template.family == "multi_adjacency": # Use common_neighbor_pairs so anchor1 and anchor2 are guaranteed to # share at least one touching neighbour — SQL will return non-empty. # Filter by both anchor subtypes AND shared_neighbor_subtype so the # sampled pair is geographically coherent with the template intent: # multi_adj_01: region anchors → region result # multi_adj_02: country anchors → country result # multi_adj_03: region anchors → county result _MULTI_ADJ_ANCHOR_SUBTYPES = { "multi_adj_01": ("region", "region"), "multi_adj_02": ("country", "country"), "multi_adj_03": ("region", "region"), } cn_df = tables.get('common_neighbor_pairs', pd.DataFrame()) if cn_df.empty: return None if template.target_subtype and 'shared_neighbor_subtype' in cn_df.columns: filtered = cn_df[cn_df['shared_neighbor_subtype'] == template.target_subtype] if not filtered.empty: cn_df = filtered if template.template_id in _MULTI_ADJ_ANCHOR_SUBTYPES and 'anchor_subtype_1' in cn_df.columns: a1_st, a2_st = _MULTI_ADJ_ANCHOR_SUBTYPES[template.template_id] filtered = cn_df[ (cn_df['anchor_subtype_1'] == a1_st) & (cn_df['anchor_subtype_2'] == a2_st) ] if not filtered.empty: cn_df = filtered row = cn_df.sample(n=1).iloc[0] anchor1 = {'id': row['anchor_id_1'], 'name': row['anchor_name_1'], 'source': 'divisions_area'} anchor2 = {'id': row['anchor_id_2'], 'name': row['anchor_name_2'], 'source': 'divisions_area'} target_subtype = template.target_subtype or row.get('shared_neighbor_subtype', 'region') sql = template.sql_template.format( anchor_id_1=anchor1['id'], anchor_id_2=anchor2['id'], target_subtype=target_subtype, ) candidates1 = build_candidate_list( con, anchor1['id'], anchor1['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates2 = build_candidate_list( con, anchor2['id'], anchor2['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates = _merge_candidate_lists(candidates1, candidates2) question = random.choice(template.question_hints).format( anchor_1_name=anchor1['name'], anchor_2_name=anchor2['name'], target_subtype=target_subtype, ) elif template.family == "difference": if template.anchor_source == "mixed": # divisions_area anchor differenced against a natural_earth feature. # Use cross_source_relations so the pair is guaranteed to intersect # (ST_Difference on non-intersecting geometries is always equal to # the original geometry — a trivial and uninformative sample). cs_anchor = sample_cross_source_anchor( tables.get('cross_source_relations', pd.DataFrame()), natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id), ) if not cs_anchor: return None anchor = { 'id': cs_anchor['division_id'], 'name': cs_anchor['division_name'], 'source': 'divisions_area', } clip_feature = { 'id': cs_anchor['natural_id'], 'name': cs_anchor['natural_name'], 'source': 'natural_earth', } sql = template.sql_template.format( anchor_id=anchor['id'], clip_feature_id=clip_feature['id'], ) question = random.choice(template.question_hints).format( anchor_name=anchor['name'], clip_feature_name=clip_feature['name'], ) # Build candidates for BOTH anchors — model must see both IDs # to correctly assign anchor_id vs clip_feature_id in the SQL. div_cands = build_candidate_list( con, anchor['id'], anchor['name'], 'divisions_area', num_candidates=5, difficulty="hard", ) ne_cands = build_candidate_list( con, clip_feature['id'], clip_feature['name'], 'natural_earth', num_candidates=5, difficulty="hard", ) candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10) else: # Two divisions_area anchors: use both ends of a containment # pair so the contained entity is guaranteed to intersect the # container. ST_Difference(container, contained) yields the # portion of the container outside the contained piece. pair = sample_containment_anchor(tables['containment_pairs']) if not pair: return None anchor1 = {'id': pair['container_id'], 'name': pair['container_name']} anchor2 = {'id': pair['contained_id'], 'name': pair['contained_name']} sql = template.sql_template.format( anchor_id_1=anchor1['id'], anchor_id_2=anchor2['id'], ) candidates1 = build_candidate_list( con, anchor1['id'], anchor1['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates2 = build_candidate_list( con, anchor2['id'], anchor2['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates = _merge_candidate_lists(candidates1, candidates2) question = random.choice(template.question_hints).format( anchor_1_name=anchor1['name'], anchor_2_name=anchor2['name'], ) elif template.family == "border_corridor": # Buffered border zone — needs two anchors that actually touch. pair = sample_adjacency_anchor(tables['adjacency_pairs']) if not pair: return None anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']} anchor2 = {'id': pair['target_id'], 'name': pair['target_name']} buffer_val = random.choice([5, 10, 25, 50]) sql = template.sql_template.format( anchor_id_1=anchor1['id'], anchor_id_2=anchor2['id'], buffer_km=buffer_val, ) candidates1 = build_candidate_list( con, anchor1['id'], anchor1['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates2 = build_candidate_list( con, anchor2['id'], anchor2['name'], 'divisions_area', num_candidates=5, difficulty="medium" ) candidates = _merge_candidate_lists(candidates1, candidates2) question = random.choice(template.question_hints).format( anchor_1_name=anchor1['name'], anchor_2_name=anchor2['name'], buffer_km=buffer_val, ) elif template.family == "window_function": anchor = sample_random_entity( con, tables['divisions_area_inventory'], 'divisions_area', subtypes={'country'}, ) if not anchor: return None country = anchor.get('country') or 'US' target_subtype = template.target_subtype or 'county' sql = template.sql_template.format( country=country, target_subtype=target_subtype, ) candidates = build_candidate_list( con, anchor['id'], anchor['name'], 'divisions_area', num_candidates=10, difficulty="hard" ) question = random.choice(template.question_hints).format( anchor_name=anchor['name'], target_subtype=target_subtype, ) elif template.family == "attribute_filter": anchor = sample_random_entity( con, tables['divisions_area_inventory'], 'divisions_area', subtypes={'country'}, ) if not anchor: return None country = anchor.get('country') or 'US' target_subtype = template.target_subtype or 'region' sql = template.sql_template.format( country=country, target_subtype=target_subtype, ) candidates = build_candidate_list( con, anchor['id'], anchor['name'], 'divisions_area', num_candidates=10, difficulty="medium" ) question = random.choice(template.question_hints).format( anchor_name=anchor['name'], target_subtype=target_subtype, country=country, ) else: # Skip unsupported families return None # Execute SQL to verify try: result = con.execute(_for_execution(sql)).fetchdf() if result.empty: return None except Exception as e: # Errors are tracked in worker return, no need to print return None # Collect every anchor ID that appears in the generated SQL so we can # mark them as the "selected" candidates in the training sample. _multi_anchor_families = {"set_operations", "multi_adjacency", "difference", "border_corridor", "buffer"} # Mixed partial_selection (partial_05) and mixed difference (diff_02) each # have two anchors from different sources — both must be marked selected. _is_mixed_two_anchor = ( template.anchor_source == "mixed" and template.num_anchors == 2 ) if _is_mixed_two_anchor: # partial_05 / diff_02: anchor (division) + clip_feature (natural_earth) mixed_ids = {anchor.get("id", ""), clip_feature.get("id", "")} selected_candidate_ids = [c.candidate_id for c in candidates if c.id in mixed_ids] elif template.family in _multi_anchor_families and template.num_anchors >= 2: anchor_ids: set = set() for var in ("anchor1", "anchor2", "anchor3"): obj = locals().get(var) if obj: anchor_ids.add(obj.get("id", "")) if "anchors" in locals(): for a in locals()["anchors"]: if a: anchor_ids.add(a.get("id", "")) selected_candidate_ids = [c.candidate_id for c in candidates if c.id in anchor_ids] else: anchor_id_to_find = ( anchor.get('anchor_id') or anchor.get('container_id') or anchor.get('natural_id') or anchor.get('id') ) selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find] question, surface_variants = _diversify_question_surface(question, template.family) return TrainingSample( id=sample_id, question=question, candidates=candidates, target={ "selected_candidates": selected_candidate_ids, "sql": sql, }, metadata={ "task_family": template.family, "sql_difficulty": template.sql_difficulty, "grounding_difficulty": "medium", "template_id": template.template_id, "num_candidates": len(candidates), "anchor_source": template.anchor_source, "sql_verified": True, "surface_variants": surface_variants, } ) def generate_sample_batch_worker(args): """Worker function that processes a batch of work items with a single DuckDB connection. Initializes DuckDB, spatial extension, templates module, and relation tables ONCE per batch, then processes all items sequentially. """ from pathlib import Path work_items, intermediate_dir_str = args # Convert string back to Path intermediate_dir = Path(intermediate_dir_str) # Initialize DuckDB ONCE for the entire batch con = duckdb.connect() con.execute("SET enable_progress_bar=false") con.execute("INSTALL spatial") con.execute("LOAD spatial") # Load relation tables ONCE tables = load_relation_tables(intermediate_dir, quiet=True) # Process all items in batch results = [] for family, template_dict, sample_id, _ in work_items: # Reconstruct template from dict (sql_templates is already imported at module level) template = sql_templates.SQLTemplate(**template_dict) try: sample = generate_template_based_sample(con, template, tables, sample_id) if sample: results.append((sample, family, template.template_id, None)) else: results.append((None, family, template.template_id, "Empty result")) except Exception as e: results.append((None, family, template_dict.get('template_id', 'unknown'), str(e))) con.close() return results def generate_batch_core( work_items: List[tuple], intermediate_dir: str, ) -> List[Dict[str, Any]]: """Standalone batch worker usable from Modal or any remote context. Data paths are resolved via GAZET_DATA_DIR env var (set in Modal image). Args: work_items: List of (family, template_dict, sample_id, _) tuples intermediate_dir: Path to intermediate dir with relation parquets Returns: List of dicts with keys: sample (dict or None), family, template_id, error """ from pathlib import Path as _Path intermediate = _Path(intermediate_dir) con = duckdb.connect() con.execute("SET enable_progress_bar=false") con.execute("INSTALL spatial") con.execute("LOAD spatial") tables = load_relation_tables(intermediate, quiet=True) results = [] for family, template_dict, sample_id, _ in work_items: template = sql_templates.SQLTemplate(**template_dict) try: sample = generate_template_based_sample(con, template, tables, sample_id) if sample: results.append({ "sample": sample.model_dump(), "family": family, "template_id": template.template_id, "error": None, }) else: results.append({ "sample": None, "family": family, "template_id": template.template_id, "error": "Empty result", }) except Exception as e: results.append({ "sample": None, "family": family, "template_id": template_dict.get('template_id', 'unknown'), "error": str(e), }) con.close() return results def prepare_work_items( target_counts: Dict[str, int], retry_multiplier: int = 2, start_counter: int = 1, intermediate_dir_str: str = "", ) -> List[tuple]: """Prepare shuffled work items for sample generation. Returns list of (family, template_dict, sample_id, intermediate_dir_str) tuples. Reusable by both local main() and Modal orchestrator. """ work_items = [] sample_counter = start_counter for family, target_count in target_counts.items(): if target_count == 0: continue family_templates = [t for t in TEMPLATES if t.family == family] if not family_templates: print(f"No templates found for {family}, skipping...") continue # Distribute target evenly across templates so every template_id gets # a guaranteed share. Uniform random choice previously let rare # variants like partial_05 / diff_02 get under-represented or dropped # entirely when their mixed-source branch hit transient failures. n_tpl = len(family_templates) per_tpl = target_count // n_tpl remainder = target_count % n_tpl for i, template in enumerate(family_templates): count = per_tpl + (1 if i < remainder else 0) template_dict = { 'template_id': template.template_id, 'family': template.family, 'sql_difficulty': template.sql_difficulty, 'anchor_source': template.anchor_source, 'num_anchors': template.num_anchors, 'sql_template': template.sql_template, 'question_hints': template.question_hints, 'target_subtype': template.target_subtype, 'requires_buffer': template.requires_buffer, 'requires_aggregation': template.requires_aggregation } for _ in range(count * retry_multiplier): work_items.append(( family, template_dict, f"sample_{sample_counter:06d}", intermediate_dir_str, )) sample_counter += 1 random.shuffle(work_items) return work_items def main(): """Generate training samples.""" global TARGET_COUNTS, MAX_WORKERS, RETRY_MULTIPLIER, APPEND_MODE # Setup paths script_dir = Path(__file__).parent intermediate_dir = script_dir.parent / "intermediate" output_dir = script_dir.parent / "output" output_dir.mkdir(exist_ok=True, parents=True) # Load relation tables once to check availability print("Loading relation tables...") tables = load_relation_tables(intermediate_dir, quiet=False) # Use configured target counts or defaults if TARGET_COUNTS is None: target_counts = { 'direct_lookup': 100, 'adjacency': 150, 'multi_adjacency': 75, 'containment': 100, 'intersection': 100, 'buffer': 100, 'chained': 150, 'difference': 75, 'border_corridor': 75, 'set_operations': 150, 'partial_selection': 75, 'aggregation': 100, 'window_function': 75, 'attribute_filter': 75, } else: target_counts = TARGET_COUNTS # Load existing samples if in append mode existing_samples = [] existing_sample_ids = set() jsonl_file = output_dir / "dataset_raw.jsonl" if APPEND_MODE and jsonl_file.exists(): print(f"\nAppend mode: Loading existing samples from {jsonl_file}") with open(jsonl_file, 'r') as f: for line in f: if line.strip(): sample_data = json.loads(line) existing_samples.append(sample_data) existing_sample_ids.add(sample_data['id']) print(f" Found {len(existing_samples)} existing samples") # Determine starting sample counter max_existing_id = max([int(s['id'].split('_')[1]) for s in existing_samples if s['id'].startswith('sample_')], default=0) sample_counter = max_existing_id + 1 else: sample_counter = 1 # Prepare work items using shared helper work_items = prepare_work_items( target_counts=target_counts, retry_multiplier=RETRY_MULTIPLIER, start_counter=sample_counter, intermediate_dir_str=str(intermediate_dir), ) starting_sample_counter = sample_counter # Partition work items into batches (one per worker) num_workers = min(MAX_WORKERS, len(work_items)) if num_workers == 0: print("No work items to process") return batch_size = (len(work_items) + num_workers - 1) // num_workers batches = [] for i in range(0, len(work_items), batch_size): batch = work_items[i:i + batch_size] batches.append((batch, str(intermediate_dir))) # Generate samples in parallel (one batch per worker) active_families = len([f for f in target_counts.values() if f > 0]) print(f"\nGenerating {len(work_items)} samples across {active_families} families...") print(f" Split into {len(batches)} batches of ~{batch_size} items (1 DuckDB init per batch)") if APPEND_MODE and existing_samples: print(f"Appending: starting from sample_{starting_sample_counter:03d}") all_samples = [] family_progress = {f: {'success': 0, 'failed': 0} for f in target_counts.keys() if target_counts[f] > 0} with ProcessPoolExecutor(max_workers=num_workers) as executor: # Submit one batch per worker futures = {executor.submit(generate_sample_batch_worker, batch): i for i, batch in enumerate(batches)} # Collect results as batches complete batches_done = 0 for future in as_completed(futures): try: batch_results = future.result() for sample, family, template_id, error in batch_results: if sample: all_samples.append(sample) family_progress[family]['success'] += 1 else: family_progress[family]['failed'] += 1 except Exception as e: print(f"\n Batch failed: {e}") batches_done += 1 total_done = sum(p['success'] + p['failed'] for p in family_progress.values()) print(f"\r Progress: {total_done}/{len(work_items)} samples ({batches_done}/{len(batches)} batches) ", end='', flush=True) print() # New line after progress # Show distribution (keep all samples, no filtering) print("\nResults by family:") for family in sorted(family_progress.keys()): success = family_progress[family]['success'] failed = family_progress[family]['failed'] target = target_counts.get(family, 0) total = success + failed success_rate = (success / total * 100) if total > 0 else 0 print(f" {family:20s}: {success:3d} success / {failed:3d} failed ({success_rate:5.1f}% success rate, target: {target})") # Save combined JSONL (skip individual JSON files for speed at scale) print(f"\nSaving {len(all_samples)} new samples...") if APPEND_MODE and existing_samples: # Append to existing dataset print(f"Appending to existing dataset ({len(existing_samples)} existing samples)") with open(jsonl_file, 'a') as f: for sample in all_samples: f.write(json.dumps(sample.model_dump()) + '\n') total_samples = len(existing_samples) + len(all_samples) else: # Overwrite dataset with open(jsonl_file, 'w') as f: for sample in all_samples: f.write(json.dumps(sample.model_dump()) + '\n') total_samples = len(all_samples) print(f"\nGenerated {len(all_samples)} new samples") print(f"Total dataset size: {total_samples} samples") print(f" Dataset: {jsonl_file}") if __name__ == "__main__": main()