gazet / dataset /scripts /generate_samples.py
srmsoumya's picture
Randomize candidate dataset order
582d1ab
"""
Generate synthetic training samples for text-to-SQL task.
This script:
1. Loads relation tables and entity inventories
2. For each SQL template, samples valid anchors
3. Renders and executes SQL to verify it works
4. Builds candidate lists with controlled distractors
5. Generates natural language questions using LLM
6. Saves complete training samples
Output:
- output/samples/sample_*.json (individual samples)
- output/dataset_raw.jsonl (all samples)
"""
import json
import random
import re
import warnings
from pathlib import Path
from typing import List, Dict, Any, Optional
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import duckdb
import pandas as pd
from pydantic import BaseModel
# Suppress warnings
warnings.filterwarnings('ignore')
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
# Fixed paths embedded in every training SQL string.
# The model learns these short, stable strings rather than machine-specific
# local paths. At inference, sql.py's _rewrite_data_paths substitutes them
# with the actual runtime paths from gazet.config.
_DIVISIONS_SQL_PATH = 'divisions_area'
_NATURAL_EARTH_SQL_PATH = 'natural_earth'
def _for_execution(sql: str) -> str:
"""Replace symbolic placeholder paths with actual local paths for verification."""
return (
sql
.replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
.replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
)
# Configurable parameters (can be overridden by CLI)
TARGET_COUNTS = None # Will be set in main() or by CLI
MAX_WORKERS = 8
RETRY_MULTIPLIER = 2
APPEND_MODE = False
_GENERIC_SURFACE_RULES = [
("spelling_neighboring", r"\bneighbouring\b", ["neighboring"]),
("spelling_neighbors", r"\bneighbours\b", ["neighbors"]),
("expand_whats", r"\bwhat's\b", ["what is"]),
("show_me", r"\bshow me\b", ["show", "display"]),
("give_me", r"\bgive me\b", ["show", "list"]),
("pull_up", r"\bpull up\b", ["show", "display"]),
("find_to_show", r"\bfind\b", ["show", "locate"]),
("kilometers_variant", r"\bkilometers\b", ["km"]),
("metres_variant", r"\bmetres\b", ["meters"]),
("recognised_variant", r"\brecognised\b", ["recognized"]),
]
_FAMILY_SURFACE_RULES = {
"adjacency": [
("which_border_to_next_to", r"\bwhich (.+?) border (.+)\?", [r"which \1 are next to \2?", r"which \1 are adjacent to \2?"]),
("bordering_to_next_to", r"\bbordering (.+)", [r"next to \1", r"adjacent to \1"]),
("touching_to_next_to", r"\btouching (.+)", [r"next to \1"]),
("share_border_to_adjacent", r"share a border with", ["are adjacent to", "are next to"]),
("adjacent_to_next_to", r"adjacent to", ["next to"]),
],
"multi_adjacency": [
("which_border_both_to_next_to", r"\bwhich (.+?) border both (.+)\?", [r"which \1 are next to both \2?", r"which \1 are adjacent to both \2?"]),
("touch_both_to_next_to", r"touch both", ["are next to both"]),
("adjacent_both_to_next_to", r"adjacent to both", ["next to both"]),
],
"containment": [
("within_to_inside", r"\bwithin\b", ["inside", "in"]),
("inside_to_in", r"\binside\b", ["in"]),
("belonging_to_in", r"belonging to", ["in"]),
("contain_to_have", r"\bcontain\b", ["have"]),
],
"intersection": [
("which_intersect_to_overlap", r"\bwhich (.+?) intersect (.+)\?", [r"which \1 overlap \2?"]),
("overlap_with_to_intersect", r"overlap with", ["intersect"]),
("crossing_to_overlapping", r"crossing into", ["overlapping"]),
("partly_in_to_overlap", r"partly in", ["overlapping"]),
],
"buffer": [
("within_distance_to_from", r"within ([0-9]+\s*(?:km|m)) of", [r"up to \1 from", r"at a distance of \1 from"]),
("buffer_to_radius", r"\bbuffer\b", ["radius", "zone"]),
("close_to_near", r"close to", ["near"]),
("around_to_near", r"what is around", ["what is near"]),
],
"chained": [
("coastal_to_seaside", r"\bcoastal\b", ["seaside", "maritime"]),
("landlocked_to_inland", r"\blandlocked\b", ["inland"]),
("sea_access_to_coast", r"sea access", ["a coastline"]),
],
"difference": [
("part_to_portion", r"\bpart of\b", ["portion of", "section of"]),
("outside_to_excluding", r"\boutside\b", ["excluding"]),
],
"border_corridor": [
("zone_to_buffer", r"\bzone\b", ["buffer", "corridor"]),
("within_distance_to_along", r"within ([0-9]+ km) of the", [r"along the", r"up to \1 from the"]),
],
"set_operations": [
("combined_to_merged", r"combined", ["merged"]),
("union_of_to_merged_area", r"\bunion of\b", ["merged area of", "combined area of"]),
("merge_to_combine", r"\bmerge\b", ["combine"]),
("together_to_combined", r"\btogether\b", ["combined"]),
],
"partial_selection": [
("part_to_portion", r"\bpart of\b", ["portion of", "section of"]),
("half_to_side", r"\bhalf\b", ["side"]),
],
"aggregation": [
("largest_to_biggest", r"\blargest\b", ["biggest"]),
("smallest_to_tiniest", r"\bsmallest\b", ["tiniest"]),
],
"window_function": [
("largest_to_biggest", r"\blargest\b", ["biggest"]),
("smallest_to_tiniest", r"\bsmallest\b", ["tiniest"]),
],
"attribute_filter": [
("official_to_recognized", r"\bofficial\b", ["recognized", "recognized territorial"]),
("land_based_to_on_land", r"land-based", ["on-land", "on land"]),
("sovereign_to_official", r"\bsovereign\b", ["official"]),
],
"direct_lookup": [
("where_is_to_show", r"\bwhere is\b", ["show", "locate"]),
("map_of_to_outline", r"\bmap of\b", ["outline of"]),
],
"disambiguation": [
("show_me_to_find", r"\bshow me\b", ["find", "show"]),
("pull_up_to_find", r"\bpull up\b", ["find", "show"]),
],
}
def _diversify_question_surface(question: str, family: str) -> tuple[str, List[str]]:
"""Apply light family-aware paraphrasing to reduce template memorization.
Rewrites are intentionally shallow and lexically local so the generated
question stays aligned with the underlying SQL intent.
"""
if not question or random.random() < 0.35:
return question, []
rules = _GENERIC_SURFACE_RULES + _FAMILY_SURFACE_RULES.get(family, [])
rewritten = question
applied: List[str] = []
max_rewrites = 2 if random.random() < 0.5 else 1
for _ in range(max_rewrites):
matches = []
for label, pattern, replacements in rules:
if re.search(pattern, rewritten, flags=re.IGNORECASE):
for replacement in replacements:
matches.append((label, pattern, replacement))
if not matches:
break
label, pattern, replacement = random.choice(matches)
updated = re.sub(pattern, replacement, rewritten, count=1, flags=re.IGNORECASE)
if updated == rewritten:
continue
rewritten = re.sub(r"\s+", " ", updated).strip()
applied.append(f"{family}:{label}")
return rewritten, applied
# Import templates from same directory
from . import sql_templates
TEMPLATES = sql_templates.TEMPLATES
SQLTemplate = sql_templates.SQLTemplate
get_templates_by_family = sql_templates.get_templates_by_family
_NE_NAMED_LOOKUP_SUBTYPES = {
'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay',
'island group', 'peninsula', 'strait', 'range/mtn', 'depression',
}
_NE_TEMPLATE_SUBTYPES = {
'lookup_02': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'},
'adj_03': {'sea', 'ocean'},
'adj_09': {'river', 'lake', 'basin'},
'adj_10': {'range/mtn', 'peninsula', 'depression'},
'adj_06': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression', 'plateau', 'plain', 'lowland', 'valley', 'gorge'},
'adj_07': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression', 'plateau', 'plain', 'lowland', 'valley', 'gorge'},
'adj_08': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression', 'plateau', 'plain', 'lowland', 'valley', 'gorge'},
'contain_04': {'sea', 'ocean', 'gulf', 'bay', 'basin', 'island group', 'peninsula', 'range/mtn', 'depression'},
'contain_05': {'sea', 'ocean', 'gulf', 'bay', 'strait'},
'intersect_03': {'river', 'lake', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression'},
'intersect_04': {'river', 'lake', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression'},
'intersect_06': {'river', 'lake', 'basin', 'gulf', 'bay', 'strait', 'range/mtn', 'peninsula', 'depression'},
'buffer_02': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'},
'buffer_11': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'},
'chained_03': {'island group', 'peninsula', 'range/mtn', 'depression'},
'chained_04': {'river', 'lake', 'basin'},
'chained_05': {'range/mtn', 'depression'},
'chained_08': {'river', 'lake', 'basin'},
'chained_09': {'range/mtn', 'depression'},
'partial_05': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'},
'diff_02': {'sea', 'ocean', 'lake', 'river', 'basin', 'gulf', 'bay', 'island group', 'peninsula', 'strait', 'range/mtn', 'depression'},
}
class Candidate(BaseModel):
"""Candidate entity for grounding."""
candidate_id: str
source: str
id: str
name: str
subtype: Optional[str] = None
country: Optional[str] = None
region: Optional[str] = None
admin_level: Optional[int] = None
similarity: float = 0.0
class TrainingSample(BaseModel):
"""Complete training sample."""
id: str
question: str
candidates: List[Candidate]
target: Dict[str, Any]
metadata: Dict[str, Any]
def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[str, pd.DataFrame]:
"""Load all precomputed relation tables."""
tables = {}
for file in intermediate_dir.glob("*.parquet"):
name = file.stem
tables[name] = pd.read_parquet(file)
if not quiet:
print(f" {name}: {len(tables[name])} rows")
return tables
def sample_adjacency_anchor(
adjacency_df: pd.DataFrame,
target_subtype: Optional[str] = None,
anchor_subtypes: Optional[List[str]] = None,
) -> Optional[Dict[str, Any]]:
"""Sample a random adjacency pair, optionally filtered by subtypes.
When ``target_subtype`` is provided, only rows whose neighbouring feature
matches that subtype are considered. When ``anchor_subtypes`` is provided,
only rows whose anchor feature is one of those subtypes are considered.
Both filters are applied together so sampled pairs are geographically
coherent with the template intent (e.g. country anchor → country result).
"""
if adjacency_df.empty:
return None
df = adjacency_df
if target_subtype is not None:
df = df[df['target_subtype'] == target_subtype]
if df.empty:
return None
if anchor_subtypes is not None:
filtered = df[df['anchor_subtype'].isin(anchor_subtypes)]
if not filtered.empty:
df = filtered
row = df.sample(n=1).iloc[0]
return {
'anchor_id': row['anchor_id'],
'anchor_name': row['anchor_name'],
'anchor_subtype': row['anchor_subtype'],
'anchor_country': row.get('anchor_country'), # May not exist in all tables
'target_id': row.get('target_id'),
'target_name': row.get('target_name'),
'target_subtype': row.get('target_subtype')
}
def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
"""Sample a random intersection pair."""
if intersection_df.empty:
return None
row = intersection_df.sample(n=1).iloc[0]
return {
'anchor_id': row['anchor_id'],
'anchor_name': row['anchor_name'],
'anchor_subtype': row['anchor_subtype'],
'target_id': row.get('target_id'),
'target_name': row.get('target_name'),
'target_subtype': row.get('target_subtype')
}
def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
"""Sample a random containment pair.
Returns both ends of the pair so callers that need the contained entity
(e.g. difference templates that clip container by contained) can use it
directly without a second random draw.
"""
if containment_df.empty:
return None
row = containment_df.sample(n=1).iloc[0]
return {
'container_id': row['container_id'],
'container_name': row['container_name'],
'container_subtype': row['container_subtype'],
'contained_id': row['contained_id'],
'contained_name': row['contained_name'],
'contained_subtype': row['contained_subtype'],
}
def sample_disambiguation_anchor(
containment_df: pd.DataFrame,
contained_subtypes: List[str],
container_subtypes: List[str],
) -> Optional[Dict[str, Any]]:
"""Sample a (contained, container) pair from containment_pairs.
Used by disambiguation templates like "Puri, Odisha" where the contained
entity is the target and the container provides disambiguation context.
"""
if containment_df.empty:
return None
df = containment_df[
containment_df['contained_subtype'].isin(contained_subtypes)
& containment_df['container_subtype'].isin(container_subtypes)
]
if df.empty:
return None
row = df.sample(n=1).iloc[0]
return {
'contained_id': row['contained_id'],
'contained_name': row['contained_name'],
'contained_subtype': row['contained_subtype'],
'container_id': row['container_id'],
'container_name': row['container_name'],
'container_subtype': row['container_subtype'],
}
def sample_cross_source_anchor(
cross_source_df: pd.DataFrame,
natural_subtypes: Optional[set[str]] = None,
relation_types: Optional[set[str]] = None,
) -> Optional[Dict[str, Any]]:
"""Sample a random cross-source relation with optional subtype filters."""
if cross_source_df.empty:
return None
df = cross_source_df
if natural_subtypes is not None:
df = df[df['natural_subtype'].isin(natural_subtypes)]
if relation_types is not None:
df = df[df['relation_type'].isin(relation_types)]
if df.empty:
return None
row = df.sample(n=1).iloc[0]
return {
'division_id': row['division_id'],
'division_name': row['division_name'],
'division_subtype': row['division_subtype'],
'natural_id': row['natural_id'],
'natural_name': row['natural_name'],
'natural_subtype': row['natural_subtype'],
'relation_type': row['relation_type']
}
def _merge_candidate_lists(
*lists: List[Candidate],
max_total: int = 10,
) -> List[Candidate]:
"""Merge N candidate lists, deduplicate by id, reassign candidate_ids.
Interleaves the lists so each anchor is represented before any anchor
gets a second candidate — matching the grouped-then-interleaved order
that inference produces.
"""
from itertools import zip_longest
seen: set = set()
merged: List[Candidate] = []
for row in zip_longest(*lists):
for c in row:
if c is None:
continue
if c.id not in seen:
merged.append(c)
seen.add(c.id)
if len(merged) >= max_total:
break
if len(merged) >= max_total:
break
for i, c in enumerate(merged, 1):
c.candidate_id = f"c{i}"
return merged
def _dedupe_country_candidates(
candidates: List[Candidate],
max_total: Optional[int] = None,
) -> List[Candidate]:
"""Deduplicate country candidates by country code, preserving first match.
This is useful for templates whose SQL uses ``country IN (...)`` rather
than candidate IDs. Overture can contain multiple country-level rows for
the same ISO code, which weakens grounding if they all remain in the list.
"""
deduped: List[Candidate] = []
seen_keys: set[tuple[str, str]] = set()
for cand in candidates:
if cand.subtype == "country" and cand.country:
key = ("country", cand.country)
else:
key = ("id", cand.id)
if key in seen_keys:
continue
deduped.append(cand)
seen_keys.add(key)
if max_total is not None and len(deduped) >= max_total:
break
for i, cand in enumerate(deduped, 1):
cand.candidate_id = f"c{i}"
return deduped
def build_candidate_list(
con: duckdb.DuckDBPyConnection,
anchor_id: str,
anchor_name: str,
anchor_source: str,
num_candidates: int = 10,
difficulty: str = "medium"
) -> List[Candidate]:
"""Build candidate list with true anchor + distractors."""
# Helper to convert pandas NA to None
def safe_get(row, key, default=None):
val = row.get(key, default)
return None if pd.isna(val) else val
# Get the true anchor
if anchor_source == "divisions_area":
query = """
SELECT
id,
names."primary" AS name,
subtype,
country,
region,
admin_level
FROM read_parquet(?)
WHERE id = ?
"""
anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
else:
query = """
SELECT
id,
names."primary" AS name,
subtype
FROM read_parquet(?)
WHERE id = ?
"""
anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
true_candidate = Candidate(
candidate_id="c1",
source=anchor_source,
id=anchor_id,
name=safe_get(anchor_row, 'name'),
subtype=safe_get(anchor_row, 'subtype'),
country=safe_get(anchor_row, 'country'),
region=safe_get(anchor_row, 'region'),
admin_level=safe_get(anchor_row, 'admin_level'),
similarity=1.0,
)
distractors = build_distractors(
con,
anchor_name,
anchor_source,
anchor_id,
num_candidates - 1,
difficulty,
)
# Deduplicate by underlying entity id while preserving order.
# Some parquet sources contain repeated rows for the same feature id,
# which can otherwise leak duplicate candidates into the dataset.
candidates: List[Candidate] = []
seen_ids: set[str] = set()
for cand in [true_candidate] + distractors:
if cand.id in seen_ids:
continue
candidates.append(cand)
seen_ids.add(cand.id)
if len(candidates) >= num_candidates:
break
for i, cand in enumerate(candidates, 1):
cand.candidate_id = f"c{i}"
return candidates
def build_distractors(
con: duckdb.DuckDBPyConnection,
anchor_name: str,
anchor_source: str,
exclude_id: str,
num_distractors: int,
difficulty: str,
cross_source_ratio: float = 0.5,
) -> List[Candidate]:
"""Build distractor candidates using fuzzy search.
Always includes candidates from both sources so the model sees mixed
``source`` values in every training example — matching the inference
behaviour where search.py queries divisions_area AND natural_earth equally
(5 results each per place).
Args:
cross_source_ratio: Fraction of distractors drawn from the *other*
source. Defaults to 0.5 (50/50 split) to match inference exactly.
"""
def safe_get(row, key, default=None):
val = row.get(key, default)
return None if pd.isna(val) else val
def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
query = """
WITH ranked AS (
SELECT
id,
names."primary" AS name,
subtype,
country,
region,
admin_level,
jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity,
ROW_NUMBER() OVER (
PARTITION BY id
ORDER BY jaro_winkler_similarity(lower(names."primary"), lower(?)) DESC
) AS rn
FROM read_parquet(?)
WHERE id != ?
AND names."primary" IS NOT NULL
AND trim(names."primary") != ''
AND geometry IS NOT NULL
)
SELECT
id,
name,
subtype,
country,
region,
admin_level,
similarity
FROM ranked
WHERE rn = 1
ORDER BY similarity DESC
LIMIT ?
"""
df = con.execute(query, [anchor_name, anchor_name, path, excl_id, n]).fetchdf()
results = []
for _, row in df.iterrows():
results.append(Candidate(
candidate_id="temp",
source=src_name,
id=row["id"],
name=safe_get(row, "name"),
subtype=safe_get(row, "subtype"),
country=safe_get(row, "country"),
region=safe_get(row, "region"),
admin_level=safe_get(row, "admin_level"),
similarity=float(row["similarity"]),
))
return results
cross_n = max(1, round(num_distractors * cross_source_ratio))
same_n = num_distractors - cross_n
if anchor_source == "divisions_area":
same = _query_source(DIVISIONS_AREA_PATH, "divisions_area", same_n, exclude_id)
cross = _query_source(NATURAL_EARTH_PATH, "natural_earth", cross_n, "")
else:
same = _query_source(NATURAL_EARTH_PATH, "natural_earth", same_n, exclude_id)
cross = _query_source(DIVISIONS_AREA_PATH, "divisions_area", cross_n, "")
return same + cross
def sample_random_entity(
con: duckdb.DuckDBPyConnection,
inventory_df: pd.DataFrame,
source: str,
subtypes: Optional[set[str]] = None,
countries: Optional[set[str]] = None,
) -> Optional[Dict[str, Any]]:
"""Sample a random entity from inventory with optional filters."""
if inventory_df.empty:
return None
df = inventory_df
if subtypes is not None:
df = df[df['subtype'].isin(subtypes)]
if countries is not None and 'country' in df.columns:
df = df[df['country'].isin(countries)]
if df.empty:
return None
row = df.sample(n=1).iloc[0]
return {
'id': row['id'],
'name': row['name'],
'subtype': row.get('subtype'),
'country': row.get('country'),
'source': source
}
def generate_template_based_sample(
con: duckdb.DuckDBPyConnection,
template: SQLTemplate,
tables: Dict[str, pd.DataFrame],
sample_id: str
) -> Optional[TrainingSample]:
"""Generate a sample based on a SQL template."""
# Sample anchor based on template requirements
if template.family == "direct_lookup":
# Just pick a random entity
if template.anchor_source == "divisions_area":
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
else:
anchor = sample_random_entity(
con,
tables['natural_earth_inventory'],
'natural_earth',
subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES),
)
if not anchor:
return None
# Render SQL
sql = template.sql_template.format(
anchor_id=anchor['id']
)
# Build candidates
candidates = build_candidate_list(
con, anchor['id'], anchor['name'], anchor['source'],
num_candidates=10, difficulty="easy"
)
# Question
question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
elif template.family == "disambiguation":
# "Puri, Odisha" style: pick a (contained, container) pair whose
# subtypes match the template, build candidates that include the
# container + same-name distractors so the model must read the CSV
# to pick the right entry.
_disambig_subtypes = {
"disambiguate_01": (["county"], ["region", "country"]),
"disambiguate_02": (["county"], ["country"]),
"disambiguate_03": (["region"], ["country"]),
}
contained_sts, container_sts = _disambig_subtypes.get(
template.template_id, (["county"], ["country"])
)
pair = sample_disambiguation_anchor(
tables["containment_pairs"], contained_sts, container_sts
)
if not pair:
return None
candidates = build_candidate_list(
con, pair["contained_id"], pair["contained_name"], "divisions_area",
num_candidates=10, difficulty="hard"
)
# Ensure the container is among the candidates so the model can
# ground the disambiguation context (e.g. "Odisha").
if not any(c.id == pair["container_id"] for c in candidates):
container_rows = con.execute(
'SELECT id, names."primary" AS name, subtype, country, region, admin_level '
'FROM read_parquet(?) WHERE id = ? LIMIT 1',
[DIVISIONS_AREA_PATH, pair["container_id"]]
).fetchdf()
if container_rows.empty:
return None
crow = container_rows.iloc[0]
def _nn(v):
return None if pd.isna(v) else v
container_cand = Candidate(
candidate_id="temp",
source="divisions_area",
id=pair["container_id"],
name=_nn(crow["name"]),
subtype=_nn(crow["subtype"]),
country=_nn(crow["country"]),
region=_nn(crow["region"]),
admin_level=_nn(crow["admin_level"]),
similarity=0.95,
)
# Insert the container right after the true target and drop the
# last filler distractor so the total stays at 10.
candidates = [candidates[0], container_cand] + candidates[1:-1]
for i, c in enumerate(candidates, 1):
c.candidate_id = f"c{i}"
sql = template.sql_template.format(anchor_id=pair["contained_id"])
question = random.choice(template.question_hints).format(
anchor_name=pair["contained_name"],
container_name=pair["container_name"],
)
# Only the contained entity is the query target — the container is
# disambiguation context and stays in candidates but NOT in
# selected_candidates. The model learns to use the container row of
# the CSV (via country/region columns) to pick the right same-name
# county or region.
anchor = {"id": pair["contained_id"], "name": pair["contained_name"]}
elif template.family == "adjacency":
# adj_03/09/10/11/12: division anchor -> natural_earth targets.
# adj_06/07/08: natural_earth anchor -> admin targets.
# Use cross_source_relations so anchors are guaranteed to intersect.
_NE_TARGET_ADJ_SUBTYPES = {
"adj_03": ("ocean", "sea"),
"adj_09": ("river", "lake", "basin"),
"adj_10": ("range/mtn",),
"adj_11": ("plateau",),
"adj_12": ("plain", "lowland", "basin", "valley", "depression", "gorge"),
}
if template.template_id in _NE_TARGET_ADJ_SUBTYPES:
cs_df = tables.get('cross_source_relations', pd.DataFrame())
if cs_df.empty:
return None
ne_types = _NE_TARGET_ADJ_SUBTYPES[template.template_id]
filtered = cs_df[cs_df['natural_subtype'].isin(ne_types)]
if filtered.empty:
return None
row = filtered.sample(n=1).iloc[0]
anchor = {
'anchor_id': row['division_id'],
'anchor_name': row['division_name'],
'anchor_subtype': row['division_subtype'],
'target_subtype': row['natural_subtype'],
'anchor_source': 'divisions_area',
}
elif template.anchor_source == "natural_earth":
cs_anchor = sample_cross_source_anchor(
tables.get('cross_source_relations', pd.DataFrame()),
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
)
if not cs_anchor:
return None
anchor = {
'anchor_id': cs_anchor['natural_id'],
'anchor_name': cs_anchor['natural_name'],
'target_subtype': template.target_subtype,
'anchor_source': 'natural_earth',
}
else:
# divisions_area self-join adjacency.
_ADJ_ANCHOR_SUBTYPES = {
"adj_02": ["country", "region"],
"adj_04": ["region"],
"adj_05": ["country"],
}
filter_subtype = (
template.target_subtype
if '{target_subtype}' in template.sql_template
else None
)
anchor = sample_adjacency_anchor(
tables['adjacency_pairs'],
target_subtype=filter_subtype,
anchor_subtypes=_ADJ_ANCHOR_SUBTYPES.get(template.template_id),
)
if anchor:
anchor['anchor_source'] = 'divisions_area'
if not anchor:
return None
sql = template.sql_template.format(
anchor_id=anchor['anchor_id'],
target_subtype=anchor.get('target_subtype', ''),
)
candidates = build_candidate_list(
con, anchor['anchor_id'], anchor['anchor_name'], anchor.get('anchor_source', 'divisions_area'),
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['anchor_name'],
target_subtype=anchor.get('target_subtype', ''),
)
elif template.family == "containment":
if template.anchor_source == "natural_earth":
# contain_04 / contain_05: NE anchor (sea, desert, etc.).
# Use cross_source_relations so the anchor exists in natural_earth
# and is guaranteed to intersect divisions_area features.
cs_anchor = sample_cross_source_anchor(
tables.get('cross_source_relations', pd.DataFrame()),
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
)
if not cs_anchor:
return None
anchor_id = cs_anchor['natural_id']
anchor_name = cs_anchor['natural_name']
target_subtype = template.target_subtype or 'country'
sql = template.sql_template.format(
anchor_id=anchor_id,
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, anchor_id, anchor_name, 'natural_earth',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor_name,
target_subtype=target_subtype,
)
anchor = {'id': anchor_id, 'name': anchor_name}
elif template.template_id == "contain_02":
# "What country contains X?" - anchor is the CONTAINED entity;
# result is the country that ST_Contains it.
# Guard against stale relation tables by only allowing contained
# subtypes that exist in the simplified admin schema.
df = tables['containment_pairs']
df = df[
(df['container_subtype'] == 'country')
& (df['contained_subtype'].isin(['region', 'county']))
]
pair = sample_containment_anchor(df)
if not pair:
return None
sql = template.sql_template.format(
anchor_id=pair['contained_id'],
target_subtype='country',
)
candidates = build_candidate_list(
con, pair['contained_id'], pair['contained_name'], 'divisions_area',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=pair['contained_name'],
target_subtype='country',
)
anchor = {'id': pair['contained_id'], 'name': pair['contained_name']}
elif template.template_id == "contain_03":
# "What regions are in country X?" - anchor is a country, target is regions.
df = tables['containment_pairs']
df = df[
(df['container_subtype'] == 'country')
& (df['contained_subtype'] == 'region')
]
pair = sample_containment_anchor(df)
if not pair:
return None
sql = template.sql_template.format(
anchor_id=pair['container_id'],
target_subtype='region',
)
candidates = build_candidate_list(
con, pair['container_id'], pair['container_name'], 'divisions_area',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=pair['container_name'],
target_subtype='region',
)
anchor = {'id': pair['container_id'], 'name': pair['container_name']}
else:
# contain_01: standard containment.
# Enforce hierarchy: county must be inside region or country, never
# inside another county. Filter container_subtype accordingly.
# Also filter contained_subtype to match template.target_subtype so
# hardcoded vocab hints (e.g. "districts") always align with the SQL.
_VALID_CONTAINERS = {
"county": ["region", "country"],
"region": ["country"],
}
df = tables['containment_pairs']
if template.target_subtype:
filtered = df[df['contained_subtype'] == template.target_subtype]
if not filtered.empty:
df = filtered
valid_containers = _VALID_CONTAINERS.get(template.target_subtype)
if valid_containers:
filtered = df[df['container_subtype'].isin(valid_containers)]
if not filtered.empty:
df = filtered
anchor = sample_containment_anchor(df)
if not anchor:
return None
target_subtype = template.target_subtype or anchor['contained_subtype']
sql = template.sql_template.format(
anchor_id=anchor['container_id'],
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['container_name'],
target_subtype=target_subtype,
)
elif template.family == "intersection":
if template.anchor_source == "natural_earth":
anchor = sample_cross_source_anchor(
tables['cross_source_relations'],
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
)
if not anchor:
return None
target_subtype = template.target_subtype or 'country'
sql = template.sql_template.format(
anchor_id=anchor['natural_id'],
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['natural_name'],
target_subtype=target_subtype,
)
else:
# Same-source intersection.
# If the template pins a target_subtype (e.g. intersect_02 targets county),
# filter intersection_pairs so the sampled pair is guaranteed to match.
idf = tables['intersection_pairs']
if template.target_subtype and not idf.empty:
filtered = idf[idf['target_subtype'] == template.target_subtype]
if filtered.empty:
return None
idf = filtered
anchor = sample_intersection_anchor(idf)
if not anchor:
return None
target_subtype = template.target_subtype or anchor.get('target_subtype') or 'region'
sql = template.sql_template.format(
anchor_id=anchor['anchor_id'],
target_subtype=target_subtype
)
candidates = build_candidate_list(
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['anchor_name'],
target_subtype=target_subtype
)
elif template.family == "set_operations":
if template.template_id == "union_03":
# 3-anchor union by ID — candidates: 3 per anchor (9 total)
anchors = [
sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
for _ in range(3)
]
if any(a is None for a in anchors):
return None
anchor1, anchor2, anchor3 = anchors
sql = template.sql_template.format(
anchor_id_1=anchor1['id'],
anchor_id_2=anchor2['id'],
anchor_id_3=anchor3['id'],
)
per_anchor = 3
cands = [
build_candidate_list(con, a['id'], a['name'], 'divisions_area',
num_candidates=per_anchor, difficulty="medium")
for a in anchors
]
candidates = _merge_candidate_lists(*cands, max_total=9)
question = random.choice(template.question_hints).format(
anchor_1_name=anchor1['name'],
anchor_2_name=anchor2['name'],
anchor_3_name=anchor3['name'],
)
elif template.template_id in ("contain_multi_01", "contain_multi_02", "contain_multi_03"):
# country IN clause — 2 or 3 anchors, each contributes its country code.
# Sample unique countries so the query actually teaches a multi-country
# pattern rather than repeating the same ISO code multiple times.
num_a = 3 if template.template_id == "contain_multi_02" else 2
country_inventory = tables['divisions_area_inventory']
country_inventory = country_inventory[
(country_inventory['subtype'] == 'country')
& country_inventory['country'].notna()
].drop_duplicates(subset=['country'])
if len(country_inventory) < num_a:
return None
sampled = country_inventory.sample(n=num_a, replace=False)
anchors = [
{
'id': row['id'],
'name': row['name'],
'subtype': row.get('subtype'),
'country': row.get('country'),
'source': 'divisions_area',
}
for _, row in sampled.iterrows()
]
countries = [a.get('country') or 'US' for a in anchors]
target_subtype = template.target_subtype or 'region'
per_anchor = 3 if num_a == 3 else 4
fmt_kwargs = dict(
target_subtype=target_subtype,
)
for i, c in enumerate(countries, 1):
fmt_kwargs[f'country_{i}'] = c
sql = template.sql_template.format(**fmt_kwargs)
cands = [
build_candidate_list(con, a['id'], a['name'], 'divisions_area',
num_candidates=per_anchor, difficulty="medium")
for a in anchors
]
candidates = _dedupe_country_candidates(
_merge_candidate_lists(*cands, max_total=num_a * per_anchor),
max_total=num_a * per_anchor,
)
q_kwargs = dict(target_subtype=target_subtype)
for i, a in enumerate(anchors, 1):
q_kwargs[f'anchor_{i}_name'] = a['name']
question = random.choice(template.question_hints).format(**q_kwargs)
elif template.template_id == "union_02":
# Filtered union: ST_Union_Agg of contained sub-features.
# Pin to template.target_subtype so hardcoded vocabulary hints
# (e.g. "districts") always match the SQL subtype.
df = tables['containment_pairs']
if template.target_subtype:
filtered = df[df['contained_subtype'] == template.target_subtype]
if not filtered.empty:
df = filtered
pair = sample_containment_anchor(df)
if not pair:
return None
target_subtype = template.target_subtype or pair.get('contained_subtype', 'county')
sql = template.sql_template.format(
anchor_id=pair['container_id'],
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, pair['container_id'], pair['container_name'], 'divisions_area',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=pair['container_name'],
target_subtype=target_subtype,
)
else:
# union_01: 2-anchor union by ID — candidates: 5 per anchor
anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
if not anchor1 or not anchor2:
return None
sql = template.sql_template.format(
anchor_id_1=anchor1['id'],
anchor_id_2=anchor2['id'],
)
cands1 = build_candidate_list(
con, anchor1['id'], anchor1['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
cands2 = build_candidate_list(
con, anchor2['id'], anchor2['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates = _merge_candidate_lists(cands1, cands2, max_total=10)
question = random.choice(template.question_hints).format(
anchor_1_name=anchor1['name'],
anchor_2_name=anchor2['name'],
)
elif template.family == "buffer":
# Buffer operations use metre distances in SQL and a human-readable
# buffer_label in questions, e.g. (1000, "1 km") or (250, "250 m").
# The template SQL divides by 111 320 to approximate metres in degrees.
_buffer_choices = [
(100, "100 m"),
(250, "250 m"),
(500, "500 m"),
(1000, "1 km"),
(2000, "2 km"),
(5000, "5 km"),
(10000, "10 km"),
(25000, "25 km"),
(50000, "50 km"),
(100000, "100 km"),
(200000, "200 km"),
]
if template.num_anchors == 1:
if template.anchor_source == "natural_earth":
anchor = sample_random_entity(
con,
tables['natural_earth_inventory'],
'natural_earth',
subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES),
)
else:
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
if not anchor:
return None
buffer_m, buffer_label = random.choice(_buffer_choices)
fmt_kwargs = dict(anchor_id=anchor['id'], buffer_m=buffer_m)
q_kwargs = dict(anchor_name=anchor['name'], buffer_label=buffer_label)
if template.target_subtype:
fmt_kwargs['target_subtype'] = template.target_subtype
q_kwargs['target_subtype'] = template.target_subtype
sql = template.sql_template.format(**fmt_kwargs)
candidates = build_candidate_list(
con, anchor['id'], anchor['name'], anchor['source'],
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(**q_kwargs)
else:
# Multi-anchor buffer (2–5 places): union of individual buffers.
num_a = template.num_anchors
anchors = []
for _ in range(num_a):
a = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
if not a:
return None
anchors.append(a)
buffer_m, buffer_label = random.choice(_buffer_choices[:7])
fmt_kwargs = {f'anchor_id_{i+1}': a['id'] for i, a in enumerate(anchors)}
fmt_kwargs['buffer_m'] = buffer_m
if template.target_subtype:
fmt_kwargs['target_subtype'] = template.target_subtype
sql = template.sql_template.format(**fmt_kwargs)
# Build one candidate list per anchor then merge.
per_anchor_n = max(2, 10 // num_a)
cand_lists = [
build_candidate_list(
con, a['id'], a['name'], 'divisions_area',
num_candidates=per_anchor_n, difficulty="medium",
)
for a in anchors
]
candidates = _merge_candidate_lists(*cand_lists)
q_kwargs = {f'anchor_{i+1}_name': a['name'] for i, a in enumerate(anchors)}
q_kwargs['buffer_label'] = buffer_label
if template.target_subtype:
q_kwargs['target_subtype'] = template.target_subtype
question = random.choice(template.question_hints).format(**q_kwargs)
elif template.family == "partial_selection":
# Partial selection (northern half, clipping, etc.)
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
if not anchor:
return None
if template.num_anchors == 1:
sql = template.sql_template.format(
anchor_id=anchor['id'],
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['name'],
)
candidates = build_candidate_list(
con, anchor['id'], anchor['name'], 'divisions_area',
num_candidates=10, difficulty="hard",
)
else:
# Mixed-source clip: division intersected with a natural_earth feature.
# Use cross_source_relations so the pair is guaranteed to intersect —
# random sampling almost never produces an intersecting pair.
cs_anchor = sample_cross_source_anchor(
tables.get('cross_source_relations', pd.DataFrame()),
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
)
if not cs_anchor:
return None
clip_feature = {
'id': cs_anchor['natural_id'],
'name': cs_anchor['natural_name'],
'source': 'natural_earth',
}
# Override the division anchor with the paired division so the
# ST_Intersects check in the SQL is guaranteed to pass.
anchor = {
'id': cs_anchor['division_id'],
'name': cs_anchor['division_name'],
'source': 'divisions_area',
}
sql = template.sql_template.format(
anchor_id=anchor['id'],
clip_feature_id=clip_feature['id'],
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['name'],
clip_feature_name=clip_feature['name'],
)
# Build candidates for BOTH anchors so the model sees both IDs
# in context and learns to pick the right one for each placeholder.
div_cands = build_candidate_list(
con, anchor['id'], anchor['name'], 'divisions_area',
num_candidates=5, difficulty="hard",
)
ne_cands = build_candidate_list(
con, clip_feature['id'], clip_feature['name'], 'natural_earth',
num_candidates=5, difficulty="hard",
)
candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
elif template.family == "aggregation":
# Teach the model to distinguish singular superlatives ("the largest")
# from explicit top-N requests ("top 5 largest").
top_n = random.choice([1, 3, 5, 10])
target_subtype = random.choice(['county', 'region'])
singular_hints = [h for h in template.question_hints if '{top_n}' not in h]
plural_hints = [h for h in template.question_hints if '{top_n}' in h]
question_hint_pool = singular_hints if top_n == 1 and singular_hints else plural_hints or template.question_hints
if template.template_id in ['agg_03', 'agg_04']:
# Country-level aggregation: SQL uses country code, so the anchor
# in the question must also be a country.
anchor = sample_random_entity(
con,
tables['divisions_area_inventory'],
'divisions_area',
subtypes={'country'},
)
if not anchor:
return None
country = anchor.get('country') or 'US'
sql = template.sql_template.format(
country=country,
target_subtype=target_subtype,
top_n=top_n,
)
candidates = build_candidate_list(
con, anchor['id'], anchor['name'], 'divisions_area',
num_candidates=10, difficulty="hard"
)
question = random.choice(question_hint_pool).format(
top_n=top_n,
target_subtype=target_subtype,
anchor_name=anchor['name'],
)
else:
# Containment-based aggregation: anchor is the container region.
anchor = sample_containment_anchor(tables['containment_pairs'])
if not anchor:
return None
sql = template.sql_template.format(
anchor_id=anchor['container_id'],
target_subtype=target_subtype,
top_n=top_n,
)
candidates = build_candidate_list(
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
num_candidates=10, difficulty="hard"
)
question = random.choice(question_hint_pool).format(
top_n=top_n,
target_subtype=target_subtype,
anchor_name=anchor['container_name'],
)
elif template.family == "chained":
# chained_12/13: country-level coastal/landlocked via adjacency.
# The SQL uses ST_Touches (not containment), so bypass the containment
# pair sampling and use adjacency_pairs with country-level anchors.
if template.template_id in {"chained_12", "chained_13"}:
adj_df = tables.get('adjacency_pairs', pd.DataFrame())
country_adj = (
adj_df[
(adj_df['anchor_subtype'] == 'country')
& (adj_df['target_subtype'] == 'country')
]
if not adj_df.empty else pd.DataFrame()
)
if country_adj.empty:
return None
pair = sample_adjacency_anchor(country_adj)
if not pair:
return None
sql = template.sql_template.format(anchor_id=pair['anchor_id'])
candidates = build_candidate_list(
con, pair['anchor_id'], pair['anchor_name'], 'divisions_area',
num_candidates=10, difficulty="hard"
)
question = random.choice(template.question_hints).format(
anchor_name=pair['anchor_name']
)
anchor = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
else:
# Use pre-filtered coastal/landlocked containment pairs so the SQL
# verification step doesn't constantly return empty results.
_COASTAL_CHAINED = {"chained_01", "chained_06", "chained_10"}
_LANDLOCKED_CHAINED = {"chained_02", "chained_07", "chained_11"}
if template.template_id in _COASTAL_CHAINED:
table_key = 'coastal_containment_pairs'
elif template.template_id in _LANDLOCKED_CHAINED:
table_key = 'landlocked_containment_pairs'
else:
table_key = 'containment_pairs'
df = tables.get(table_key, tables['containment_pairs'])
# When the template pins a target_subtype (e.g. chained_06 wants
# counties), only consider pairs whose contained entity already
# matches — guarantees the sampled container holds at least one
# entity of the right subtype so the SQL filter returns rows.
if template.target_subtype:
df = df[df['contained_subtype'] == template.target_subtype]
# chained_10/11 additionally need a country-level container so
# phrasings like "coastal states of India" line up.
if template.template_id in {"chained_10", "chained_11"}:
df = df[df['container_subtype'] == 'country']
anchor = sample_containment_anchor(df)
if not anchor:
return None
target_subtype = template.target_subtype or anchor.get('contained_subtype', 'county')
sql = template.sql_template.format(
anchor_id=anchor['container_id'],
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
num_candidates=10, difficulty="hard"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['container_name'],
target_subtype=target_subtype,
)
elif template.family == "multi_adjacency":
# Use common_neighbor_pairs so anchor1 and anchor2 are guaranteed to
# share at least one touching neighbour — SQL will return non-empty.
# Filter by both anchor subtypes AND shared_neighbor_subtype so the
# sampled pair is geographically coherent with the template intent:
# multi_adj_01: region anchors → region result
# multi_adj_02: country anchors → country result
# multi_adj_03: region anchors → county result
_MULTI_ADJ_ANCHOR_SUBTYPES = {
"multi_adj_01": ("region", "region"),
"multi_adj_02": ("country", "country"),
"multi_adj_03": ("region", "region"),
}
cn_df = tables.get('common_neighbor_pairs', pd.DataFrame())
if cn_df.empty:
return None
if template.target_subtype and 'shared_neighbor_subtype' in cn_df.columns:
filtered = cn_df[cn_df['shared_neighbor_subtype'] == template.target_subtype]
if not filtered.empty:
cn_df = filtered
if template.template_id in _MULTI_ADJ_ANCHOR_SUBTYPES and 'anchor_subtype_1' in cn_df.columns:
a1_st, a2_st = _MULTI_ADJ_ANCHOR_SUBTYPES[template.template_id]
filtered = cn_df[
(cn_df['anchor_subtype_1'] == a1_st) &
(cn_df['anchor_subtype_2'] == a2_st)
]
if not filtered.empty:
cn_df = filtered
row = cn_df.sample(n=1).iloc[0]
anchor1 = {'id': row['anchor_id_1'], 'name': row['anchor_name_1'], 'source': 'divisions_area'}
anchor2 = {'id': row['anchor_id_2'], 'name': row['anchor_name_2'], 'source': 'divisions_area'}
target_subtype = template.target_subtype or row.get('shared_neighbor_subtype', 'region')
sql = template.sql_template.format(
anchor_id_1=anchor1['id'],
anchor_id_2=anchor2['id'],
target_subtype=target_subtype,
)
candidates1 = build_candidate_list(
con, anchor1['id'], anchor1['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates2 = build_candidate_list(
con, anchor2['id'], anchor2['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates = _merge_candidate_lists(candidates1, candidates2)
question = random.choice(template.question_hints).format(
anchor_1_name=anchor1['name'],
anchor_2_name=anchor2['name'],
target_subtype=target_subtype,
)
elif template.family == "difference":
if template.anchor_source == "mixed":
# divisions_area anchor differenced against a natural_earth feature.
# Use cross_source_relations so the pair is guaranteed to intersect
# (ST_Difference on non-intersecting geometries is always equal to
# the original geometry — a trivial and uninformative sample).
cs_anchor = sample_cross_source_anchor(
tables.get('cross_source_relations', pd.DataFrame()),
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
)
if not cs_anchor:
return None
anchor = {
'id': cs_anchor['division_id'],
'name': cs_anchor['division_name'],
'source': 'divisions_area',
}
clip_feature = {
'id': cs_anchor['natural_id'],
'name': cs_anchor['natural_name'],
'source': 'natural_earth',
}
sql = template.sql_template.format(
anchor_id=anchor['id'],
clip_feature_id=clip_feature['id'],
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['name'],
clip_feature_name=clip_feature['name'],
)
# Build candidates for BOTH anchors — model must see both IDs
# to correctly assign anchor_id vs clip_feature_id in the SQL.
div_cands = build_candidate_list(
con, anchor['id'], anchor['name'], 'divisions_area',
num_candidates=5, difficulty="hard",
)
ne_cands = build_candidate_list(
con, clip_feature['id'], clip_feature['name'], 'natural_earth',
num_candidates=5, difficulty="hard",
)
candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
else:
# Two divisions_area anchors: use both ends of a containment
# pair so the contained entity is guaranteed to intersect the
# container. ST_Difference(container, contained) yields the
# portion of the container outside the contained piece.
pair = sample_containment_anchor(tables['containment_pairs'])
if not pair:
return None
anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
anchor2 = {'id': pair['contained_id'], 'name': pair['contained_name']}
sql = template.sql_template.format(
anchor_id_1=anchor1['id'],
anchor_id_2=anchor2['id'],
)
candidates1 = build_candidate_list(
con, anchor1['id'], anchor1['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates2 = build_candidate_list(
con, anchor2['id'], anchor2['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates = _merge_candidate_lists(candidates1, candidates2)
question = random.choice(template.question_hints).format(
anchor_1_name=anchor1['name'],
anchor_2_name=anchor2['name'],
)
elif template.family == "border_corridor":
# Buffered border zone — needs two anchors that actually touch.
pair = sample_adjacency_anchor(tables['adjacency_pairs'])
if not pair:
return None
anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
anchor2 = {'id': pair['target_id'], 'name': pair['target_name']}
buffer_val = random.choice([5, 10, 25, 50])
sql = template.sql_template.format(
anchor_id_1=anchor1['id'],
anchor_id_2=anchor2['id'],
buffer_km=buffer_val,
)
candidates1 = build_candidate_list(
con, anchor1['id'], anchor1['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates2 = build_candidate_list(
con, anchor2['id'], anchor2['name'], 'divisions_area',
num_candidates=5, difficulty="medium"
)
candidates = _merge_candidate_lists(candidates1, candidates2)
question = random.choice(template.question_hints).format(
anchor_1_name=anchor1['name'],
anchor_2_name=anchor2['name'],
buffer_km=buffer_val,
)
elif template.family == "window_function":
anchor = sample_random_entity(
con,
tables['divisions_area_inventory'],
'divisions_area',
subtypes={'country'},
)
if not anchor:
return None
country = anchor.get('country') or 'US'
target_subtype = template.target_subtype or 'county'
sql = template.sql_template.format(
country=country,
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, anchor['id'], anchor['name'], 'divisions_area',
num_candidates=10, difficulty="hard"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['name'],
target_subtype=target_subtype,
)
elif template.family == "attribute_filter":
anchor = sample_random_entity(
con,
tables['divisions_area_inventory'],
'divisions_area',
subtypes={'country'},
)
if not anchor:
return None
country = anchor.get('country') or 'US'
target_subtype = template.target_subtype or 'region'
sql = template.sql_template.format(
country=country,
target_subtype=target_subtype,
)
candidates = build_candidate_list(
con, anchor['id'], anchor['name'], 'divisions_area',
num_candidates=10, difficulty="medium"
)
question = random.choice(template.question_hints).format(
anchor_name=anchor['name'],
target_subtype=target_subtype,
country=country,
)
else:
# Skip unsupported families
return None
# Execute SQL to verify
try:
result = con.execute(_for_execution(sql)).fetchdf()
if result.empty:
return None
except Exception as e:
# Errors are tracked in worker return, no need to print
return None
# Collect every anchor ID that appears in the generated SQL so we can
# mark them as the "selected" candidates in the training sample.
_multi_anchor_families = {"set_operations", "multi_adjacency", "difference", "border_corridor", "buffer"}
# Mixed partial_selection (partial_05) and mixed difference (diff_02) each
# have two anchors from different sources — both must be marked selected.
_is_mixed_two_anchor = (
template.anchor_source == "mixed" and template.num_anchors == 2
)
if _is_mixed_two_anchor:
# partial_05 / diff_02: anchor (division) + clip_feature (natural_earth)
mixed_ids = {anchor.get("id", ""), clip_feature.get("id", "")}
selected_candidate_ids = [c.candidate_id for c in candidates if c.id in mixed_ids]
elif template.family in _multi_anchor_families and template.num_anchors >= 2:
anchor_ids: set = set()
for var in ("anchor1", "anchor2", "anchor3"):
obj = locals().get(var)
if obj:
anchor_ids.add(obj.get("id", ""))
if "anchors" in locals():
for a in locals()["anchors"]:
if a:
anchor_ids.add(a.get("id", ""))
selected_candidate_ids = [c.candidate_id for c in candidates if c.id in anchor_ids]
else:
anchor_id_to_find = (
anchor.get('anchor_id')
or anchor.get('container_id')
or anchor.get('natural_id')
or anchor.get('id')
)
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
question, surface_variants = _diversify_question_surface(question, template.family)
return TrainingSample(
id=sample_id,
question=question,
candidates=candidates,
target={
"selected_candidates": selected_candidate_ids,
"sql": sql,
},
metadata={
"task_family": template.family,
"sql_difficulty": template.sql_difficulty,
"grounding_difficulty": "medium",
"template_id": template.template_id,
"num_candidates": len(candidates),
"anchor_source": template.anchor_source,
"sql_verified": True,
"surface_variants": surface_variants,
}
)
def generate_sample_batch_worker(args):
"""Worker function that processes a batch of work items with a single DuckDB connection.
Initializes DuckDB, spatial extension, templates module, and relation tables
ONCE per batch, then processes all items sequentially.
"""
from pathlib import Path
work_items, intermediate_dir_str = args
# Convert string back to Path
intermediate_dir = Path(intermediate_dir_str)
# Initialize DuckDB ONCE for the entire batch
con = duckdb.connect()
con.execute("SET enable_progress_bar=false")
con.execute("INSTALL spatial")
con.execute("LOAD spatial")
# Load relation tables ONCE
tables = load_relation_tables(intermediate_dir, quiet=True)
# Process all items in batch
results = []
for family, template_dict, sample_id, _ in work_items:
# Reconstruct template from dict (sql_templates is already imported at module level)
template = sql_templates.SQLTemplate(**template_dict)
try:
sample = generate_template_based_sample(con, template, tables, sample_id)
if sample:
results.append((sample, family, template.template_id, None))
else:
results.append((None, family, template.template_id, "Empty result"))
except Exception as e:
results.append((None, family, template_dict.get('template_id', 'unknown'), str(e)))
con.close()
return results
def generate_batch_core(
work_items: List[tuple],
intermediate_dir: str,
) -> List[Dict[str, Any]]:
"""Standalone batch worker usable from Modal or any remote context.
Data paths are resolved via GAZET_DATA_DIR env var (set in Modal image).
Args:
work_items: List of (family, template_dict, sample_id, _) tuples
intermediate_dir: Path to intermediate dir with relation parquets
Returns:
List of dicts with keys: sample (dict or None), family, template_id, error
"""
from pathlib import Path as _Path
intermediate = _Path(intermediate_dir)
con = duckdb.connect()
con.execute("SET enable_progress_bar=false")
con.execute("INSTALL spatial")
con.execute("LOAD spatial")
tables = load_relation_tables(intermediate, quiet=True)
results = []
for family, template_dict, sample_id, _ in work_items:
template = sql_templates.SQLTemplate(**template_dict)
try:
sample = generate_template_based_sample(con, template, tables, sample_id)
if sample:
results.append({
"sample": sample.model_dump(),
"family": family,
"template_id": template.template_id,
"error": None,
})
else:
results.append({
"sample": None,
"family": family,
"template_id": template.template_id,
"error": "Empty result",
})
except Exception as e:
results.append({
"sample": None,
"family": family,
"template_id": template_dict.get('template_id', 'unknown'),
"error": str(e),
})
con.close()
return results
def prepare_work_items(
target_counts: Dict[str, int],
retry_multiplier: int = 2,
start_counter: int = 1,
intermediate_dir_str: str = "",
) -> List[tuple]:
"""Prepare shuffled work items for sample generation.
Returns list of (family, template_dict, sample_id, intermediate_dir_str) tuples.
Reusable by both local main() and Modal orchestrator.
"""
work_items = []
sample_counter = start_counter
for family, target_count in target_counts.items():
if target_count == 0:
continue
family_templates = [t for t in TEMPLATES if t.family == family]
if not family_templates:
print(f"No templates found for {family}, skipping...")
continue
# Distribute target evenly across templates so every template_id gets
# a guaranteed share. Uniform random choice previously let rare
# variants like partial_05 / diff_02 get under-represented or dropped
# entirely when their mixed-source branch hit transient failures.
n_tpl = len(family_templates)
per_tpl = target_count // n_tpl
remainder = target_count % n_tpl
for i, template in enumerate(family_templates):
count = per_tpl + (1 if i < remainder else 0)
template_dict = {
'template_id': template.template_id,
'family': template.family,
'sql_difficulty': template.sql_difficulty,
'anchor_source': template.anchor_source,
'num_anchors': template.num_anchors,
'sql_template': template.sql_template,
'question_hints': template.question_hints,
'target_subtype': template.target_subtype,
'requires_buffer': template.requires_buffer,
'requires_aggregation': template.requires_aggregation
}
for _ in range(count * retry_multiplier):
work_items.append((
family,
template_dict,
f"sample_{sample_counter:06d}",
intermediate_dir_str,
))
sample_counter += 1
random.shuffle(work_items)
return work_items
def main():
"""Generate training samples."""
global TARGET_COUNTS, MAX_WORKERS, RETRY_MULTIPLIER, APPEND_MODE
# Setup paths
script_dir = Path(__file__).parent
intermediate_dir = script_dir.parent / "intermediate"
output_dir = script_dir.parent / "output"
output_dir.mkdir(exist_ok=True, parents=True)
# Load relation tables once to check availability
print("Loading relation tables...")
tables = load_relation_tables(intermediate_dir, quiet=False)
# Use configured target counts or defaults
if TARGET_COUNTS is None:
target_counts = {
'direct_lookup': 100,
'adjacency': 150,
'multi_adjacency': 75,
'containment': 100,
'intersection': 100,
'buffer': 100,
'chained': 150,
'difference': 75,
'border_corridor': 75,
'set_operations': 150,
'partial_selection': 75,
'aggregation': 100,
'window_function': 75,
'attribute_filter': 75,
}
else:
target_counts = TARGET_COUNTS
# Load existing samples if in append mode
existing_samples = []
existing_sample_ids = set()
jsonl_file = output_dir / "dataset_raw.jsonl"
if APPEND_MODE and jsonl_file.exists():
print(f"\nAppend mode: Loading existing samples from {jsonl_file}")
with open(jsonl_file, 'r') as f:
for line in f:
if line.strip():
sample_data = json.loads(line)
existing_samples.append(sample_data)
existing_sample_ids.add(sample_data['id'])
print(f" Found {len(existing_samples)} existing samples")
# Determine starting sample counter
max_existing_id = max([int(s['id'].split('_')[1]) for s in existing_samples if s['id'].startswith('sample_')], default=0)
sample_counter = max_existing_id + 1
else:
sample_counter = 1
# Prepare work items using shared helper
work_items = prepare_work_items(
target_counts=target_counts,
retry_multiplier=RETRY_MULTIPLIER,
start_counter=sample_counter,
intermediate_dir_str=str(intermediate_dir),
)
starting_sample_counter = sample_counter
# Partition work items into batches (one per worker)
num_workers = min(MAX_WORKERS, len(work_items))
if num_workers == 0:
print("No work items to process")
return
batch_size = (len(work_items) + num_workers - 1) // num_workers
batches = []
for i in range(0, len(work_items), batch_size):
batch = work_items[i:i + batch_size]
batches.append((batch, str(intermediate_dir)))
# Generate samples in parallel (one batch per worker)
active_families = len([f for f in target_counts.values() if f > 0])
print(f"\nGenerating {len(work_items)} samples across {active_families} families...")
print(f" Split into {len(batches)} batches of ~{batch_size} items (1 DuckDB init per batch)")
if APPEND_MODE and existing_samples:
print(f"Appending: starting from sample_{starting_sample_counter:03d}")
all_samples = []
family_progress = {f: {'success': 0, 'failed': 0} for f in target_counts.keys() if target_counts[f] > 0}
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Submit one batch per worker
futures = {executor.submit(generate_sample_batch_worker, batch): i for i, batch in enumerate(batches)}
# Collect results as batches complete
batches_done = 0
for future in as_completed(futures):
try:
batch_results = future.result()
for sample, family, template_id, error in batch_results:
if sample:
all_samples.append(sample)
family_progress[family]['success'] += 1
else:
family_progress[family]['failed'] += 1
except Exception as e:
print(f"\n Batch failed: {e}")
batches_done += 1
total_done = sum(p['success'] + p['failed'] for p in family_progress.values())
print(f"\r Progress: {total_done}/{len(work_items)} samples ({batches_done}/{len(batches)} batches) ", end='', flush=True)
print() # New line after progress
# Show distribution (keep all samples, no filtering)
print("\nResults by family:")
for family in sorted(family_progress.keys()):
success = family_progress[family]['success']
failed = family_progress[family]['failed']
target = target_counts.get(family, 0)
total = success + failed
success_rate = (success / total * 100) if total > 0 else 0
print(f" {family:20s}: {success:3d} success / {failed:3d} failed ({success_rate:5.1f}% success rate, target: {target})")
# Save combined JSONL (skip individual JSON files for speed at scale)
print(f"\nSaving {len(all_samples)} new samples...")
if APPEND_MODE and existing_samples:
# Append to existing dataset
print(f"Appending to existing dataset ({len(existing_samples)} existing samples)")
with open(jsonl_file, 'a') as f:
for sample in all_samples:
f.write(json.dumps(sample.model_dump()) + '\n')
total_samples = len(existing_samples) + len(all_samples)
else:
# Overwrite dataset
with open(jsonl_file, 'w') as f:
for sample in all_samples:
f.write(json.dumps(sample.model_dump()) + '\n')
total_samples = len(all_samples)
print(f"\nGenerated {len(all_samples)} new samples")
print(f"Total dataset size: {total_samples} samples")
print(f" Dataset: {jsonl_file}")
if __name__ == "__main__":
main()