""" Precompute spatial relation tables for efficient anchor sampling. This script computes: - Adjacency pairs (touching features) - Containment pairs (features within other features) - Intersection pairs (overlapping features) - Cross-source relations (divisions_area ↔ natural_earth) Output: - intermediate/adjacency_pairs.parquet - intermediate/containment_pairs.parquet - intermediate/intersection_pairs.parquet - intermediate/cross_source_relations.parquet """ import os from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import duckdb import pandas as pd from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH # (container_subtype, contained_subtype) combos used by the chained, # containment, and disambiguation templates. The normalized Overture input is # now restricted to country / region / county, so keep relation building in # sync and avoid wasting work on removed subtypes. _CONTAINMENT_SUBTYPE_PAIRS = ( ("country", "region"), ("country", "county"), ("region", "county"), ) # Natural Earth subtype vocabulary normalized to lowercase. # We lowercase the source subtype values while building relation tables so # mixed casing in upstream data (e.g. Lake vs lake, Range/mtn vs range/mtn) # does not fragment anchor pools or break template matching. _NE_CROSS_SOURCE_SUBTYPES = ( "sea", "ocean", "lake", "river", "basin", "gulf", "bay", "strait", "island group", "peninsula", "range/mtn", "plateau", "plain", "lowland", "valley", "depression", "gorge", ) _DIVISION_SUBTYPES = ("country", "region", "county") def _country_filter(countries: list) -> tuple[str, list]: """Return (SQL WHERE clause, params) handling 'all' sentinel.""" if countries == ["all"]: return "", [] return "WHERE country IN (SELECT unnest(?))", [countries] def _country_filter_for_join(countries: list) -> tuple[str, list]: """Return a country filter for self-joins over normalized admin data.""" if countries == ["all"]: return "", [] return "WHERE country IN (SELECT unnest(?))", [countries] def _country_chunks( con: duckdb.DuckDBPyConnection, countries: list, chunk_size: int = 40, ) -> list[list[str]]: """Return explicit country batches for safer global containment joins.""" if countries != ["all"]: return [countries] rows = con.execute( """ SELECT DISTINCT country FROM read_parquet(?) WHERE country IS NOT NULL AND trim(country) != '' ORDER BY country """, [DIVISIONS_AREA_PATH], ).fetchall() codes = [row[0] for row in rows] return [codes[i:i + chunk_size] for i in range(0, len(codes), chunk_size)] def compute_adjacency_pairs( con: duckdb.DuckDBPyConnection, countries: list, limit: int ) -> pd.DataFrame: """Find all pairs of features that touch (share a boundary).""" print("Computing adjacency pairs (optimized with spatial index)...") cfilter, cparams = _country_filter_for_join(countries) # Use bounding box pre-filter to avoid full cartesian product query = f""" WITH features AS ( SELECT id, names."primary" AS name, subtype, country, admin_level, geometry, ST_Envelope(geometry) AS bbox FROM read_parquet(?) {cfilter} ) SELECT a.id AS anchor_id, a.name AS anchor_name, a.subtype AS anchor_subtype, a.country AS anchor_country, b.id AS target_id, b.name AS target_name, b.subtype AS target_subtype, b.country AS target_country, 'adjacency' AS relation_type FROM features AS a JOIN features AS b ON ( a.id < b.id AND ST_Intersects(a.bbox, b.bbox) AND ST_Touches(a.geometry, b.geometry) ) LIMIT ? """ df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf() print(f"Found {len(df)} adjacency pairs") return df def _stratified_containment( con: duckdb.DuckDBPyConnection, countries: list, limit: int, relation_type: str, extra_where: str = "", extra_params: list = None, ) -> pd.DataFrame: """Compute containment pairs stratified by (container_subtype, contained_subtype). A single global self-join with LIMIT fills up with coarse country->region pairs before emitting country->county and region->county pairs. We run one focused query per subtype combo instead so every combo receives a fair share of the overall limit. ``extra_where`` / ``extra_params`` let the coastal and landlocked variants inject their country-set filter without duplicating the whole body. """ extra_params = extra_params or [] # Use a lower target per subtype combo for global runs; they are the most # memory-intensive part of the pipeline and don't need huge anchor tables. if countries == ["all"]: per_combo = min(max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100), 1500) else: per_combo = max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100) country_batches = _country_chunks(con, countries) frames: list[pd.DataFrame] = [] for container_st, contained_st in _CONTAINMENT_SUBTYPE_PAIRS: remaining = per_combo combo_parts: list[pd.DataFrame] = [] for batch in country_batches: if remaining <= 0: break cfilter, cparams = _country_filter(batch) query = f""" WITH a AS ( SELECT src.id, src.names."primary" AS name, src.subtype, src.country, src.admin_level, src.geometry, ST_Envelope(src.geometry) AS bbox FROM read_parquet(?) AS src WHERE src.subtype = '{container_st}' {cfilter.replace("WHERE", "AND") if cfilter else ""} {extra_where} ), b AS ( SELECT dst.id, dst.names."primary" AS name, dst.subtype, dst.country, dst.admin_level, dst.geometry, ST_Envelope(dst.geometry) AS bbox FROM read_parquet(?) AS dst WHERE dst.subtype = '{contained_st}' {cfilter.replace("WHERE", "AND") if cfilter else ""} ) SELECT a.id AS container_id, a.name AS container_name, a.subtype AS container_subtype, b.id AS contained_id, b.name AS contained_name, b.subtype AS contained_subtype, a.country AS container_country, '{relation_type}' AS relation_type FROM a JOIN b ON ( a.id != b.id AND ST_Intersects(a.bbox, b.bbox) AND ST_Within(b.geometry, a.geometry) ) LIMIT {remaining} """ params = [DIVISIONS_AREA_PATH] + extra_params + cparams + [DIVISIONS_AREA_PATH] + cparams df_part = con.execute(query, params).fetchdf() if not df_part.empty: combo_parts.append(df_part) remaining -= len(df_part) df_combo = ( pd.concat(combo_parts, ignore_index=True) if combo_parts else pd.DataFrame() ) print(f" {relation_type} {container_st:>10s} -> {contained_st:<10s}: {len(df_combo)} pairs") frames.append(df_combo) return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() def compute_containment_pairs( con: duckdb.DuckDBPyConnection, countries: list, limit: int ) -> pd.DataFrame: """Find containment pairs stratified across admin-level combinations.""" print("\nComputing containment pairs (stratified by subtype combo)...") df = _stratified_containment(con, countries, limit, relation_type="containment") print(f"Found {len(df)} containment pairs") return df def compute_intersection_pairs( con: duckdb.DuckDBPyConnection, countries: list, limit: int ) -> pd.DataFrame: """Find pairs that intersect but don't touch or contain.""" print("\nComputing intersection pairs (optimized)...") cfilter, cparams = _country_filter_for_join(countries) query = f""" WITH features AS ( SELECT id, names."primary" AS name, subtype, country, admin_level, geometry, ST_Envelope(geometry) AS bbox FROM read_parquet(?) {cfilter} ) SELECT a.id AS anchor_id, a.name AS anchor_name, a.subtype AS anchor_subtype, b.id AS target_id, b.name AS target_name, b.subtype AS target_subtype, 'intersection' AS relation_type FROM features AS a JOIN features AS b ON ( a.id < b.id AND ST_Intersects(a.bbox, b.bbox) AND ST_Intersects(a.geometry, b.geometry) AND NOT ST_Touches(a.geometry, b.geometry) AND NOT ST_Within(a.geometry, b.geometry) AND NOT ST_Within(b.geometry, a.geometry) ) LIMIT ? """ df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf() print(f"Found {len(df)} same-source intersection pairs") return df def compute_cross_source_relations( con: duckdb.DuckDBPyConnection, countries: list, limit: int ) -> pd.DataFrame: """Find relations between divisions_area and natural_earth. The join is skewed both by abundant Natural Earth subtypes and by coarse admin features. We therefore stratify by (natural_subtype, division_subtype) so country / region / county anchors all make it into the pool used by mixed-source, NE-intersection, and NE-adjacency templates. """ print("\nComputing cross-source relations (stratified by NE subtype and admin subtype)...") cfilter, cparams = _country_filter(countries) num_combos = len(_NE_CROSS_SOURCE_SUBTYPES) * len(_DIVISION_SUBTYPES) per_combo = max(limit // num_combos, 10) frames: list[pd.DataFrame] = [] for natural_subtype in _NE_CROSS_SOURCE_SUBTYPES: for division_subtype in _DIVISION_SUBTYPES: query = f""" WITH divisions AS ( SELECT id, names."primary" AS name, subtype, country, geometry FROM read_parquet(?) WHERE geometry IS NOT NULL AND names."primary" IS NOT NULL AND trim(names."primary") != '' AND subtype = '{division_subtype}' {cfilter.replace("WHERE", "AND") if cfilter else ''} ), natural_features AS ( SELECT id, names."primary" AS name, lower(subtype) AS natural_subtype, geometry FROM read_parquet(?) WHERE geometry IS NOT NULL AND names."primary" IS NOT NULL AND trim(names."primary") != '' AND lower(subtype) = '{natural_subtype}' ) SELECT d.id AS division_id, d.name AS division_name, d.subtype AS division_subtype, d.country AS division_country, n.id AS natural_id, n.name AS natural_name, n.natural_subtype AS natural_subtype, CASE WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches' WHEN ST_Within(d.geometry, n.geometry) THEN 'within' WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains' WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects' END AS relation_type FROM divisions AS d JOIN natural_features AS n ON ST_Intersects(d.geometry, n.geometry) LIMIT {per_combo} """ df_part = con.execute( query, [DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH], ).fetchdf() print( f" cross_source {natural_subtype:>12s} x {division_subtype:<7s}: " f"{len(df_part)} rows" ) frames.append(df_part) df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() print(f"Found {len(df)} cross-source relations") return df def compute_coastal_containment_pairs( con: duckdb.DuckDBPyConnection, countries: list, limit: int, ) -> pd.DataFrame: """Stratified containment pairs limited to coastal-country containers. Used by chained templates so sampled anchors actually have sea-adjacent sub-features. Stratification guarantees coverage of every supported admin-level combination (country->region, country->county, region->county). """ print("\nComputing coastal containment pairs (stratified)...") extra_where = f""" AND EXISTS ( SELECT 1 FROM read_parquet('{NATURAL_EARTH_PATH}') AS n WHERE n.geometry IS NOT NULL AND n.names."primary" IS NOT NULL AND trim(n.names."primary") != '' AND n.subtype IN ('sea', 'ocean') AND ST_Intersects(src.geometry, n.geometry) ) """ df = _stratified_containment( con, countries, limit, relation_type="coastal_containment", extra_where=extra_where, ) print(f"Found {len(df)} coastal containment pairs") return df def compute_landlocked_containment_pairs( con: duckdb.DuckDBPyConnection, countries: list, limit: int, ) -> pd.DataFrame: """Stratified containment pairs limited to landlocked-country containers. Used by chained templates that need inland anchors. Stratification by subtype combo ensures county-level pairs are actually present in the output instead of being starved by coarse country->region pairs. """ print("\nComputing landlocked containment pairs (stratified)...") extra_where = f""" AND NOT EXISTS ( SELECT 1 FROM read_parquet('{NATURAL_EARTH_PATH}') AS n WHERE n.geometry IS NOT NULL AND n.names."primary" IS NOT NULL AND trim(n.names."primary") != '' AND n.subtype IN ('sea', 'ocean') AND ST_Intersects(src.geometry, n.geometry) ) """ df = _stratified_containment( con, countries, limit, relation_type="landlocked_containment", extra_where=extra_where, ) print(f"Found {len(df)} landlocked containment pairs") return df def compute_common_neighbor_pairs( con: duckdb.DuckDBPyConnection, countries: list, limit: int, ) -> pd.DataFrame: """Pairs of anchors that share at least one common touching neighbour. Used by multi_adj_01 (borders both X and Y) so that the generated SQL is guaranteed to return at least one result rather than failing constantly on random pairs that have no common neighbour. Derived by self-joining adjacency_pairs on the shared target_id. """ print("\nComputing common-neighbor pairs...") adj_path = Path(__file__).parent.parent / "intermediate" / "adjacency_pairs.parquet" if not adj_path.exists(): print(" adjacency_pairs.parquet not found — skipping (run adjacency first)") return pd.DataFrame(columns=[ "anchor_id_1", "anchor_name_1", "anchor_subtype_1", "anchor_id_2", "anchor_name_2", "anchor_subtype_2", "shared_neighbor_id", "shared_neighbor_name", "shared_neighbor_subtype", ]) query = """ WITH undirected AS ( SELECT anchor_id, anchor_name, anchor_subtype, target_id, target_name, target_subtype FROM read_parquet(?) UNION ALL SELECT target_id AS anchor_id, target_name AS anchor_name, target_subtype AS anchor_subtype, anchor_id AS target_id, anchor_name AS target_name, anchor_subtype AS target_subtype FROM read_parquet(?) ) SELECT DISTINCT a1.anchor_id AS anchor_id_1, a1.anchor_name AS anchor_name_1, a1.anchor_subtype AS anchor_subtype_1, a2.anchor_id AS anchor_id_2, a2.anchor_name AS anchor_name_2, a2.anchor_subtype AS anchor_subtype_2, a1.target_id AS shared_neighbor_id, a1.target_name AS shared_neighbor_name, a1.target_subtype AS shared_neighbor_subtype FROM undirected AS a1 JOIN undirected AS a2 ON a1.target_id = a2.target_id AND a1.anchor_id < a2.anchor_id LIMIT ? """ df = con.execute(query, [str(adj_path), str(adj_path), limit]).fetchdf() print(f"Found {len(df)} common-neighbor pairs") return df def _make_connection(): """Create a new DuckDB connection with spatial extension loaded.""" con = duckdb.connect() con.execute("INSTALL spatial") con.execute("LOAD spatial") memory_limit = os.environ.get("GAZET_DUCKDB_MEMORY_LIMIT", "12GB") threads = int(os.environ.get("GAZET_DUCKDB_THREADS", "1")) con.execute(f"SET memory_limit='{memory_limit}'") con.execute("SET temp_directory='/tmp/duckdb_tmp'") con.execute(f"SET threads={threads}") return con def _compute_and_save(compute_fn, countries, limit, output_path): """Compute a relation table and save it to parquet. Uses its own DuckDB connection.""" con = _make_connection() try: df = compute_fn(con, countries, limit) df.to_parquet(output_path, index=False) print(f"Saved to {output_path}") return df finally: con.close() RELATION_FUNCTIONS = { "adjacency": compute_adjacency_pairs, "containment": compute_containment_pairs, "intersection": compute_intersection_pairs, "cross_source": compute_cross_source_relations, "coastal_containment": compute_coastal_containment_pairs, "landlocked_containment": compute_landlocked_containment_pairs, "common_neighbor": compute_common_neighbor_pairs, } # Single source of truth for the on-disk filename for each relation. # Both local and Modal paths must use this so the sample generator loads # the same file regardless of where the pipeline ran. RELATION_FILENAMES = { "adjacency": "adjacency_pairs.parquet", "containment": "containment_pairs.parquet", "intersection": "intersection_pairs.parquet", "cross_source": "cross_source_relations.parquet", "coastal_containment": "coastal_containment_pairs.parquet", "landlocked_containment": "landlocked_containment_pairs.parquet", "common_neighbor": "common_neighbor_pairs.parquet", } def compute_single_relation( relation_type: str, countries: list, limit: int, output_dir: Path, ) -> int: """Compute one relation type and save to output_dir. Returns the number of rows computed. Usable from Modal or locally. """ compute_fn = RELATION_FUNCTIONS.get(relation_type) if compute_fn is None: raise ValueError( f"Unknown relation type: {relation_type}. " f"Expected one of {list(RELATION_FUNCTIONS)}" ) output_dir.mkdir(exist_ok=True, parents=True) output_path = output_dir / RELATION_FILENAMES[relation_type] df = _compute_and_save(compute_fn, countries, limit, output_path) return len(df) def main(countries: list = None, relation_limits: dict = None): """Compute and save all relation tables in parallel. Args: countries: List of country codes to process relation_limits: Dict with keys: adjacency, containment, intersection, cross_source """ # Defaults if countries is None: countries = ['EC', 'BE', 'KE', 'AE', 'SG', 'CH'] if relation_limits is None: relation_limits = { 'adjacency': 50000, 'containment': 3000, 'intersection': 3000, 'cross_source': 1800, 'coastal_containment': 3000, 'landlocked_containment': 1500, 'common_neighbor': 5000, } output_dir = Path(__file__).parent.parent / "intermediate" output_dir.mkdir(exist_ok=True, parents=True) # Define all relation tasks. Filenames come from RELATION_FILENAMES so # local and Modal pipelines produce identically-named parquet files. # common_neighbor depends on adjacency_pairs so it runs after adjacency. tasks = [ (rel_type, RELATION_FUNCTIONS[rel_type], relation_limits[rel_type], output_dir / RELATION_FILENAMES[rel_type]) for rel_type in ( "adjacency", "containment", "intersection", "cross_source", "coastal_containment", "landlocked_containment", "common_neighbor", ) ] # common_neighbor reads adjacency_pairs.parquet so it must run after # adjacency finishes. Split into two waves. independent_tasks = [t for t in tasks if t[0] != "common_neighbor"] dependent_tasks = [t for t in tasks if t[0] == "common_neighbor"] print(f"Computing {len(independent_tasks)} relation types in parallel...") with ThreadPoolExecutor(max_workers=len(independent_tasks)) as executor: futures = { executor.submit(_compute_and_save, compute_fn, countries, limit, path): name for name, compute_fn, limit, path in independent_tasks } for future in as_completed(futures): name = futures[future] try: future.result() except Exception as e: print(f"ERROR computing {name}: {e}") raise for name, compute_fn, limit, path in dependent_tasks: print(f"\nComputing {name} (depends on adjacency)...") try: _compute_and_save(compute_fn, countries, limit, path) except Exception as e: print(f"ERROR computing {name}: {e}") raise print("\nRelation tables build complete") if __name__ == "__main__": main()