Spaces:
Running
Running
| """ | |
| Precompute spatial relation tables for efficient anchor sampling. | |
| This script computes: | |
| - Adjacency pairs (touching features) | |
| - Containment pairs (features within other features) | |
| - Intersection pairs (overlapping features) | |
| - Cross-source relations (divisions_area ↔ natural_earth) | |
| Output: | |
| - intermediate/adjacency_pairs.parquet | |
| - intermediate/containment_pairs.parquet | |
| - intermediate/intersection_pairs.parquet | |
| - intermediate/cross_source_relations.parquet | |
| """ | |
| import os | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| import duckdb | |
| import pandas as pd | |
| from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH | |
| # (container_subtype, contained_subtype) combos used by the chained, | |
| # containment, and disambiguation templates. The normalized Overture input is | |
| # now restricted to country / region / county, so keep relation building in | |
| # sync and avoid wasting work on removed subtypes. | |
| _CONTAINMENT_SUBTYPE_PAIRS = ( | |
| ("country", "region"), | |
| ("country", "county"), | |
| ("region", "county"), | |
| ) | |
| # Natural Earth subtype vocabulary normalized to lowercase. | |
| # We lowercase the source subtype values while building relation tables so | |
| # mixed casing in upstream data (e.g. Lake vs lake, Range/mtn vs range/mtn) | |
| # does not fragment anchor pools or break template matching. | |
| _NE_CROSS_SOURCE_SUBTYPES = ( | |
| "sea", | |
| "ocean", | |
| "lake", | |
| "river", | |
| "basin", | |
| "gulf", | |
| "bay", | |
| "strait", | |
| "island group", | |
| "peninsula", | |
| "range/mtn", | |
| "plateau", | |
| "plain", | |
| "lowland", | |
| "valley", | |
| "depression", | |
| "gorge", | |
| ) | |
| _DIVISION_SUBTYPES = ("country", "region", "county") | |
| def _country_filter(countries: list) -> tuple[str, list]: | |
| """Return (SQL WHERE clause, params) handling 'all' sentinel.""" | |
| if countries == ["all"]: | |
| return "", [] | |
| return "WHERE country IN (SELECT unnest(?))", [countries] | |
| def _country_filter_for_join(countries: list) -> tuple[str, list]: | |
| """Return a country filter for self-joins over normalized admin data.""" | |
| if countries == ["all"]: | |
| return "", [] | |
| return "WHERE country IN (SELECT unnest(?))", [countries] | |
| def _country_chunks( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| chunk_size: int = 40, | |
| ) -> list[list[str]]: | |
| """Return explicit country batches for safer global containment joins.""" | |
| if countries != ["all"]: | |
| return [countries] | |
| rows = con.execute( | |
| """ | |
| SELECT DISTINCT country | |
| FROM read_parquet(?) | |
| WHERE country IS NOT NULL | |
| AND trim(country) != '' | |
| ORDER BY country | |
| """, | |
| [DIVISIONS_AREA_PATH], | |
| ).fetchall() | |
| codes = [row[0] for row in rows] | |
| return [codes[i:i + chunk_size] for i in range(0, len(codes), chunk_size)] | |
| def compute_adjacency_pairs( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int | |
| ) -> pd.DataFrame: | |
| """Find all pairs of features that touch (share a boundary).""" | |
| print("Computing adjacency pairs (optimized with spatial index)...") | |
| cfilter, cparams = _country_filter_for_join(countries) | |
| # Use bounding box pre-filter to avoid full cartesian product | |
| query = f""" | |
| WITH features AS ( | |
| SELECT | |
| id, | |
| names."primary" AS name, | |
| subtype, | |
| country, | |
| admin_level, | |
| geometry, | |
| ST_Envelope(geometry) AS bbox | |
| FROM read_parquet(?) | |
| {cfilter} | |
| ) | |
| SELECT | |
| a.id AS anchor_id, | |
| a.name AS anchor_name, | |
| a.subtype AS anchor_subtype, | |
| a.country AS anchor_country, | |
| b.id AS target_id, | |
| b.name AS target_name, | |
| b.subtype AS target_subtype, | |
| b.country AS target_country, | |
| 'adjacency' AS relation_type | |
| FROM features AS a | |
| JOIN features AS b ON ( | |
| a.id < b.id | |
| AND ST_Intersects(a.bbox, b.bbox) | |
| AND ST_Touches(a.geometry, b.geometry) | |
| ) | |
| LIMIT ? | |
| """ | |
| df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf() | |
| print(f"Found {len(df)} adjacency pairs") | |
| return df | |
| def _stratified_containment( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int, | |
| relation_type: str, | |
| extra_where: str = "", | |
| extra_params: list = None, | |
| ) -> pd.DataFrame: | |
| """Compute containment pairs stratified by (container_subtype, contained_subtype). | |
| A single global self-join with LIMIT fills up with coarse country->region | |
| pairs before emitting country->county and region->county pairs. We run one | |
| focused query per subtype combo instead so every combo receives a fair | |
| share of the overall limit. | |
| ``extra_where`` / ``extra_params`` let the coastal and landlocked variants | |
| inject their country-set filter without duplicating the whole body. | |
| """ | |
| extra_params = extra_params or [] | |
| # Use a lower target per subtype combo for global runs; they are the most | |
| # memory-intensive part of the pipeline and don't need huge anchor tables. | |
| if countries == ["all"]: | |
| per_combo = min(max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100), 1500) | |
| else: | |
| per_combo = max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100) | |
| country_batches = _country_chunks(con, countries) | |
| frames: list[pd.DataFrame] = [] | |
| for container_st, contained_st in _CONTAINMENT_SUBTYPE_PAIRS: | |
| remaining = per_combo | |
| combo_parts: list[pd.DataFrame] = [] | |
| for batch in country_batches: | |
| if remaining <= 0: | |
| break | |
| cfilter, cparams = _country_filter(batch) | |
| query = f""" | |
| WITH a AS ( | |
| SELECT src.id, src.names."primary" AS name, src.subtype, src.country, src.admin_level, | |
| src.geometry, ST_Envelope(src.geometry) AS bbox | |
| FROM read_parquet(?) AS src | |
| WHERE src.subtype = '{container_st}' | |
| {cfilter.replace("WHERE", "AND") if cfilter else ""} | |
| {extra_where} | |
| ), | |
| b AS ( | |
| SELECT dst.id, dst.names."primary" AS name, dst.subtype, dst.country, dst.admin_level, | |
| dst.geometry, ST_Envelope(dst.geometry) AS bbox | |
| FROM read_parquet(?) AS dst | |
| WHERE dst.subtype = '{contained_st}' | |
| {cfilter.replace("WHERE", "AND") if cfilter else ""} | |
| ) | |
| SELECT | |
| a.id AS container_id, | |
| a.name AS container_name, | |
| a.subtype AS container_subtype, | |
| b.id AS contained_id, | |
| b.name AS contained_name, | |
| b.subtype AS contained_subtype, | |
| a.country AS container_country, | |
| '{relation_type}' AS relation_type | |
| FROM a JOIN b ON ( | |
| a.id != b.id | |
| AND ST_Intersects(a.bbox, b.bbox) | |
| AND ST_Within(b.geometry, a.geometry) | |
| ) | |
| LIMIT {remaining} | |
| """ | |
| params = [DIVISIONS_AREA_PATH] + extra_params + cparams + [DIVISIONS_AREA_PATH] + cparams | |
| df_part = con.execute(query, params).fetchdf() | |
| if not df_part.empty: | |
| combo_parts.append(df_part) | |
| remaining -= len(df_part) | |
| df_combo = ( | |
| pd.concat(combo_parts, ignore_index=True) | |
| if combo_parts else pd.DataFrame() | |
| ) | |
| print(f" {relation_type} {container_st:>10s} -> {contained_st:<10s}: {len(df_combo)} pairs") | |
| frames.append(df_combo) | |
| return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() | |
| def compute_containment_pairs( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int | |
| ) -> pd.DataFrame: | |
| """Find containment pairs stratified across admin-level combinations.""" | |
| print("\nComputing containment pairs (stratified by subtype combo)...") | |
| df = _stratified_containment(con, countries, limit, relation_type="containment") | |
| print(f"Found {len(df)} containment pairs") | |
| return df | |
| def compute_intersection_pairs( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int | |
| ) -> pd.DataFrame: | |
| """Find pairs that intersect but don't touch or contain.""" | |
| print("\nComputing intersection pairs (optimized)...") | |
| cfilter, cparams = _country_filter_for_join(countries) | |
| query = f""" | |
| WITH features AS ( | |
| SELECT | |
| id, | |
| names."primary" AS name, | |
| subtype, | |
| country, | |
| admin_level, | |
| geometry, | |
| ST_Envelope(geometry) AS bbox | |
| FROM read_parquet(?) | |
| {cfilter} | |
| ) | |
| SELECT | |
| a.id AS anchor_id, | |
| a.name AS anchor_name, | |
| a.subtype AS anchor_subtype, | |
| b.id AS target_id, | |
| b.name AS target_name, | |
| b.subtype AS target_subtype, | |
| 'intersection' AS relation_type | |
| FROM features AS a | |
| JOIN features AS b ON ( | |
| a.id < b.id | |
| AND ST_Intersects(a.bbox, b.bbox) | |
| AND ST_Intersects(a.geometry, b.geometry) | |
| AND NOT ST_Touches(a.geometry, b.geometry) | |
| AND NOT ST_Within(a.geometry, b.geometry) | |
| AND NOT ST_Within(b.geometry, a.geometry) | |
| ) | |
| LIMIT ? | |
| """ | |
| df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf() | |
| print(f"Found {len(df)} same-source intersection pairs") | |
| return df | |
| def compute_cross_source_relations( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int | |
| ) -> pd.DataFrame: | |
| """Find relations between divisions_area and natural_earth. | |
| The join is skewed both by abundant Natural Earth subtypes and by coarse | |
| admin features. We therefore stratify by (natural_subtype, | |
| division_subtype) so country / region / county anchors all make it into | |
| the pool used by mixed-source, NE-intersection, and NE-adjacency templates. | |
| """ | |
| print("\nComputing cross-source relations (stratified by NE subtype and admin subtype)...") | |
| cfilter, cparams = _country_filter(countries) | |
| num_combos = len(_NE_CROSS_SOURCE_SUBTYPES) * len(_DIVISION_SUBTYPES) | |
| per_combo = max(limit // num_combos, 10) | |
| frames: list[pd.DataFrame] = [] | |
| for natural_subtype in _NE_CROSS_SOURCE_SUBTYPES: | |
| for division_subtype in _DIVISION_SUBTYPES: | |
| query = f""" | |
| WITH divisions AS ( | |
| SELECT | |
| id, | |
| names."primary" AS name, | |
| subtype, | |
| country, | |
| geometry | |
| FROM read_parquet(?) | |
| WHERE geometry IS NOT NULL | |
| AND names."primary" IS NOT NULL | |
| AND trim(names."primary") != '' | |
| AND subtype = '{division_subtype}' | |
| {cfilter.replace("WHERE", "AND") if cfilter else ''} | |
| ), | |
| natural_features AS ( | |
| SELECT | |
| id, | |
| names."primary" AS name, | |
| lower(subtype) AS natural_subtype, | |
| geometry | |
| FROM read_parquet(?) | |
| WHERE geometry IS NOT NULL | |
| AND names."primary" IS NOT NULL | |
| AND trim(names."primary") != '' | |
| AND lower(subtype) = '{natural_subtype}' | |
| ) | |
| SELECT | |
| d.id AS division_id, | |
| d.name AS division_name, | |
| d.subtype AS division_subtype, | |
| d.country AS division_country, | |
| n.id AS natural_id, | |
| n.name AS natural_name, | |
| n.natural_subtype AS natural_subtype, | |
| CASE | |
| WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches' | |
| WHEN ST_Within(d.geometry, n.geometry) THEN 'within' | |
| WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains' | |
| WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects' | |
| END AS relation_type | |
| FROM divisions AS d | |
| JOIN natural_features AS n | |
| ON ST_Intersects(d.geometry, n.geometry) | |
| LIMIT {per_combo} | |
| """ | |
| df_part = con.execute( | |
| query, | |
| [DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH], | |
| ).fetchdf() | |
| print( | |
| f" cross_source {natural_subtype:>12s} x {division_subtype:<7s}: " | |
| f"{len(df_part)} rows" | |
| ) | |
| frames.append(df_part) | |
| df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() | |
| print(f"Found {len(df)} cross-source relations") | |
| return df | |
| def compute_coastal_containment_pairs( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int, | |
| ) -> pd.DataFrame: | |
| """Stratified containment pairs limited to coastal-country containers. | |
| Used by chained templates so sampled anchors actually have sea-adjacent | |
| sub-features. Stratification guarantees coverage of every supported | |
| admin-level combination (country->region, country->county, region->county). | |
| """ | |
| print("\nComputing coastal containment pairs (stratified)...") | |
| extra_where = f""" | |
| AND EXISTS ( | |
| SELECT 1 | |
| FROM read_parquet('{NATURAL_EARTH_PATH}') AS n | |
| WHERE n.geometry IS NOT NULL | |
| AND n.names."primary" IS NOT NULL | |
| AND trim(n.names."primary") != '' | |
| AND n.subtype IN ('sea', 'ocean') | |
| AND ST_Intersects(src.geometry, n.geometry) | |
| ) | |
| """ | |
| df = _stratified_containment( | |
| con, countries, limit, | |
| relation_type="coastal_containment", | |
| extra_where=extra_where, | |
| ) | |
| print(f"Found {len(df)} coastal containment pairs") | |
| return df | |
| def compute_landlocked_containment_pairs( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int, | |
| ) -> pd.DataFrame: | |
| """Stratified containment pairs limited to landlocked-country containers. | |
| Used by chained templates that need inland anchors. Stratification by | |
| subtype combo ensures county-level pairs are actually present in the | |
| output instead of being starved by coarse country->region pairs. | |
| """ | |
| print("\nComputing landlocked containment pairs (stratified)...") | |
| extra_where = f""" | |
| AND NOT EXISTS ( | |
| SELECT 1 | |
| FROM read_parquet('{NATURAL_EARTH_PATH}') AS n | |
| WHERE n.geometry IS NOT NULL | |
| AND n.names."primary" IS NOT NULL | |
| AND trim(n.names."primary") != '' | |
| AND n.subtype IN ('sea', 'ocean') | |
| AND ST_Intersects(src.geometry, n.geometry) | |
| ) | |
| """ | |
| df = _stratified_containment( | |
| con, countries, limit, | |
| relation_type="landlocked_containment", | |
| extra_where=extra_where, | |
| ) | |
| print(f"Found {len(df)} landlocked containment pairs") | |
| return df | |
| def compute_common_neighbor_pairs( | |
| con: duckdb.DuckDBPyConnection, | |
| countries: list, | |
| limit: int, | |
| ) -> pd.DataFrame: | |
| """Pairs of anchors that share at least one common touching neighbour. | |
| Used by multi_adj_01 (borders both X and Y) so that the generated SQL | |
| is guaranteed to return at least one result rather than failing constantly | |
| on random pairs that have no common neighbour. | |
| Derived by self-joining adjacency_pairs on the shared target_id. | |
| """ | |
| print("\nComputing common-neighbor pairs...") | |
| adj_path = Path(__file__).parent.parent / "intermediate" / "adjacency_pairs.parquet" | |
| if not adj_path.exists(): | |
| print(" adjacency_pairs.parquet not found — skipping (run adjacency first)") | |
| return pd.DataFrame(columns=[ | |
| "anchor_id_1", "anchor_name_1", "anchor_subtype_1", | |
| "anchor_id_2", "anchor_name_2", "anchor_subtype_2", | |
| "shared_neighbor_id", "shared_neighbor_name", "shared_neighbor_subtype", | |
| ]) | |
| query = """ | |
| WITH undirected AS ( | |
| SELECT | |
| anchor_id, | |
| anchor_name, | |
| anchor_subtype, | |
| target_id, | |
| target_name, | |
| target_subtype | |
| FROM read_parquet(?) | |
| UNION ALL | |
| SELECT | |
| target_id AS anchor_id, | |
| target_name AS anchor_name, | |
| target_subtype AS anchor_subtype, | |
| anchor_id AS target_id, | |
| anchor_name AS target_name, | |
| anchor_subtype AS target_subtype | |
| FROM read_parquet(?) | |
| ) | |
| SELECT DISTINCT | |
| a1.anchor_id AS anchor_id_1, | |
| a1.anchor_name AS anchor_name_1, | |
| a1.anchor_subtype AS anchor_subtype_1, | |
| a2.anchor_id AS anchor_id_2, | |
| a2.anchor_name AS anchor_name_2, | |
| a2.anchor_subtype AS anchor_subtype_2, | |
| a1.target_id AS shared_neighbor_id, | |
| a1.target_name AS shared_neighbor_name, | |
| a1.target_subtype AS shared_neighbor_subtype | |
| FROM undirected AS a1 | |
| JOIN undirected AS a2 | |
| ON a1.target_id = a2.target_id | |
| AND a1.anchor_id < a2.anchor_id | |
| LIMIT ? | |
| """ | |
| df = con.execute(query, [str(adj_path), str(adj_path), limit]).fetchdf() | |
| print(f"Found {len(df)} common-neighbor pairs") | |
| return df | |
| def _make_connection(): | |
| """Create a new DuckDB connection with spatial extension loaded.""" | |
| con = duckdb.connect() | |
| con.execute("INSTALL spatial") | |
| con.execute("LOAD spatial") | |
| memory_limit = os.environ.get("GAZET_DUCKDB_MEMORY_LIMIT", "12GB") | |
| threads = int(os.environ.get("GAZET_DUCKDB_THREADS", "1")) | |
| con.execute(f"SET memory_limit='{memory_limit}'") | |
| con.execute("SET temp_directory='/tmp/duckdb_tmp'") | |
| con.execute(f"SET threads={threads}") | |
| return con | |
| def _compute_and_save(compute_fn, countries, limit, output_path): | |
| """Compute a relation table and save it to parquet. Uses its own DuckDB connection.""" | |
| con = _make_connection() | |
| try: | |
| df = compute_fn(con, countries, limit) | |
| df.to_parquet(output_path, index=False) | |
| print(f"Saved to {output_path}") | |
| return df | |
| finally: | |
| con.close() | |
| RELATION_FUNCTIONS = { | |
| "adjacency": compute_adjacency_pairs, | |
| "containment": compute_containment_pairs, | |
| "intersection": compute_intersection_pairs, | |
| "cross_source": compute_cross_source_relations, | |
| "coastal_containment": compute_coastal_containment_pairs, | |
| "landlocked_containment": compute_landlocked_containment_pairs, | |
| "common_neighbor": compute_common_neighbor_pairs, | |
| } | |
| # Single source of truth for the on-disk filename for each relation. | |
| # Both local and Modal paths must use this so the sample generator loads | |
| # the same file regardless of where the pipeline ran. | |
| RELATION_FILENAMES = { | |
| "adjacency": "adjacency_pairs.parquet", | |
| "containment": "containment_pairs.parquet", | |
| "intersection": "intersection_pairs.parquet", | |
| "cross_source": "cross_source_relations.parquet", | |
| "coastal_containment": "coastal_containment_pairs.parquet", | |
| "landlocked_containment": "landlocked_containment_pairs.parquet", | |
| "common_neighbor": "common_neighbor_pairs.parquet", | |
| } | |
| def compute_single_relation( | |
| relation_type: str, | |
| countries: list, | |
| limit: int, | |
| output_dir: Path, | |
| ) -> int: | |
| """Compute one relation type and save to output_dir. | |
| Returns the number of rows computed. Usable from Modal or locally. | |
| """ | |
| compute_fn = RELATION_FUNCTIONS.get(relation_type) | |
| if compute_fn is None: | |
| raise ValueError( | |
| f"Unknown relation type: {relation_type}. " | |
| f"Expected one of {list(RELATION_FUNCTIONS)}" | |
| ) | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| output_path = output_dir / RELATION_FILENAMES[relation_type] | |
| df = _compute_and_save(compute_fn, countries, limit, output_path) | |
| return len(df) | |
| def main(countries: list = None, relation_limits: dict = None): | |
| """Compute and save all relation tables in parallel. | |
| Args: | |
| countries: List of country codes to process | |
| relation_limits: Dict with keys: adjacency, containment, intersection, cross_source | |
| """ | |
| # Defaults | |
| if countries is None: | |
| countries = ['EC', 'BE', 'KE', 'AE', 'SG', 'CH'] | |
| if relation_limits is None: | |
| relation_limits = { | |
| 'adjacency': 50000, | |
| 'containment': 3000, | |
| 'intersection': 3000, | |
| 'cross_source': 1800, | |
| 'coastal_containment': 3000, | |
| 'landlocked_containment': 1500, | |
| 'common_neighbor': 5000, | |
| } | |
| output_dir = Path(__file__).parent.parent / "intermediate" | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| # Define all relation tasks. Filenames come from RELATION_FILENAMES so | |
| # local and Modal pipelines produce identically-named parquet files. | |
| # common_neighbor depends on adjacency_pairs so it runs after adjacency. | |
| tasks = [ | |
| (rel_type, RELATION_FUNCTIONS[rel_type], relation_limits[rel_type], output_dir / RELATION_FILENAMES[rel_type]) | |
| for rel_type in ( | |
| "adjacency", "containment", "intersection", "cross_source", | |
| "coastal_containment", "landlocked_containment", "common_neighbor", | |
| ) | |
| ] | |
| # common_neighbor reads adjacency_pairs.parquet so it must run after | |
| # adjacency finishes. Split into two waves. | |
| independent_tasks = [t for t in tasks if t[0] != "common_neighbor"] | |
| dependent_tasks = [t for t in tasks if t[0] == "common_neighbor"] | |
| print(f"Computing {len(independent_tasks)} relation types in parallel...") | |
| with ThreadPoolExecutor(max_workers=len(independent_tasks)) as executor: | |
| futures = { | |
| executor.submit(_compute_and_save, compute_fn, countries, limit, path): name | |
| for name, compute_fn, limit, path in independent_tasks | |
| } | |
| for future in as_completed(futures): | |
| name = futures[future] | |
| try: | |
| future.result() | |
| except Exception as e: | |
| print(f"ERROR computing {name}: {e}") | |
| raise | |
| for name, compute_fn, limit, path in dependent_tasks: | |
| print(f"\nComputing {name} (depends on adjacency)...") | |
| try: | |
| _compute_and_save(compute_fn, countries, limit, path) | |
| except Exception as e: | |
| print(f"ERROR computing {name}: {e}") | |
| raise | |
| print("\nRelation tables build complete") | |
| if __name__ == "__main__": | |
| main() | |