Spaces:
Running
Running
Fix: No pairs are created for mixed queries
Browse files- Normalize overture & natural earth to EPSG:4326
- Add better adjancency matrix to increase the overlap in query pairs
- Fix titled & lower case subtype names in NE
- Reduce number of threads in duckdb to prevent memory issues
- dataset/README.md +40 -1
- dataset/config.smalltest.yaml +40 -0
- dataset/config.yaml +16 -16
- dataset/modal_app.py +47 -11
- dataset/scripts/build_inventory.py +6 -4
- dataset/scripts/build_relations.py +274 -201
- dataset/scripts/cli.py +20 -2
- dataset/scripts/export_training_data.py +1 -1
- dataset/scripts/generate_samples.py +320 -132
- dataset/scripts/normalize_geodata.py +87 -0
- dataset/scripts/sql_templates.py +18 -18
- dataset/scripts/validate_dataset.py +2 -2
- finetune/README.md +29 -11
- finetune/eval_demo.py +12 -6
- finetune/train_modal_qwen35.py +1 -1
- gazet_demo.py +81 -10
- ingest/convert_natural_earth.py +1 -1
- src/gazet/config.py +27 -2
- src/gazet/sql.py +25 -0
dataset/README.md
CHANGED
|
@@ -20,6 +20,20 @@ uv sync
|
|
| 20 |
You need the Overture and Natural Earth parquet files under `data/` locally,
|
| 21 |
or on a Modal volume if running in the cloud.
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
---
|
| 24 |
|
| 25 |
## Option A — Run locally (small datasets, development)
|
|
@@ -72,6 +86,14 @@ Modal uses two volumes:
|
|
| 72 |
|
| 73 |
**Step 1 — One-time setup (only first time, or when source parquets change)**
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
```bash
|
| 76 |
modal setup # authenticate
|
| 77 |
gazet-dataset modal-upload --config dataset/config.yaml # ~15 min, uploads data/ to gazet-data volume
|
|
@@ -81,7 +103,7 @@ Verify:
|
|
| 81 |
|
| 82 |
```bash
|
| 83 |
modal volume ls gazet-data
|
| 84 |
-
# should show: overture/,
|
| 85 |
```
|
| 86 |
|
| 87 |
Skip this step on subsequent runs — the volume persists across runs.
|
|
@@ -207,6 +229,23 @@ by default; pass `--fresh` to overwrite existing samples.
|
|
| 207 |
|
| 208 |
---
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
## Troubleshooting
|
| 211 |
|
| 212 |
**Very few samples generated for a family**
|
|
|
|
| 20 |
You need the Overture and Natural Earth parquet files under `data/` locally,
|
| 21 |
or on a Modal volume if running in the cloud.
|
| 22 |
|
| 23 |
+
Before large runs, normalize them once so both datasets use harmonized
|
| 24 |
+
geometry metadata and cross-source joins behave the same locally and on Modal:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
gazet-dataset normalize-data --config dataset/config.yaml
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
This writes:
|
| 31 |
+
|
| 32 |
+
- `data/overture_normalized/divisions_area/part-000.parquet`
|
| 33 |
+
- `data/natural_earth_normalized/ne_geography.parquet`
|
| 34 |
+
|
| 35 |
+
When those files exist, `gazet.config` will prefer them automatically.
|
| 36 |
+
|
| 37 |
---
|
| 38 |
|
| 39 |
## Option A — Run locally (small datasets, development)
|
|
|
|
| 86 |
|
| 87 |
**Step 1 — One-time setup (only first time, or when source parquets change)**
|
| 88 |
|
| 89 |
+
First normalize the source geodata locally:
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
gazet-dataset normalize-data --config dataset/config.yaml
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Then upload `data/` to Modal:
|
| 96 |
+
|
| 97 |
```bash
|
| 98 |
modal setup # authenticate
|
| 99 |
gazet-dataset modal-upload --config dataset/config.yaml # ~15 min, uploads data/ to gazet-data volume
|
|
|
|
| 103 |
|
| 104 |
```bash
|
| 105 |
modal volume ls gazet-data
|
| 106 |
+
# should show: overture/, overture_normalized/, natural_earth_geoparquet/, natural_earth_normalized/
|
| 107 |
```
|
| 108 |
|
| 109 |
Skip this step on subsequent runs — the volume persists across runs.
|
|
|
|
| 229 |
|
| 230 |
---
|
| 231 |
|
| 232 |
+
## Data quality checks
|
| 233 |
+
|
| 234 |
+
After a run, spot-check the output with the pytest suite:
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
uv run --extra dev pytest dataset/tests/ -v
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
The suite reads `dataset/output/dataset_validated.jsonl` plus the exported
|
| 241 |
+
`runs/{run_name}/*.jsonl` and verifies: schema, no unresolved `{placeholders}`
|
| 242 |
+
in questions, candidate refs resolve, SQL shape, template coverage,
|
| 243 |
+
subtype-filtered templates match their phrasing, disambiguation samples have
|
| 244 |
+
same-name distractors, and exported assistant payloads parse as valid
|
| 245 |
+
JSON / SQL. Tests skip gracefully when outputs are missing.
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
## Troubleshooting
|
| 250 |
|
| 251 |
**Very few samples generated for a family**
|
dataset/config.smalltest.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
countries:
|
| 2 |
+
- BE
|
| 3 |
+
- CH
|
| 4 |
+
- AE
|
| 5 |
+
|
| 6 |
+
sample_targets:
|
| 7 |
+
direct_lookup: 4
|
| 8 |
+
disambiguation: 6
|
| 9 |
+
adjacency: 12
|
| 10 |
+
multi_adjacency: 2
|
| 11 |
+
containment: 8
|
| 12 |
+
intersection: 8
|
| 13 |
+
buffer: 10
|
| 14 |
+
chained: 22
|
| 15 |
+
difference: 4
|
| 16 |
+
border_corridor: 2
|
| 17 |
+
set_operations: 10
|
| 18 |
+
partial_selection: 10
|
| 19 |
+
aggregation: 8
|
| 20 |
+
window_function: 4
|
| 21 |
+
attribute_filter: 6
|
| 22 |
+
|
| 23 |
+
generation:
|
| 24 |
+
max_workers: 4
|
| 25 |
+
retry_multiplier: 2
|
| 26 |
+
append_mode: false
|
| 27 |
+
|
| 28 |
+
auto_scaling:
|
| 29 |
+
safety_factor: 1.25
|
| 30 |
+
manual_limits: {}
|
| 31 |
+
|
| 32 |
+
modal:
|
| 33 |
+
volume_name: "gazet-data"
|
| 34 |
+
app_name: "gazet-dataset"
|
| 35 |
+
num_containers: 10
|
| 36 |
+
container_cpu: 2
|
| 37 |
+
container_memory: 4096
|
| 38 |
+
timeout: 7200
|
| 39 |
+
|
| 40 |
+
run_name: "smalltest-v1"
|
dataset/config.yaml
CHANGED
|
@@ -17,21 +17,21 @@ countries:
|
|
| 17 |
# Bumped families with many templates or mixed-source variants so each
|
| 18 |
# template_id gets enough coverage after uniform sampling + stratified split.
|
| 19 |
sample_targets:
|
| 20 |
-
direct_lookup:
|
| 21 |
-
disambiguation:
|
| 22 |
-
adjacency:
|
| 23 |
-
multi_adjacency:
|
| 24 |
-
containment:
|
| 25 |
-
intersection:
|
| 26 |
-
buffer:
|
| 27 |
-
chained:
|
| 28 |
-
difference:
|
| 29 |
-
border_corridor:
|
| 30 |
-
set_operations:
|
| 31 |
-
partial_selection:
|
| 32 |
-
aggregation:
|
| 33 |
-
window_function:
|
| 34 |
-
attribute_filter:
|
| 35 |
|
| 36 |
# Generation settings
|
| 37 |
generation:
|
|
@@ -58,7 +58,7 @@ modal:
|
|
| 58 |
num_containers: 100 # Number of parallel containers for sample generation
|
| 59 |
container_cpu: 2 # CPUs per container
|
| 60 |
container_memory: 4096 # Memory (MB) per container
|
| 61 |
-
timeout:
|
| 62 |
|
| 63 |
# Run name — used to version exported splits so re-runs never overwrite previous data.
|
| 64 |
# Change this whenever you regenerate from scratch (e.g. after template changes).
|
|
|
|
| 17 |
# Bumped families with many templates or mixed-source variants so each
|
| 18 |
# template_id gets enough coverage after uniform sampling + stratified split.
|
| 19 |
sample_targets:
|
| 20 |
+
direct_lookup: 1000
|
| 21 |
+
disambiguation: 2000 # 3 templates (disambiguate_01..03) - "Puri, Odisha" pattern
|
| 22 |
+
adjacency: 2000 # 6 templates (adj_01..06) - adj_06 is counties
|
| 23 |
+
multi_adjacency: 1000
|
| 24 |
+
containment: 2000 # 4 templates (contain_01..04) - contain_02 reversed, contain_03/04 NE anchor
|
| 25 |
+
intersection: 2000 # 4 templates (intersect_01..04) - intersect_02/03 NE anchor
|
| 26 |
+
buffer: 2000 # 5 templates (buffer_01..05)
|
| 27 |
+
chained: 2000 # 11 templates (chained_01..11) - 10/11 are coastal/inland regions
|
| 28 |
+
difference: 2000 # 2 templates, one is mixed (diff_02)
|
| 29 |
+
border_corridor: 1000
|
| 30 |
+
set_operations: 2000
|
| 31 |
+
partial_selection: 2000 # 5 templates, one is mixed (partial_05)
|
| 32 |
+
aggregation: 1500
|
| 33 |
+
window_function: 1000
|
| 34 |
+
attribute_filter: 1000 # 3 templates (attr_01..03)
|
| 35 |
|
| 36 |
# Generation settings
|
| 37 |
generation:
|
|
|
|
| 58 |
num_containers: 100 # Number of parallel containers for sample generation
|
| 59 |
container_cpu: 2 # CPUs per container
|
| 60 |
container_memory: 4096 # Memory (MB) per container
|
| 61 |
+
timeout: 7200 # Per-container timeout in seconds
|
| 62 |
|
| 63 |
# Run name — used to version exported splits so re-runs never overwrite previous data.
|
| 64 |
# Change this whenever you regenerate from scratch (e.g. after template changes).
|
dataset/modal_app.py
CHANGED
|
@@ -24,7 +24,16 @@ image = (
|
|
| 24 |
"pyarrow>=17.0.0",
|
| 25 |
"pyyaml>=6.0",
|
| 26 |
)
|
| 27 |
-
.env(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
.add_local_dir("src/gazet", "/root/gazet")
|
| 29 |
.add_local_dir("dataset", "/root/dataset")
|
| 30 |
)
|
|
@@ -50,9 +59,9 @@ def build_inventory_remote():
|
|
| 50 |
@app.function(
|
| 51 |
image=image,
|
| 52 |
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 53 |
-
timeout=
|
| 54 |
-
cpu=
|
| 55 |
-
memory=
|
| 56 |
)
|
| 57 |
def build_relation_remote(relation_type: str, countries: list, limit: int):
|
| 58 |
"""Compute one relation type and save to intermediate volume."""
|
|
@@ -72,7 +81,7 @@ def build_relation_remote(relation_type: str, countries: list, limit: int):
|
|
| 72 |
@app.function(
|
| 73 |
image=image,
|
| 74 |
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 75 |
-
timeout=
|
| 76 |
cpu=2,
|
| 77 |
memory=4096,
|
| 78 |
)
|
|
@@ -126,13 +135,40 @@ def run_pipeline(
|
|
| 126 |
|
| 127 |
relation_needs = calculate_relation_limits(config)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
for rel_type
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
print(f" {rel_type}: {result['count']} pairs")
|
| 137 |
|
| 138 |
print(f"Generating samples across {n_containers} containers...")
|
|
|
|
| 24 |
"pyarrow>=17.0.0",
|
| 25 |
"pyyaml>=6.0",
|
| 26 |
)
|
| 27 |
+
.env(
|
| 28 |
+
{
|
| 29 |
+
"GAZET_DATA_DIR": VOLUME_MOUNT,
|
| 30 |
+
"PYTHONPATH": "/root",
|
| 31 |
+
# Spatial self-joins are much more stable with conservative
|
| 32 |
+
# DuckDB settings inside Modal containers.
|
| 33 |
+
"GAZET_DUCKDB_THREADS": "1",
|
| 34 |
+
"GAZET_DUCKDB_MEMORY_LIMIT": "20GB",
|
| 35 |
+
}
|
| 36 |
+
)
|
| 37 |
.add_local_dir("src/gazet", "/root/gazet")
|
| 38 |
.add_local_dir("dataset", "/root/dataset")
|
| 39 |
)
|
|
|
|
| 59 |
@app.function(
|
| 60 |
image=image,
|
| 61 |
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 62 |
+
timeout=7200,
|
| 63 |
+
cpu=2,
|
| 64 |
+
memory=65536,
|
| 65 |
)
|
| 66 |
def build_relation_remote(relation_type: str, countries: list, limit: int):
|
| 67 |
"""Compute one relation type and save to intermediate volume."""
|
|
|
|
| 81 |
@app.function(
|
| 82 |
image=image,
|
| 83 |
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 84 |
+
timeout=7200,
|
| 85 |
cpu=2,
|
| 86 |
memory=4096,
|
| 87 |
)
|
|
|
|
| 135 |
|
| 136 |
relation_needs = calculate_relation_limits(config)
|
| 137 |
|
| 138 |
+
# Global containment-style relations are the most expensive and don't
|
| 139 |
+
# need extremely large anchor tables to support sample generation.
|
| 140 |
+
if countries == ["all"]:
|
| 141 |
+
for rel_type, cap in {
|
| 142 |
+
"containment": 12000,
|
| 143 |
+
"coastal_containment": 8000,
|
| 144 |
+
"landlocked_containment": 8000,
|
| 145 |
+
}.items():
|
| 146 |
+
if rel_type in relation_needs:
|
| 147 |
+
relation_needs[rel_type] = min(relation_needs[rel_type], cap)
|
| 148 |
+
|
| 149 |
+
# Spatial relation builds are the most crash-prone part of the Modal
|
| 150 |
+
# pipeline. Run them sequentially with conservative DuckDB settings
|
| 151 |
+
# rather than fanning out several large native spatial joins at once.
|
| 152 |
+
# common_neighbor still runs after adjacency because it depends on the
|
| 153 |
+
# adjacency parquet being committed first.
|
| 154 |
+
ordered_relations = [
|
| 155 |
+
rel_type
|
| 156 |
+
for rel_type in (
|
| 157 |
+
"adjacency",
|
| 158 |
+
"containment",
|
| 159 |
+
"intersection",
|
| 160 |
+
"cross_source",
|
| 161 |
+
"coastal_containment",
|
| 162 |
+
"landlocked_containment",
|
| 163 |
+
"common_neighbor",
|
| 164 |
+
)
|
| 165 |
+
if rel_type in relation_needs
|
| 166 |
+
]
|
| 167 |
|
| 168 |
+
for rel_type in ordered_relations:
|
| 169 |
+
limit = max(relation_needs[rel_type], 500)
|
| 170 |
+
print(f" building {rel_type} (limit={limit})...")
|
| 171 |
+
result = build_relation_remote.remote(rel_type, countries, limit)
|
| 172 |
print(f" {rel_type}: {result['count']} pairs")
|
| 173 |
|
| 174 |
print(f"Generating samples across {n_containers} containers...")
|
dataset/scripts/build_inventory.py
CHANGED
|
@@ -37,10 +37,11 @@ def build_divisions_area_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFra
|
|
| 37 |
ST_XMax(geometry) AS xmax,
|
| 38 |
ST_YMax(geometry) AS ymax
|
| 39 |
FROM read_parquet(?)
|
| 40 |
-
WHERE names."primary" IS NOT NULL
|
| 41 |
AND trim(names."primary") != ''
|
|
|
|
| 42 |
"""
|
| 43 |
-
|
| 44 |
df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
|
| 45 |
print(f"Divisions area inventory: {len(df)} entities")
|
| 46 |
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
|
@@ -69,10 +70,11 @@ def build_natural_earth_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFram
|
|
| 69 |
ST_XMax(geometry) AS xmax,
|
| 70 |
ST_YMax(geometry) AS ymax
|
| 71 |
FROM read_parquet(?)
|
| 72 |
-
WHERE names."primary" IS NOT NULL
|
| 73 |
AND trim(names."primary") != ''
|
|
|
|
| 74 |
"""
|
| 75 |
-
|
| 76 |
df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
|
| 77 |
print(f"\nNatural earth inventory: {len(df)} entities")
|
| 78 |
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
|
|
|
| 37 |
ST_XMax(geometry) AS xmax,
|
| 38 |
ST_YMax(geometry) AS ymax
|
| 39 |
FROM read_parquet(?)
|
| 40 |
+
WHERE names."primary" IS NOT NULL
|
| 41 |
AND trim(names."primary") != ''
|
| 42 |
+
AND geometry IS NOT NULL
|
| 43 |
"""
|
| 44 |
+
|
| 45 |
df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
|
| 46 |
print(f"Divisions area inventory: {len(df)} entities")
|
| 47 |
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
|
|
|
| 70 |
ST_XMax(geometry) AS xmax,
|
| 71 |
ST_YMax(geometry) AS ymax
|
| 72 |
FROM read_parquet(?)
|
| 73 |
+
WHERE names."primary" IS NOT NULL
|
| 74 |
AND trim(names."primary") != ''
|
| 75 |
+
AND geometry IS NOT NULL
|
| 76 |
"""
|
| 77 |
+
|
| 78 |
df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
|
| 79 |
print(f"\nNatural earth inventory: {len(df)} entities")
|
| 80 |
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
dataset/scripts/build_relations.py
CHANGED
|
@@ -14,10 +14,12 @@ Output:
|
|
| 14 |
- intermediate/cross_source_relations.parquet
|
| 15 |
"""
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import duckdb
|
| 18 |
import pandas as pd
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
|
| 22 |
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 23 |
|
|
@@ -26,6 +28,43 @@ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
|
| 26 |
_EXCLUDED_SUBTYPES_FOR_GLOBAL = ("locality", "neighborhood", "microhood", "macrohood")
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def _country_filter(countries: list) -> tuple[str, list]:
|
| 30 |
"""Return (SQL WHERE clause, params) handling 'all' sentinel."""
|
| 31 |
if countries == ["all"]:
|
|
@@ -46,6 +85,29 @@ def _country_filter_for_join(countries: list) -> tuple[str, list]:
|
|
| 46 |
return f"WHERE country IN (SELECT unnest(?)) {subtype_clause}", [countries]
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def compute_adjacency_pairs(
|
| 50 |
con: duckdb.DuckDBPyConnection,
|
| 51 |
countries: list,
|
|
@@ -95,50 +157,100 @@ def compute_adjacency_pairs(
|
|
| 95 |
return df
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def compute_containment_pairs(
|
| 99 |
con: duckdb.DuckDBPyConnection,
|
| 100 |
countries: list,
|
| 101 |
limit: int
|
| 102 |
) -> pd.DataFrame:
|
| 103 |
-
"""Find
|
| 104 |
-
print("\nComputing containment pairs (
|
| 105 |
-
|
| 106 |
-
cfilter, cparams = _country_filter(countries)
|
| 107 |
-
|
| 108 |
-
query = f"""
|
| 109 |
-
WITH features AS (
|
| 110 |
-
SELECT
|
| 111 |
-
id,
|
| 112 |
-
names."primary" AS name,
|
| 113 |
-
subtype,
|
| 114 |
-
country,
|
| 115 |
-
admin_level,
|
| 116 |
-
geometry,
|
| 117 |
-
ST_Envelope(geometry) AS bbox
|
| 118 |
-
FROM read_parquet(?)
|
| 119 |
-
{cfilter}
|
| 120 |
-
)
|
| 121 |
-
SELECT
|
| 122 |
-
a.id AS container_id,
|
| 123 |
-
a.name AS container_name,
|
| 124 |
-
a.subtype AS container_subtype,
|
| 125 |
-
b.id AS contained_id,
|
| 126 |
-
b.name AS contained_name,
|
| 127 |
-
b.subtype AS contained_subtype,
|
| 128 |
-
'containment' AS relation_type
|
| 129 |
-
FROM features AS a
|
| 130 |
-
JOIN features AS b ON (
|
| 131 |
-
a.id != b.id
|
| 132 |
-
AND a.admin_level < b.admin_level
|
| 133 |
-
AND ST_Intersects(a.bbox, b.bbox)
|
| 134 |
-
AND ST_Within(b.geometry, a.geometry)
|
| 135 |
-
)
|
| 136 |
-
LIMIT ?
|
| 137 |
-
"""
|
| 138 |
-
|
| 139 |
-
df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
|
| 140 |
print(f"Found {len(df)} containment pairs")
|
| 141 |
-
|
| 142 |
return df
|
| 143 |
|
| 144 |
|
|
@@ -198,60 +310,71 @@ def compute_cross_source_relations(
|
|
| 198 |
) -> pd.DataFrame:
|
| 199 |
"""Find relations between divisions_area and natural_earth.
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
| 204 |
"""
|
| 205 |
-
print("\nComputing cross-source relations...")
|
| 206 |
|
| 207 |
cfilter, cparams = _country_filter(countries)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
print(f"Found {len(df)} cross-source relations")
|
| 256 |
return df
|
| 257 |
|
|
@@ -261,64 +384,29 @@ def compute_coastal_containment_pairs(
|
|
| 261 |
countries: list,
|
| 262 |
limit: int,
|
| 263 |
) -> pd.DataFrame:
|
| 264 |
-
"""
|
| 265 |
|
| 266 |
-
Used by chained_01 (coastal towns of X)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
Strategy: find countries whose geometry intersects any ocean/sea in
|
| 271 |
-
natural_earth, then filter containment_pairs to those countries.
|
| 272 |
"""
|
| 273 |
-
print("\nComputing coastal containment pairs...")
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
AND n.subtype IN ('sea', 'ocean')
|
| 285 |
-
),
|
| 286 |
-
features AS (
|
| 287 |
-
SELECT
|
| 288 |
-
id,
|
| 289 |
-
names."primary" AS name,
|
| 290 |
-
subtype,
|
| 291 |
-
country,
|
| 292 |
-
admin_level,
|
| 293 |
-
geometry,
|
| 294 |
-
ST_Envelope(geometry) AS bbox
|
| 295 |
-
FROM read_parquet(?)
|
| 296 |
-
{cfilter}
|
| 297 |
-
)
|
| 298 |
-
SELECT
|
| 299 |
-
a.id AS container_id,
|
| 300 |
-
a.name AS container_name,
|
| 301 |
-
a.subtype AS container_subtype,
|
| 302 |
-
b.id AS contained_id,
|
| 303 |
-
b.name AS contained_name,
|
| 304 |
-
b.subtype AS contained_subtype,
|
| 305 |
-
a.country AS container_country,
|
| 306 |
-
'coastal_containment' AS relation_type
|
| 307 |
-
FROM features AS a
|
| 308 |
-
JOIN features AS b ON (
|
| 309 |
-
a.id != b.id
|
| 310 |
-
AND a.admin_level < b.admin_level
|
| 311 |
-
AND ST_Intersects(a.bbox, b.bbox)
|
| 312 |
-
AND ST_Within(b.geometry, a.geometry)
|
| 313 |
-
)
|
| 314 |
-
WHERE a.country IN (SELECT country FROM coastal_countries)
|
| 315 |
-
LIMIT ?
|
| 316 |
"""
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
)
|
| 322 |
print(f"Found {len(df)} coastal containment pairs")
|
| 323 |
return df
|
| 324 |
|
|
@@ -328,61 +416,29 @@ def compute_landlocked_containment_pairs(
|
|
| 328 |
countries: list,
|
| 329 |
limit: int,
|
| 330 |
) -> pd.DataFrame:
|
| 331 |
-
"""
|
| 332 |
|
| 333 |
-
Used by chained_02 (landlocked
|
| 334 |
-
|
| 335 |
-
|
| 336 |
"""
|
| 337 |
-
print("\nComputing landlocked containment pairs...")
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
AND n.subtype IN ('sea', 'ocean')
|
| 349 |
-
),
|
| 350 |
-
features AS (
|
| 351 |
-
SELECT
|
| 352 |
-
id,
|
| 353 |
-
names."primary" AS name,
|
| 354 |
-
subtype,
|
| 355 |
-
country,
|
| 356 |
-
admin_level,
|
| 357 |
-
geometry,
|
| 358 |
-
ST_Envelope(geometry) AS bbox
|
| 359 |
-
FROM read_parquet(?)
|
| 360 |
-
{cfilter}
|
| 361 |
-
)
|
| 362 |
-
SELECT
|
| 363 |
-
a.id AS container_id,
|
| 364 |
-
a.name AS container_name,
|
| 365 |
-
a.subtype AS container_subtype,
|
| 366 |
-
b.id AS contained_id,
|
| 367 |
-
b.name AS contained_name,
|
| 368 |
-
b.subtype AS contained_subtype,
|
| 369 |
-
a.country AS container_country,
|
| 370 |
-
'landlocked_containment' AS relation_type
|
| 371 |
-
FROM features AS a
|
| 372 |
-
JOIN features AS b ON (
|
| 373 |
-
a.id != b.id
|
| 374 |
-
AND a.admin_level < b.admin_level
|
| 375 |
-
AND ST_Intersects(a.bbox, b.bbox)
|
| 376 |
-
AND ST_Within(b.geometry, a.geometry)
|
| 377 |
-
)
|
| 378 |
-
WHERE a.country NOT IN (SELECT country FROM coastal_countries)
|
| 379 |
-
LIMIT ?
|
| 380 |
"""
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
)
|
| 386 |
print(f"Found {len(df)} landlocked containment pairs")
|
| 387 |
return df
|
| 388 |
|
|
@@ -411,6 +467,21 @@ def compute_common_neighbor_pairs(
|
|
| 411 |
])
|
| 412 |
|
| 413 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
SELECT DISTINCT
|
| 415 |
a1.anchor_id AS anchor_id_1,
|
| 416 |
a1.anchor_name AS anchor_name_1,
|
|
@@ -418,8 +489,8 @@ def compute_common_neighbor_pairs(
|
|
| 418 |
a2.anchor_name AS anchor_name_2,
|
| 419 |
a1.target_id AS shared_neighbor_id,
|
| 420 |
a1.target_name AS shared_neighbor_name
|
| 421 |
-
FROM
|
| 422 |
-
JOIN
|
| 423 |
ON a1.target_id = a2.target_id
|
| 424 |
AND a1.anchor_id < a2.anchor_id
|
| 425 |
LIMIT ?
|
|
@@ -435,9 +506,11 @@ def _make_connection():
|
|
| 435 |
con = duckdb.connect()
|
| 436 |
con.execute("INSTALL spatial")
|
| 437 |
con.execute("LOAD spatial")
|
| 438 |
-
|
|
|
|
|
|
|
| 439 |
con.execute("SET temp_directory='/tmp/duckdb_tmp'")
|
| 440 |
-
con.execute("SET threads=
|
| 441 |
return con
|
| 442 |
|
| 443 |
|
|
|
|
| 14 |
- intermediate/cross_source_relations.parquet
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
import os
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
import duckdb
|
| 22 |
import pandas as pd
|
|
|
|
|
|
|
| 23 |
|
| 24 |
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 25 |
|
|
|
|
| 28 |
_EXCLUDED_SUBTYPES_FOR_GLOBAL = ("locality", "neighborhood", "microhood", "macrohood")
|
| 29 |
|
| 30 |
|
| 31 |
+
# (container_subtype, contained_subtype) combos used by the chained,
|
| 32 |
+
# containment, and disambiguation templates. A naive self-join with
|
| 33 |
+
# LIMIT fills up with coarse pairs (country -> region / county) first and
|
| 34 |
+
# never emits locality-level pairs; stratifying by combo ensures each
|
| 35 |
+
# template has anchors to draw from.
|
| 36 |
+
_CONTAINMENT_SUBTYPE_PAIRS = (
|
| 37 |
+
("country", "region"),
|
| 38 |
+
("country", "county"),
|
| 39 |
+
("country", "localadmin"),
|
| 40 |
+
("country", "locality"),
|
| 41 |
+
("region", "county"),
|
| 42 |
+
("region", "localadmin"),
|
| 43 |
+
("region", "locality"),
|
| 44 |
+
("county", "locality"),
|
| 45 |
+
("localadmin", "locality"),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Natural Earth subtype vocabulary in the current GeoParquet.
|
| 50 |
+
# Keep these exact strings in one place so relation building and templates
|
| 51 |
+
# stay aligned with the underlying dataset.
|
| 52 |
+
_NE_CROSS_SOURCE_SUBTYPES = (
|
| 53 |
+
"sea",
|
| 54 |
+
"ocean",
|
| 55 |
+
"Lake",
|
| 56 |
+
"River",
|
| 57 |
+
"Basin",
|
| 58 |
+
"gulf",
|
| 59 |
+
"bay",
|
| 60 |
+
"Island group",
|
| 61 |
+
"Peninsula",
|
| 62 |
+
"strait",
|
| 63 |
+
"Range/mtn",
|
| 64 |
+
"Depression",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
def _country_filter(countries: list) -> tuple[str, list]:
|
| 69 |
"""Return (SQL WHERE clause, params) handling 'all' sentinel."""
|
| 70 |
if countries == ["all"]:
|
|
|
|
| 85 |
return f"WHERE country IN (SELECT unnest(?)) {subtype_clause}", [countries]
|
| 86 |
|
| 87 |
|
| 88 |
+
def _country_chunks(
|
| 89 |
+
con: duckdb.DuckDBPyConnection,
|
| 90 |
+
countries: list,
|
| 91 |
+
chunk_size: int = 40,
|
| 92 |
+
) -> list[list[str]]:
|
| 93 |
+
"""Return explicit country batches for safer global containment joins."""
|
| 94 |
+
if countries != ["all"]:
|
| 95 |
+
return [countries]
|
| 96 |
+
|
| 97 |
+
rows = con.execute(
|
| 98 |
+
"""
|
| 99 |
+
SELECT DISTINCT country
|
| 100 |
+
FROM read_parquet(?)
|
| 101 |
+
WHERE country IS NOT NULL
|
| 102 |
+
AND trim(country) != ''
|
| 103 |
+
ORDER BY country
|
| 104 |
+
""",
|
| 105 |
+
[DIVISIONS_AREA_PATH],
|
| 106 |
+
).fetchall()
|
| 107 |
+
codes = [row[0] for row in rows]
|
| 108 |
+
return [codes[i:i + chunk_size] for i in range(0, len(codes), chunk_size)]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
def compute_adjacency_pairs(
|
| 112 |
con: duckdb.DuckDBPyConnection,
|
| 113 |
countries: list,
|
|
|
|
| 157 |
return df
|
| 158 |
|
| 159 |
|
| 160 |
+
def _stratified_containment(
|
| 161 |
+
con: duckdb.DuckDBPyConnection,
|
| 162 |
+
countries: list,
|
| 163 |
+
limit: int,
|
| 164 |
+
relation_type: str,
|
| 165 |
+
extra_where: str = "",
|
| 166 |
+
extra_params: list = None,
|
| 167 |
+
) -> pd.DataFrame:
|
| 168 |
+
"""Compute containment pairs stratified by (container_subtype, contained_subtype).
|
| 169 |
+
|
| 170 |
+
A single global self-join with LIMIT fills up with coarse country->region
|
| 171 |
+
pairs before emitting any locality-level pairs. We run one focused query
|
| 172 |
+
per subtype combo instead so every combo receives a fair share of the
|
| 173 |
+
overall limit.
|
| 174 |
+
|
| 175 |
+
``extra_where`` / ``extra_params`` let the coastal and landlocked variants
|
| 176 |
+
inject their country-set filter without duplicating the whole body.
|
| 177 |
+
"""
|
| 178 |
+
extra_params = extra_params or []
|
| 179 |
+
# Use a lower target per subtype combo for global runs; they are the most
|
| 180 |
+
# memory-intensive part of the pipeline and don't need huge anchor tables.
|
| 181 |
+
if countries == ["all"]:
|
| 182 |
+
per_combo = min(max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100), 1500)
|
| 183 |
+
else:
|
| 184 |
+
per_combo = max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100)
|
| 185 |
+
|
| 186 |
+
country_batches = _country_chunks(con, countries)
|
| 187 |
+
frames: list[pd.DataFrame] = []
|
| 188 |
+
for container_st, contained_st in _CONTAINMENT_SUBTYPE_PAIRS:
|
| 189 |
+
remaining = per_combo
|
| 190 |
+
combo_parts: list[pd.DataFrame] = []
|
| 191 |
+
|
| 192 |
+
for batch in country_batches:
|
| 193 |
+
if remaining <= 0:
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
cfilter, cparams = _country_filter(batch)
|
| 197 |
+
query = f"""
|
| 198 |
+
WITH a AS (
|
| 199 |
+
SELECT src.id, src.names."primary" AS name, src.subtype, src.country, src.admin_level,
|
| 200 |
+
src.geometry, ST_Envelope(src.geometry) AS bbox
|
| 201 |
+
FROM read_parquet(?) AS src
|
| 202 |
+
WHERE src.subtype = '{container_st}'
|
| 203 |
+
{cfilter.replace("WHERE", "AND") if cfilter else ""}
|
| 204 |
+
{extra_where}
|
| 205 |
+
),
|
| 206 |
+
b AS (
|
| 207 |
+
SELECT dst.id, dst.names."primary" AS name, dst.subtype, dst.country, dst.admin_level,
|
| 208 |
+
dst.geometry, ST_Envelope(dst.geometry) AS bbox
|
| 209 |
+
FROM read_parquet(?) AS dst
|
| 210 |
+
WHERE dst.subtype = '{contained_st}'
|
| 211 |
+
{cfilter.replace("WHERE", "AND") if cfilter else ""}
|
| 212 |
+
)
|
| 213 |
+
SELECT
|
| 214 |
+
a.id AS container_id,
|
| 215 |
+
a.name AS container_name,
|
| 216 |
+
a.subtype AS container_subtype,
|
| 217 |
+
b.id AS contained_id,
|
| 218 |
+
b.name AS contained_name,
|
| 219 |
+
b.subtype AS contained_subtype,
|
| 220 |
+
a.country AS container_country,
|
| 221 |
+
'{relation_type}' AS relation_type
|
| 222 |
+
FROM a JOIN b ON (
|
| 223 |
+
a.id != b.id
|
| 224 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 225 |
+
AND ST_Within(b.geometry, a.geometry)
|
| 226 |
+
)
|
| 227 |
+
LIMIT {remaining}
|
| 228 |
+
"""
|
| 229 |
+
params = [DIVISIONS_AREA_PATH] + extra_params + cparams + [DIVISIONS_AREA_PATH] + cparams
|
| 230 |
+
df_part = con.execute(query, params).fetchdf()
|
| 231 |
+
if not df_part.empty:
|
| 232 |
+
combo_parts.append(df_part)
|
| 233 |
+
remaining -= len(df_part)
|
| 234 |
+
|
| 235 |
+
df_combo = (
|
| 236 |
+
pd.concat(combo_parts, ignore_index=True)
|
| 237 |
+
if combo_parts else pd.DataFrame()
|
| 238 |
+
)
|
| 239 |
+
print(f" {relation_type} {container_st:>10s} -> {contained_st:<10s}: {len(df_combo)} pairs")
|
| 240 |
+
frames.append(df_combo)
|
| 241 |
+
|
| 242 |
+
return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
def compute_containment_pairs(
|
| 246 |
con: duckdb.DuckDBPyConnection,
|
| 247 |
countries: list,
|
| 248 |
limit: int
|
| 249 |
) -> pd.DataFrame:
|
| 250 |
+
"""Find containment pairs stratified across admin-level combinations."""
|
| 251 |
+
print("\nComputing containment pairs (stratified by subtype combo)...")
|
| 252 |
+
df = _stratified_containment(con, countries, limit, relation_type="containment")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
print(f"Found {len(df)} containment pairs")
|
|
|
|
| 254 |
return df
|
| 255 |
|
| 256 |
|
|
|
|
| 310 |
) -> pd.DataFrame:
|
| 311 |
"""Find relations between divisions_area and natural_earth.
|
| 312 |
|
| 313 |
+
The join is heavily skewed by a few abundant Natural Earth subtypes
|
| 314 |
+
(especially rivers and mountain ranges). We therefore stratify by exact
|
| 315 |
+
natural subtype so seas/oceans, gulfs/bays, island groups, and rarer
|
| 316 |
+
landforms all make it into the anchor pool.
|
| 317 |
"""
|
| 318 |
+
print("\nComputing cross-source relations (stratified by NE subtype)...")
|
| 319 |
|
| 320 |
cfilter, cparams = _country_filter(countries)
|
| 321 |
+
per_subtype = max(limit // len(_NE_CROSS_SOURCE_SUBTYPES), 50)
|
| 322 |
+
|
| 323 |
+
frames: list[pd.DataFrame] = []
|
| 324 |
+
for natural_subtype in _NE_CROSS_SOURCE_SUBTYPES:
|
| 325 |
+
query = f"""
|
| 326 |
+
WITH divisions AS (
|
| 327 |
+
SELECT
|
| 328 |
+
id,
|
| 329 |
+
names."primary" AS name,
|
| 330 |
+
subtype,
|
| 331 |
+
country,
|
| 332 |
+
geometry
|
| 333 |
+
FROM read_parquet(?)
|
| 334 |
+
WHERE geometry IS NOT NULL
|
| 335 |
+
AND names."primary" IS NOT NULL
|
| 336 |
+
AND trim(names."primary") != ''
|
| 337 |
+
{cfilter.replace("WHERE", "AND") if cfilter else ''}
|
| 338 |
+
),
|
| 339 |
+
natural_features AS (
|
| 340 |
+
SELECT
|
| 341 |
+
id,
|
| 342 |
+
names."primary" AS name,
|
| 343 |
+
subtype,
|
| 344 |
+
geometry
|
| 345 |
+
FROM read_parquet(?)
|
| 346 |
+
WHERE geometry IS NOT NULL
|
| 347 |
+
AND names."primary" IS NOT NULL
|
| 348 |
+
AND trim(names."primary") != ''
|
| 349 |
+
AND subtype = '{natural_subtype}'
|
| 350 |
)
|
| 351 |
+
SELECT
|
| 352 |
+
d.id AS division_id,
|
| 353 |
+
d.name AS division_name,
|
| 354 |
+
d.subtype AS division_subtype,
|
| 355 |
+
d.country AS division_country,
|
| 356 |
+
n.id AS natural_id,
|
| 357 |
+
n.name AS natural_name,
|
| 358 |
+
n.subtype AS natural_subtype,
|
| 359 |
+
CASE
|
| 360 |
+
WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
|
| 361 |
+
WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
|
| 362 |
+
WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
|
| 363 |
+
WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
|
| 364 |
+
END AS relation_type
|
| 365 |
+
FROM divisions AS d
|
| 366 |
+
JOIN natural_features AS n
|
| 367 |
+
ON ST_Intersects(d.geometry, n.geometry)
|
| 368 |
+
LIMIT {per_subtype}
|
| 369 |
+
"""
|
| 370 |
+
df_part = con.execute(
|
| 371 |
+
query,
|
| 372 |
+
[DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH],
|
| 373 |
+
).fetchdf()
|
| 374 |
+
print(f" cross_source {natural_subtype:>12s}: {len(df_part)} rows")
|
| 375 |
+
frames.append(df_part)
|
| 376 |
+
|
| 377 |
+
df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
|
| 378 |
print(f"Found {len(df)} cross-source relations")
|
| 379 |
return df
|
| 380 |
|
|
|
|
| 384 |
countries: list,
|
| 385 |
limit: int,
|
| 386 |
) -> pd.DataFrame:
|
| 387 |
+
"""Stratified containment pairs limited to coastal-country containers.
|
| 388 |
|
| 389 |
+
Used by chained_01 (coastal towns of X) so sampled anchors actually have
|
| 390 |
+
sea-adjacent sub-features. Stratification guarantees coverage of every
|
| 391 |
+
admin-level combination (country->locality, region->locality, etc.).
|
|
|
|
|
|
|
|
|
|
| 392 |
"""
|
| 393 |
+
print("\nComputing coastal containment pairs (stratified)...")
|
| 394 |
+
extra_where = f"""
|
| 395 |
+
AND EXISTS (
|
| 396 |
+
SELECT 1
|
| 397 |
+
FROM read_parquet('{NATURAL_EARTH_PATH}') AS n
|
| 398 |
+
WHERE n.geometry IS NOT NULL
|
| 399 |
+
AND n.names."primary" IS NOT NULL
|
| 400 |
+
AND trim(n.names."primary") != ''
|
| 401 |
+
AND n.subtype IN ('sea', 'ocean')
|
| 402 |
+
AND ST_Intersects(src.geometry, n.geometry)
|
| 403 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
"""
|
| 405 |
+
df = _stratified_containment(
|
| 406 |
+
con, countries, limit,
|
| 407 |
+
relation_type="coastal_containment",
|
| 408 |
+
extra_where=extra_where,
|
| 409 |
+
)
|
| 410 |
print(f"Found {len(df)} coastal containment pairs")
|
| 411 |
return df
|
| 412 |
|
|
|
|
| 416 |
countries: list,
|
| 417 |
limit: int,
|
| 418 |
) -> pd.DataFrame:
|
| 419 |
+
"""Stratified containment pairs limited to landlocked-country containers.
|
| 420 |
|
| 421 |
+
Used by chained_02 (landlocked localities within X). Stratification by
|
| 422 |
+
subtype combo ensures locality-level pairs are actually present in the
|
| 423 |
+
output instead of being starved by coarse country->region pairs.
|
| 424 |
"""
|
| 425 |
+
print("\nComputing landlocked containment pairs (stratified)...")
|
| 426 |
+
extra_where = f"""
|
| 427 |
+
AND NOT EXISTS (
|
| 428 |
+
SELECT 1
|
| 429 |
+
FROM read_parquet('{NATURAL_EARTH_PATH}') AS n
|
| 430 |
+
WHERE n.geometry IS NOT NULL
|
| 431 |
+
AND n.names."primary" IS NOT NULL
|
| 432 |
+
AND trim(n.names."primary") != ''
|
| 433 |
+
AND n.subtype IN ('sea', 'ocean')
|
| 434 |
+
AND ST_Intersects(src.geometry, n.geometry)
|
| 435 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
"""
|
| 437 |
+
df = _stratified_containment(
|
| 438 |
+
con, countries, limit,
|
| 439 |
+
relation_type="landlocked_containment",
|
| 440 |
+
extra_where=extra_where,
|
| 441 |
+
)
|
| 442 |
print(f"Found {len(df)} landlocked containment pairs")
|
| 443 |
return df
|
| 444 |
|
|
|
|
| 467 |
])
|
| 468 |
|
| 469 |
query = """
|
| 470 |
+
WITH undirected AS (
|
| 471 |
+
SELECT
|
| 472 |
+
anchor_id,
|
| 473 |
+
anchor_name,
|
| 474 |
+
target_id,
|
| 475 |
+
target_name
|
| 476 |
+
FROM read_parquet(?)
|
| 477 |
+
UNION ALL
|
| 478 |
+
SELECT
|
| 479 |
+
target_id AS anchor_id,
|
| 480 |
+
target_name AS anchor_name,
|
| 481 |
+
anchor_id AS target_id,
|
| 482 |
+
anchor_name AS target_name
|
| 483 |
+
FROM read_parquet(?)
|
| 484 |
+
)
|
| 485 |
SELECT DISTINCT
|
| 486 |
a1.anchor_id AS anchor_id_1,
|
| 487 |
a1.anchor_name AS anchor_name_1,
|
|
|
|
| 489 |
a2.anchor_name AS anchor_name_2,
|
| 490 |
a1.target_id AS shared_neighbor_id,
|
| 491 |
a1.target_name AS shared_neighbor_name
|
| 492 |
+
FROM undirected AS a1
|
| 493 |
+
JOIN undirected AS a2
|
| 494 |
ON a1.target_id = a2.target_id
|
| 495 |
AND a1.anchor_id < a2.anchor_id
|
| 496 |
LIMIT ?
|
|
|
|
| 506 |
con = duckdb.connect()
|
| 507 |
con.execute("INSTALL spatial")
|
| 508 |
con.execute("LOAD spatial")
|
| 509 |
+
memory_limit = os.environ.get("GAZET_DUCKDB_MEMORY_LIMIT", "12GB")
|
| 510 |
+
threads = int(os.environ.get("GAZET_DUCKDB_THREADS", "1"))
|
| 511 |
+
con.execute(f"SET memory_limit='{memory_limit}'")
|
| 512 |
con.execute("SET temp_directory='/tmp/duckdb_tmp'")
|
| 513 |
+
con.execute(f"SET threads={threads}")
|
| 514 |
return con
|
| 515 |
|
| 516 |
|
dataset/scripts/cli.py
CHANGED
|
@@ -110,6 +110,19 @@ def calculate_relation_limits(config: dict) -> Dict[str, int]:
|
|
| 110 |
return relation_needs
|
| 111 |
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
def build_relations(config_path: Path):
|
| 114 |
"""Run relation building with config."""
|
| 115 |
config = load_config(config_path)
|
|
@@ -269,6 +282,9 @@ def main():
|
|
| 269 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 270 |
epilog="""
|
| 271 |
Examples:
|
|
|
|
|
|
|
|
|
|
| 272 |
# Build relation tables only
|
| 273 |
python cli.py build-relations --config ../config.yaml
|
| 274 |
|
|
@@ -296,7 +312,7 @@ Examples:
|
|
| 296 |
|
| 297 |
parser.add_argument(
|
| 298 |
'command',
|
| 299 |
-
choices=['build-relations', 'generate-samples', 'validate', 'export',
|
| 300 |
'full-pipeline', 'modal-upload', 'modal-generate'],
|
| 301 |
help='Command to run'
|
| 302 |
)
|
|
@@ -348,7 +364,9 @@ Examples:
|
|
| 348 |
|
| 349 |
# Run the appropriate command
|
| 350 |
try:
|
| 351 |
-
if args.command == '
|
|
|
|
|
|
|
| 352 |
build_relations(args.config)
|
| 353 |
elif args.command == 'generate-samples':
|
| 354 |
generate_samples(args.config, args.append)
|
|
|
|
| 110 |
return relation_needs
|
| 111 |
|
| 112 |
|
| 113 |
+
def normalize_data():
|
| 114 |
+
"""Build normalized source parquet copies with harmonized geometry metadata."""
|
| 115 |
+
print("=" * 60)
|
| 116 |
+
print("STEP 0: Normalizing Source Geodata")
|
| 117 |
+
print("=" * 60)
|
| 118 |
+
from dataset.scripts.normalize_geodata import normalize_geodata
|
| 119 |
+
|
| 120 |
+
result = normalize_geodata()
|
| 121 |
+
for name, path in result.items():
|
| 122 |
+
print(f" {name}: {path}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
def build_relations(config_path: Path):
|
| 127 |
"""Run relation building with config."""
|
| 128 |
config = load_config(config_path)
|
|
|
|
| 282 |
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 283 |
epilog="""
|
| 284 |
Examples:
|
| 285 |
+
# Normalize source geodata first (recommended before Modal upload)
|
| 286 |
+
python cli.py normalize-data --config ../config.yaml
|
| 287 |
+
|
| 288 |
# Build relation tables only
|
| 289 |
python cli.py build-relations --config ../config.yaml
|
| 290 |
|
|
|
|
| 312 |
|
| 313 |
parser.add_argument(
|
| 314 |
'command',
|
| 315 |
+
choices=['normalize-data', 'build-relations', 'generate-samples', 'validate', 'export',
|
| 316 |
'full-pipeline', 'modal-upload', 'modal-generate'],
|
| 317 |
help='Command to run'
|
| 318 |
)
|
|
|
|
| 364 |
|
| 365 |
# Run the appropriate command
|
| 366 |
try:
|
| 367 |
+
if args.command == 'normalize-data':
|
| 368 |
+
normalize_data()
|
| 369 |
+
elif args.command == 'build-relations':
|
| 370 |
build_relations(args.config)
|
| 371 |
elif args.command == 'generate-samples':
|
| 372 |
generate_samples(args.config, args.append)
|
dataset/scripts/export_training_data.py
CHANGED
|
@@ -124,7 +124,7 @@ You have access to two DuckDB parquet tables. Given a set of candidate entities
|
|
| 124 |
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 125 |
names STRUCT("primary" VARCHAR, ...)
|
| 126 |
country VARCHAR
|
| 127 |
-
subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', '
|
| 128 |
class VARCHAR
|
| 129 |
region VARCHAR
|
| 130 |
admin_level INTEGER
|
|
|
|
| 124 |
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 125 |
names STRUCT("primary" VARCHAR, ...)
|
| 126 |
country VARCHAR
|
| 127 |
+
subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Range/mtn', 'Island group'
|
| 128 |
class VARCHAR
|
| 129 |
region VARCHAR
|
| 130 |
admin_level INTEGER
|
dataset/scripts/generate_samples.py
CHANGED
|
@@ -44,7 +44,7 @@ def _for_execution(sql: str) -> str:
|
|
| 44 |
return (
|
| 45 |
sql
|
| 46 |
.replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
|
| 47 |
-
.replace("read_parquet('natural_earth')",
|
| 48 |
)
|
| 49 |
|
| 50 |
# Configurable parameters (can be overridden by CLI)
|
|
@@ -60,6 +60,33 @@ SQLTemplate = sql_templates.SQLTemplate
|
|
| 60 |
get_templates_by_family = sql_templates.get_templates_by_family
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
class Candidate(BaseModel):
|
| 64 |
"""Candidate entity for grounding."""
|
| 65 |
candidate_id: str
|
|
@@ -121,6 +148,8 @@ def sample_adjacency_anchor(
|
|
| 121 |
'anchor_name': row['anchor_name'],
|
| 122 |
'anchor_subtype': row['anchor_subtype'],
|
| 123 |
'anchor_country': row.get('anchor_country'), # May not exist in all tables
|
|
|
|
|
|
|
| 124 |
'target_subtype': row.get('target_subtype')
|
| 125 |
}
|
| 126 |
|
|
@@ -142,16 +171,23 @@ def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[s
|
|
| 142 |
|
| 143 |
|
| 144 |
def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 145 |
-
"""Sample a random containment pair.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
if containment_df.empty:
|
| 147 |
return None
|
| 148 |
-
|
| 149 |
row = containment_df.sample(n=1).iloc[0]
|
| 150 |
return {
|
| 151 |
'container_id': row['container_id'],
|
| 152 |
'container_name': row['container_name'],
|
| 153 |
'container_subtype': row['container_subtype'],
|
| 154 |
-
'
|
|
|
|
|
|
|
| 155 |
}
|
| 156 |
|
| 157 |
|
|
@@ -186,12 +222,24 @@ def sample_disambiguation_anchor(
|
|
| 186 |
}
|
| 187 |
|
| 188 |
|
| 189 |
-
def sample_cross_source_anchor(
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
if cross_source_df.empty:
|
| 192 |
return None
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
return {
|
| 196 |
'division_id': row['division_id'],
|
| 197 |
'division_name': row['division_name'],
|
|
@@ -242,16 +290,16 @@ def build_candidate_list(
|
|
| 242 |
difficulty: str = "medium"
|
| 243 |
) -> List[Candidate]:
|
| 244 |
"""Build candidate list with true anchor + distractors."""
|
| 245 |
-
|
| 246 |
# Helper to convert pandas NA to None
|
| 247 |
def safe_get(row, key, default=None):
|
| 248 |
val = row.get(key, default)
|
| 249 |
return None if pd.isna(val) else val
|
| 250 |
-
|
| 251 |
# Get the true anchor
|
| 252 |
if anchor_source == "divisions_area":
|
| 253 |
query = """
|
| 254 |
-
SELECT
|
| 255 |
id,
|
| 256 |
names."primary" AS name,
|
| 257 |
subtype,
|
|
@@ -264,7 +312,7 @@ def build_candidate_list(
|
|
| 264 |
anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
|
| 265 |
else:
|
| 266 |
query = """
|
| 267 |
-
SELECT
|
| 268 |
id,
|
| 269 |
names."primary" AS name,
|
| 270 |
subtype
|
|
@@ -272,8 +320,7 @@ def build_candidate_list(
|
|
| 272 |
WHERE id = ?
|
| 273 |
"""
|
| 274 |
anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
|
| 275 |
-
|
| 276 |
-
# Build true candidate
|
| 277 |
true_candidate = Candidate(
|
| 278 |
candidate_id="c1",
|
| 279 |
source=anchor_source,
|
|
@@ -283,25 +330,31 @@ def build_candidate_list(
|
|
| 283 |
country=safe_get(anchor_row, 'country'),
|
| 284 |
region=safe_get(anchor_row, 'region'),
|
| 285 |
admin_level=safe_get(anchor_row, 'admin_level'),
|
| 286 |
-
similarity=1.0
|
| 287 |
)
|
| 288 |
-
|
| 289 |
-
# Build distractors based on difficulty
|
| 290 |
distractors = build_distractors(
|
| 291 |
-
con,
|
| 292 |
-
anchor_name,
|
| 293 |
anchor_source,
|
| 294 |
anchor_id,
|
| 295 |
num_candidates - 1,
|
| 296 |
-
difficulty
|
| 297 |
)
|
| 298 |
-
|
| 299 |
-
# Order: true anchor first, then same-source distractors, then cross-source
|
| 300 |
-
# distractors. This mirrors inference order (anchor at top by similarity,
|
| 301 |
-
# same source grouped before the other source).
|
| 302 |
-
candidates = [true_candidate] + distractors
|
| 303 |
|
| 304 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
for i, cand in enumerate(candidates, 1):
|
| 306 |
cand.candidate_id = f"c{i}"
|
| 307 |
|
|
@@ -335,21 +388,39 @@ def build_distractors(
|
|
| 335 |
|
| 336 |
def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
|
| 337 |
query = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
SELECT
|
| 339 |
id,
|
| 340 |
-
|
| 341 |
subtype,
|
| 342 |
country,
|
| 343 |
region,
|
| 344 |
admin_level,
|
| 345 |
-
|
| 346 |
-
FROM
|
| 347 |
-
WHERE
|
| 348 |
-
AND names."primary" IS NOT NULL
|
| 349 |
ORDER BY similarity DESC
|
| 350 |
LIMIT ?
|
| 351 |
"""
|
| 352 |
-
df = con.execute(query, [anchor_name, path, excl_id, n]).fetchdf()
|
| 353 |
results = []
|
| 354 |
for _, row in df.iterrows():
|
| 355 |
results.append(Candidate(
|
|
@@ -515,13 +586,23 @@ WHERE b.id != '{anchor['container_id']}'
|
|
| 515 |
def sample_random_entity(
|
| 516 |
con: duckdb.DuckDBPyConnection,
|
| 517 |
inventory_df: pd.DataFrame,
|
| 518 |
-
source: str
|
|
|
|
|
|
|
| 519 |
) -> Optional[Dict[str, Any]]:
|
| 520 |
-
"""Sample a random entity from inventory."""
|
| 521 |
if inventory_df.empty:
|
| 522 |
return None
|
| 523 |
-
|
| 524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
return {
|
| 526 |
'id': row['id'],
|
| 527 |
'name': row['name'],
|
|
@@ -545,7 +626,12 @@ def generate_template_based_sample(
|
|
| 545 |
if template.anchor_source == "divisions_area":
|
| 546 |
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 547 |
else:
|
| 548 |
-
anchor = sample_random_entity(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
if not anchor:
|
| 551 |
return None
|
|
@@ -636,70 +722,156 @@ def generate_template_based_sample(
|
|
| 636 |
anchor = {"id": pair["contained_id"], "name": pair["contained_name"]}
|
| 637 |
|
| 638 |
elif template.family == "adjacency":
|
| 639 |
-
#
|
| 640 |
-
#
|
| 641 |
-
#
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
if not anchor:
|
| 647 |
return None
|
| 648 |
|
| 649 |
sql = template.sql_template.format(
|
| 650 |
anchor_id=anchor['anchor_id'],
|
| 651 |
-
target_subtype=anchor
|
| 652 |
)
|
| 653 |
-
|
| 654 |
candidates = build_candidate_list(
|
| 655 |
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
|
| 656 |
num_candidates=10, difficulty="medium"
|
| 657 |
)
|
| 658 |
-
|
| 659 |
question = random.choice(template.question_hints).format(
|
| 660 |
anchor_name=anchor['anchor_name'],
|
| 661 |
-
target_subtype=anchor
|
| 662 |
)
|
| 663 |
|
| 664 |
elif template.family == "containment":
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
elif template.family == "intersection":
|
| 685 |
if template.anchor_source == "natural_earth":
|
| 686 |
-
anchor = sample_cross_source_anchor(
|
|
|
|
|
|
|
|
|
|
| 687 |
if not anchor:
|
| 688 |
return None
|
| 689 |
-
|
|
|
|
|
|
|
| 690 |
sql = template.sql_template.format(
|
| 691 |
anchor_id=anchor['natural_id'],
|
| 692 |
-
target_subtype=
|
| 693 |
)
|
| 694 |
-
|
| 695 |
candidates = build_candidate_list(
|
| 696 |
con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
|
| 697 |
num_candidates=10, difficulty="medium"
|
| 698 |
)
|
| 699 |
-
|
| 700 |
question = random.choice(template.question_hints).format(
|
| 701 |
anchor_name=anchor['natural_name'],
|
| 702 |
-
target_subtype=
|
| 703 |
)
|
| 704 |
else:
|
| 705 |
# Same-source intersection
|
|
@@ -760,14 +932,19 @@ def generate_template_based_sample(
|
|
| 760 |
# country IN clause — 2 or 3 anchors, each contributes its country code
|
| 761 |
num_a = 3 if template.template_id == "contain_multi_02" else 2
|
| 762 |
anchors = [
|
| 763 |
-
sample_random_entity(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
for _ in range(num_a)
|
| 765 |
]
|
| 766 |
if any(a is None for a in anchors):
|
| 767 |
return None
|
| 768 |
|
| 769 |
countries = [a.get('country') or 'US' for a in anchors]
|
| 770 |
-
target_subtype =
|
| 771 |
per_anchor = 3 if num_a == 3 else 4
|
| 772 |
|
| 773 |
fmt_kwargs = dict(
|
|
@@ -850,7 +1027,12 @@ def generate_template_based_sample(
|
|
| 850 |
|
| 851 |
if template.num_anchors == 1:
|
| 852 |
if template.anchor_source == "natural_earth":
|
| 853 |
-
anchor = sample_random_entity(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
else:
|
| 855 |
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 856 |
if not anchor:
|
|
@@ -944,20 +1126,22 @@ def generate_template_based_sample(
|
|
| 944 |
# Mixed-source clip: division intersected with a natural_earth feature.
|
| 945 |
# Use cross_source_relations so the pair is guaranteed to intersect —
|
| 946 |
# random sampling almost never produces an intersecting pair.
|
| 947 |
-
|
| 948 |
-
|
|
|
|
|
|
|
|
|
|
| 949 |
return None
|
| 950 |
-
row = cs_df.sample(n=1).iloc[0]
|
| 951 |
clip_feature = {
|
| 952 |
-
'id':
|
| 953 |
-
'name':
|
| 954 |
'source': 'natural_earth',
|
| 955 |
}
|
| 956 |
# Override the division anchor with the paired division so the
|
| 957 |
# ST_Intersects check in the SQL is guaranteed to pass.
|
| 958 |
anchor = {
|
| 959 |
-
'id':
|
| 960 |
-
'name':
|
| 961 |
'source': 'divisions_area',
|
| 962 |
}
|
| 963 |
|
|
@@ -986,8 +1170,14 @@ def generate_template_based_sample(
|
|
| 986 |
target_subtype = random.choice(['locality', 'region'])
|
| 987 |
|
| 988 |
if template.template_id in ['agg_03', 'agg_04']:
|
| 989 |
-
# Country-level aggregation: SQL uses country code,
|
| 990 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
if not anchor:
|
| 992 |
return None
|
| 993 |
|
|
@@ -1035,35 +1225,33 @@ def generate_template_based_sample(
|
|
| 1035 |
elif template.family == "chained":
|
| 1036 |
# Use pre-filtered coastal/landlocked containment pairs so the SQL
|
| 1037 |
# verification step doesn't constantly return empty results.
|
| 1038 |
-
|
|
|
|
|
|
|
| 1039 |
table_key = 'coastal_containment_pairs'
|
| 1040 |
-
elif template.template_id
|
| 1041 |
table_key = 'landlocked_containment_pairs'
|
| 1042 |
else:
|
| 1043 |
table_key = 'containment_pairs'
|
| 1044 |
|
| 1045 |
-
# chained_10/11 need a country-level anchor ("coastal states of
|
| 1046 |
-
# India") and region-level targets, so filter the containment pairs
|
| 1047 |
-
# to (container=country, contained=region) before sampling.
|
| 1048 |
-
_chained_subtype_filter = {
|
| 1049 |
-
"chained_10": ("country", "region"),
|
| 1050 |
-
"chained_11": ("country", "region"),
|
| 1051 |
-
}
|
| 1052 |
df = tables.get(table_key, tables['containment_pairs'])
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1059 |
|
| 1060 |
anchor = sample_containment_anchor(df)
|
| 1061 |
if not anchor:
|
| 1062 |
return None
|
| 1063 |
|
| 1064 |
-
# Prefer the template-pinned target_subtype when set (e.g. chained_10
|
| 1065 |
-
# always wants 'region') so the SQL filter and question phrasing stay
|
| 1066 |
-
# in sync regardless of what the sampled pair happens to contain.
|
| 1067 |
target_subtype = template.target_subtype or anchor.get('contained_subtype', 'locality')
|
| 1068 |
|
| 1069 |
sql = template.sql_template.format(
|
|
@@ -1117,18 +1305,20 @@ def generate_template_based_sample(
|
|
| 1117 |
# Use cross_source_relations so the pair is guaranteed to intersect
|
| 1118 |
# (ST_Difference on non-intersecting geometries is always equal to
|
| 1119 |
# the original geometry — a trivial and uninformative sample).
|
| 1120 |
-
|
| 1121 |
-
|
|
|
|
|
|
|
|
|
|
| 1122 |
return None
|
| 1123 |
-
row = cs_df.sample(n=1).iloc[0]
|
| 1124 |
anchor = {
|
| 1125 |
-
'id':
|
| 1126 |
-
'name':
|
| 1127 |
'source': 'divisions_area',
|
| 1128 |
}
|
| 1129 |
clip_feature = {
|
| 1130 |
-
'id':
|
| 1131 |
-
'name':
|
| 1132 |
'source': 'natural_earth',
|
| 1133 |
}
|
| 1134 |
|
|
@@ -1153,20 +1343,19 @@ def generate_template_based_sample(
|
|
| 1153 |
candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
|
| 1154 |
|
| 1155 |
else:
|
| 1156 |
-
# Two divisions_area anchors
|
| 1157 |
-
#
|
|
|
|
|
|
|
| 1158 |
pair = sample_containment_anchor(tables['containment_pairs'])
|
| 1159 |
if not pair:
|
| 1160 |
return None
|
| 1161 |
|
| 1162 |
anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
|
| 1163 |
-
|
| 1164 |
-
if not anchor2_row:
|
| 1165 |
-
return None
|
| 1166 |
-
anchor2 = anchor2_row
|
| 1167 |
|
| 1168 |
sql = template.sql_template.format(
|
| 1169 |
-
|
| 1170 |
anchor_id_2=anchor2['id'],
|
| 1171 |
)
|
| 1172 |
|
|
@@ -1191,19 +1380,8 @@ def generate_template_based_sample(
|
|
| 1191 |
if not pair:
|
| 1192 |
return None
|
| 1193 |
|
| 1194 |
-
# The adjacency table only records one direction; sample a second
|
| 1195 |
-
# anchor that is known to be adjacent to the first.
|
| 1196 |
anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
|
| 1197 |
-
|
| 1198 |
-
# Find a random neighbour of anchor1 from adjacency pairs
|
| 1199 |
-
neighbours = tables['adjacency_pairs']
|
| 1200 |
-
neighbours = neighbours[neighbours['anchor_id'] == anchor1['id']]
|
| 1201 |
-
if neighbours.empty:
|
| 1202 |
-
return None
|
| 1203 |
-
nb_row = neighbours.sample(n=1).iloc[0]
|
| 1204 |
-
anchor2 = {'id': nb_row.get('target_id', nb_row['anchor_id']), 'name': nb_row.get('target_name', nb_row['anchor_name'])}
|
| 1205 |
-
if anchor1['id'] == anchor2['id']:
|
| 1206 |
-
return None
|
| 1207 |
|
| 1208 |
buffer_val = random.choice([5, 10, 25, 50])
|
| 1209 |
|
|
@@ -1230,12 +1408,17 @@ def generate_template_based_sample(
|
|
| 1230 |
)
|
| 1231 |
|
| 1232 |
elif template.family == "window_function":
|
| 1233 |
-
anchor = sample_random_entity(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1234 |
if not anchor:
|
| 1235 |
return None
|
| 1236 |
|
| 1237 |
country = anchor.get('country') or 'US'
|
| 1238 |
-
target_subtype =
|
| 1239 |
|
| 1240 |
sql = template.sql_template.format(
|
| 1241 |
country=country,
|
|
@@ -1253,12 +1436,17 @@ def generate_template_based_sample(
|
|
| 1253 |
)
|
| 1254 |
|
| 1255 |
elif template.family == "attribute_filter":
|
| 1256 |
-
anchor = sample_random_entity(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1257 |
if not anchor:
|
| 1258 |
return None
|
| 1259 |
|
| 1260 |
country = anchor.get('country') or 'US'
|
| 1261 |
-
target_subtype = template.target_subtype or
|
| 1262 |
|
| 1263 |
sql = template.sql_template.format(
|
| 1264 |
country=country,
|
|
|
|
| 44 |
return (
|
| 45 |
sql
|
| 46 |
.replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
|
| 47 |
+
.replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
|
| 48 |
)
|
| 49 |
|
| 50 |
# Configurable parameters (can be overridden by CLI)
|
|
|
|
| 60 |
get_templates_by_family = sql_templates.get_templates_by_family
|
| 61 |
|
| 62 |
|
| 63 |
+
_NE_NAMED_LOOKUP_SUBTYPES = {
|
| 64 |
+
'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay',
|
| 65 |
+
'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression',
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
_NE_TEMPLATE_SUBTYPES = {
|
| 69 |
+
'lookup_02': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
|
| 70 |
+
'adj_03': {'sea', 'ocean'},
|
| 71 |
+
'adj_04': {'River', 'Lake', 'Basin'},
|
| 72 |
+
'adj_05': {'Range/mtn', 'Peninsula', 'Depression'},
|
| 73 |
+
'contain_03': {'sea', 'ocean', 'gulf', 'bay', 'Basin', 'Island group', 'Peninsula', 'Range/mtn', 'Depression'},
|
| 74 |
+
'contain_04': {'sea', 'ocean', 'gulf', 'bay', 'strait'},
|
| 75 |
+
'intersect_02': {'River', 'Lake', 'Basin', 'gulf', 'bay', 'strait', 'Range/mtn', 'Peninsula', 'Depression'},
|
| 76 |
+
'intersect_03': {'River', 'Lake', 'Basin', 'gulf', 'bay', 'strait', 'Range/mtn', 'Peninsula', 'Depression'},
|
| 77 |
+
'buffer_03': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
|
| 78 |
+
'buffer_04': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
|
| 79 |
+
'buffer_05': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
|
| 80 |
+
'chained_03': {'Island group', 'Peninsula', 'Range/mtn', 'Depression'},
|
| 81 |
+
'chained_04': {'River', 'Lake', 'Basin'},
|
| 82 |
+
'chained_05': {'Range/mtn', 'Depression'},
|
| 83 |
+
'chained_08': {'River', 'Lake', 'Basin'},
|
| 84 |
+
'chained_09': {'Range/mtn', 'Depression'},
|
| 85 |
+
'partial_05': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
|
| 86 |
+
'diff_02': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
class Candidate(BaseModel):
|
| 91 |
"""Candidate entity for grounding."""
|
| 92 |
candidate_id: str
|
|
|
|
| 148 |
'anchor_name': row['anchor_name'],
|
| 149 |
'anchor_subtype': row['anchor_subtype'],
|
| 150 |
'anchor_country': row.get('anchor_country'), # May not exist in all tables
|
| 151 |
+
'target_id': row.get('target_id'),
|
| 152 |
+
'target_name': row.get('target_name'),
|
| 153 |
'target_subtype': row.get('target_subtype')
|
| 154 |
}
|
| 155 |
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 174 |
+
"""Sample a random containment pair.
|
| 175 |
+
|
| 176 |
+
Returns both ends of the pair so callers that need the contained entity
|
| 177 |
+
(e.g. difference templates that clip container by contained) can use it
|
| 178 |
+
directly without a second random draw.
|
| 179 |
+
"""
|
| 180 |
if containment_df.empty:
|
| 181 |
return None
|
| 182 |
+
|
| 183 |
row = containment_df.sample(n=1).iloc[0]
|
| 184 |
return {
|
| 185 |
'container_id': row['container_id'],
|
| 186 |
'container_name': row['container_name'],
|
| 187 |
'container_subtype': row['container_subtype'],
|
| 188 |
+
'contained_id': row['contained_id'],
|
| 189 |
+
'contained_name': row['contained_name'],
|
| 190 |
+
'contained_subtype': row['contained_subtype'],
|
| 191 |
}
|
| 192 |
|
| 193 |
|
|
|
|
| 222 |
}
|
| 223 |
|
| 224 |
|
| 225 |
+
def sample_cross_source_anchor(
|
| 226 |
+
cross_source_df: pd.DataFrame,
|
| 227 |
+
natural_subtypes: Optional[set[str]] = None,
|
| 228 |
+
relation_types: Optional[set[str]] = None,
|
| 229 |
+
) -> Optional[Dict[str, Any]]:
|
| 230 |
+
"""Sample a random cross-source relation with optional subtype filters."""
|
| 231 |
if cross_source_df.empty:
|
| 232 |
return None
|
| 233 |
+
|
| 234 |
+
df = cross_source_df
|
| 235 |
+
if natural_subtypes is not None:
|
| 236 |
+
df = df[df['natural_subtype'].isin(natural_subtypes)]
|
| 237 |
+
if relation_types is not None:
|
| 238 |
+
df = df[df['relation_type'].isin(relation_types)]
|
| 239 |
+
if df.empty:
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
row = df.sample(n=1).iloc[0]
|
| 243 |
return {
|
| 244 |
'division_id': row['division_id'],
|
| 245 |
'division_name': row['division_name'],
|
|
|
|
| 290 |
difficulty: str = "medium"
|
| 291 |
) -> List[Candidate]:
|
| 292 |
"""Build candidate list with true anchor + distractors."""
|
| 293 |
+
|
| 294 |
# Helper to convert pandas NA to None
|
| 295 |
def safe_get(row, key, default=None):
|
| 296 |
val = row.get(key, default)
|
| 297 |
return None if pd.isna(val) else val
|
| 298 |
+
|
| 299 |
# Get the true anchor
|
| 300 |
if anchor_source == "divisions_area":
|
| 301 |
query = """
|
| 302 |
+
SELECT
|
| 303 |
id,
|
| 304 |
names."primary" AS name,
|
| 305 |
subtype,
|
|
|
|
| 312 |
anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
|
| 313 |
else:
|
| 314 |
query = """
|
| 315 |
+
SELECT
|
| 316 |
id,
|
| 317 |
names."primary" AS name,
|
| 318 |
subtype
|
|
|
|
| 320 |
WHERE id = ?
|
| 321 |
"""
|
| 322 |
anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
|
| 323 |
+
|
|
|
|
| 324 |
true_candidate = Candidate(
|
| 325 |
candidate_id="c1",
|
| 326 |
source=anchor_source,
|
|
|
|
| 330 |
country=safe_get(anchor_row, 'country'),
|
| 331 |
region=safe_get(anchor_row, 'region'),
|
| 332 |
admin_level=safe_get(anchor_row, 'admin_level'),
|
| 333 |
+
similarity=1.0,
|
| 334 |
)
|
| 335 |
+
|
|
|
|
| 336 |
distractors = build_distractors(
|
| 337 |
+
con,
|
| 338 |
+
anchor_name,
|
| 339 |
anchor_source,
|
| 340 |
anchor_id,
|
| 341 |
num_candidates - 1,
|
| 342 |
+
difficulty,
|
| 343 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
+
# Deduplicate by underlying entity id while preserving order.
|
| 346 |
+
# Some parquet sources contain repeated rows for the same feature id,
|
| 347 |
+
# which can otherwise leak duplicate candidates into the dataset.
|
| 348 |
+
candidates: List[Candidate] = []
|
| 349 |
+
seen_ids: set[str] = set()
|
| 350 |
+
for cand in [true_candidate] + distractors:
|
| 351 |
+
if cand.id in seen_ids:
|
| 352 |
+
continue
|
| 353 |
+
candidates.append(cand)
|
| 354 |
+
seen_ids.add(cand.id)
|
| 355 |
+
if len(candidates) >= num_candidates:
|
| 356 |
+
break
|
| 357 |
+
|
| 358 |
for i, cand in enumerate(candidates, 1):
|
| 359 |
cand.candidate_id = f"c{i}"
|
| 360 |
|
|
|
|
| 388 |
|
| 389 |
def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
|
| 390 |
query = """
|
| 391 |
+
WITH ranked AS (
|
| 392 |
+
SELECT
|
| 393 |
+
id,
|
| 394 |
+
names."primary" AS name,
|
| 395 |
+
subtype,
|
| 396 |
+
country,
|
| 397 |
+
region,
|
| 398 |
+
admin_level,
|
| 399 |
+
jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity,
|
| 400 |
+
ROW_NUMBER() OVER (
|
| 401 |
+
PARTITION BY id
|
| 402 |
+
ORDER BY jaro_winkler_similarity(lower(names."primary"), lower(?)) DESC
|
| 403 |
+
) AS rn
|
| 404 |
+
FROM read_parquet(?)
|
| 405 |
+
WHERE id != ?
|
| 406 |
+
AND names."primary" IS NOT NULL
|
| 407 |
+
AND trim(names."primary") != ''
|
| 408 |
+
AND geometry IS NOT NULL
|
| 409 |
+
)
|
| 410 |
SELECT
|
| 411 |
id,
|
| 412 |
+
name,
|
| 413 |
subtype,
|
| 414 |
country,
|
| 415 |
region,
|
| 416 |
admin_level,
|
| 417 |
+
similarity
|
| 418 |
+
FROM ranked
|
| 419 |
+
WHERE rn = 1
|
|
|
|
| 420 |
ORDER BY similarity DESC
|
| 421 |
LIMIT ?
|
| 422 |
"""
|
| 423 |
+
df = con.execute(query, [anchor_name, anchor_name, path, excl_id, n]).fetchdf()
|
| 424 |
results = []
|
| 425 |
for _, row in df.iterrows():
|
| 426 |
results.append(Candidate(
|
|
|
|
| 586 |
def sample_random_entity(
|
| 587 |
con: duckdb.DuckDBPyConnection,
|
| 588 |
inventory_df: pd.DataFrame,
|
| 589 |
+
source: str,
|
| 590 |
+
subtypes: Optional[set[str]] = None,
|
| 591 |
+
countries: Optional[set[str]] = None,
|
| 592 |
) -> Optional[Dict[str, Any]]:
|
| 593 |
+
"""Sample a random entity from inventory with optional filters."""
|
| 594 |
if inventory_df.empty:
|
| 595 |
return None
|
| 596 |
+
|
| 597 |
+
df = inventory_df
|
| 598 |
+
if subtypes is not None:
|
| 599 |
+
df = df[df['subtype'].isin(subtypes)]
|
| 600 |
+
if countries is not None and 'country' in df.columns:
|
| 601 |
+
df = df[df['country'].isin(countries)]
|
| 602 |
+
if df.empty:
|
| 603 |
+
return None
|
| 604 |
+
|
| 605 |
+
row = df.sample(n=1).iloc[0]
|
| 606 |
return {
|
| 607 |
'id': row['id'],
|
| 608 |
'name': row['name'],
|
|
|
|
| 626 |
if template.anchor_source == "divisions_area":
|
| 627 |
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 628 |
else:
|
| 629 |
+
anchor = sample_random_entity(
|
| 630 |
+
con,
|
| 631 |
+
tables['natural_earth_inventory'],
|
| 632 |
+
'natural_earth',
|
| 633 |
+
subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES),
|
| 634 |
+
)
|
| 635 |
|
| 636 |
if not anchor:
|
| 637 |
return None
|
|
|
|
| 722 |
anchor = {"id": pair["contained_id"], "name": pair["contained_name"]}
|
| 723 |
|
| 724 |
elif template.family == "adjacency":
|
| 725 |
+
# adj_03/04/05 target natural_earth features (seas, rivers, ranges).
|
| 726 |
+
# Their SQL hardcodes NE subtypes and does not use {target_subtype}.
|
| 727 |
+
# Sample from cross_source_relations so the anchor is a division
|
| 728 |
+
# that actually intersects the right NE features.
|
| 729 |
+
_NE_ADJ_SUBTYPES = {
|
| 730 |
+
"adj_03": ("ocean", "sea"),
|
| 731 |
+
"adj_04": ("River", "Lake", "Basin"),
|
| 732 |
+
"adj_05": ("Range/mtn", "Peninsula", "Depression"),
|
| 733 |
+
}
|
| 734 |
+
if template.template_id in _NE_ADJ_SUBTYPES:
|
| 735 |
+
cs_df = tables.get('cross_source_relations', pd.DataFrame())
|
| 736 |
+
if cs_df.empty:
|
| 737 |
+
return None
|
| 738 |
+
ne_types = _NE_ADJ_SUBTYPES[template.template_id]
|
| 739 |
+
filtered = cs_df[cs_df['natural_subtype'].isin(ne_types)]
|
| 740 |
+
if filtered.empty:
|
| 741 |
+
return None
|
| 742 |
+
row = filtered.sample(n=1).iloc[0]
|
| 743 |
+
anchor = {
|
| 744 |
+
'anchor_id': row['division_id'],
|
| 745 |
+
'anchor_name': row['division_name'],
|
| 746 |
+
'anchor_subtype': row['division_subtype'],
|
| 747 |
+
'target_subtype': row['natural_subtype'],
|
| 748 |
+
}
|
| 749 |
+
else:
|
| 750 |
+
# adj_01/02/06: divisions_area self-join adjacency.
|
| 751 |
+
# Only filter by target_subtype when the SQL uses {target_subtype}.
|
| 752 |
+
filter_subtype = (
|
| 753 |
+
template.target_subtype
|
| 754 |
+
if '{target_subtype}' in template.sql_template
|
| 755 |
+
else None
|
| 756 |
+
)
|
| 757 |
+
anchor = sample_adjacency_anchor(
|
| 758 |
+
tables['adjacency_pairs'],
|
| 759 |
+
target_subtype=filter_subtype,
|
| 760 |
+
)
|
| 761 |
if not anchor:
|
| 762 |
return None
|
| 763 |
|
| 764 |
sql = template.sql_template.format(
|
| 765 |
anchor_id=anchor['anchor_id'],
|
| 766 |
+
target_subtype=anchor.get('target_subtype', ''),
|
| 767 |
)
|
| 768 |
+
|
| 769 |
candidates = build_candidate_list(
|
| 770 |
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
|
| 771 |
num_candidates=10, difficulty="medium"
|
| 772 |
)
|
| 773 |
+
|
| 774 |
question = random.choice(template.question_hints).format(
|
| 775 |
anchor_name=anchor['anchor_name'],
|
| 776 |
+
target_subtype=anchor.get('target_subtype', ''),
|
| 777 |
)
|
| 778 |
|
| 779 |
elif template.family == "containment":
|
| 780 |
+
if template.anchor_source == "natural_earth":
|
| 781 |
+
# contain_03 / contain_04: NE anchor (sea, desert, etc.).
|
| 782 |
+
# Use cross_source_relations so the anchor exists in natural_earth
|
| 783 |
+
# and is guaranteed to intersect divisions_area features.
|
| 784 |
+
cs_anchor = sample_cross_source_anchor(
|
| 785 |
+
tables.get('cross_source_relations', pd.DataFrame()),
|
| 786 |
+
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
|
| 787 |
+
)
|
| 788 |
+
if not cs_anchor:
|
| 789 |
+
return None
|
| 790 |
+
anchor_id = cs_anchor['natural_id']
|
| 791 |
+
anchor_name = cs_anchor['natural_name']
|
| 792 |
+
target_subtype = template.target_subtype or 'country'
|
| 793 |
+
|
| 794 |
+
sql = template.sql_template.format(
|
| 795 |
+
anchor_id=anchor_id,
|
| 796 |
+
target_subtype=target_subtype,
|
| 797 |
+
)
|
| 798 |
+
candidates = build_candidate_list(
|
| 799 |
+
con, anchor_id, anchor_name, 'natural_earth',
|
| 800 |
+
num_candidates=10, difficulty="medium"
|
| 801 |
+
)
|
| 802 |
+
question = random.choice(template.question_hints).format(
|
| 803 |
+
anchor_name=anchor_name,
|
| 804 |
+
target_subtype=target_subtype,
|
| 805 |
+
)
|
| 806 |
+
anchor = {'id': anchor_id, 'name': anchor_name}
|
| 807 |
+
|
| 808 |
+
elif template.template_id == "contain_02":
|
| 809 |
+
# "What country contains X?" - anchor is the CONTAINED entity;
|
| 810 |
+
# result is the country that ST_Contains it.
|
| 811 |
+
df = tables['containment_pairs']
|
| 812 |
+
df = df[df['container_subtype'] == 'country']
|
| 813 |
+
pair = sample_containment_anchor(df)
|
| 814 |
+
if not pair:
|
| 815 |
+
return None
|
| 816 |
+
|
| 817 |
+
sql = template.sql_template.format(
|
| 818 |
+
anchor_id=pair['contained_id'],
|
| 819 |
+
target_subtype='country',
|
| 820 |
+
)
|
| 821 |
+
candidates = build_candidate_list(
|
| 822 |
+
con, pair['contained_id'], pair['contained_name'], 'divisions_area',
|
| 823 |
+
num_candidates=10, difficulty="medium"
|
| 824 |
+
)
|
| 825 |
+
question = random.choice(template.question_hints).format(
|
| 826 |
+
anchor_name=pair['contained_name'],
|
| 827 |
+
target_subtype='country',
|
| 828 |
+
)
|
| 829 |
+
anchor = {'id': pair['contained_id'], 'name': pair['contained_name']}
|
| 830 |
+
|
| 831 |
+
else:
|
| 832 |
+
# contain_01: standard containment.
|
| 833 |
+
# Anchor = container, target_subtype = contained entity's subtype.
|
| 834 |
+
anchor = sample_containment_anchor(tables['containment_pairs'])
|
| 835 |
+
if not anchor:
|
| 836 |
+
return None
|
| 837 |
+
|
| 838 |
+
sql = template.sql_template.format(
|
| 839 |
+
anchor_id=anchor['container_id'],
|
| 840 |
+
target_subtype=anchor['contained_subtype'],
|
| 841 |
+
)
|
| 842 |
+
candidates = build_candidate_list(
|
| 843 |
+
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
|
| 844 |
+
num_candidates=10, difficulty="medium"
|
| 845 |
+
)
|
| 846 |
+
question = random.choice(template.question_hints).format(
|
| 847 |
+
anchor_name=anchor['container_name'],
|
| 848 |
+
target_subtype=anchor['contained_subtype'],
|
| 849 |
+
)
|
| 850 |
|
| 851 |
elif template.family == "intersection":
|
| 852 |
if template.anchor_source == "natural_earth":
|
| 853 |
+
anchor = sample_cross_source_anchor(
|
| 854 |
+
tables['cross_source_relations'],
|
| 855 |
+
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
|
| 856 |
+
)
|
| 857 |
if not anchor:
|
| 858 |
return None
|
| 859 |
+
|
| 860 |
+
target_subtype = template.target_subtype or 'country'
|
| 861 |
+
|
| 862 |
sql = template.sql_template.format(
|
| 863 |
anchor_id=anchor['natural_id'],
|
| 864 |
+
target_subtype=target_subtype,
|
| 865 |
)
|
| 866 |
+
|
| 867 |
candidates = build_candidate_list(
|
| 868 |
con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
|
| 869 |
num_candidates=10, difficulty="medium"
|
| 870 |
)
|
| 871 |
+
|
| 872 |
question = random.choice(template.question_hints).format(
|
| 873 |
anchor_name=anchor['natural_name'],
|
| 874 |
+
target_subtype=target_subtype,
|
| 875 |
)
|
| 876 |
else:
|
| 877 |
# Same-source intersection
|
|
|
|
| 932 |
# country IN clause — 2 or 3 anchors, each contributes its country code
|
| 933 |
num_a = 3 if template.template_id == "contain_multi_02" else 2
|
| 934 |
anchors = [
|
| 935 |
+
sample_random_entity(
|
| 936 |
+
con,
|
| 937 |
+
tables['divisions_area_inventory'],
|
| 938 |
+
'divisions_area',
|
| 939 |
+
subtypes={'country'},
|
| 940 |
+
)
|
| 941 |
for _ in range(num_a)
|
| 942 |
]
|
| 943 |
if any(a is None for a in anchors):
|
| 944 |
return None
|
| 945 |
|
| 946 |
countries = [a.get('country') or 'US' for a in anchors]
|
| 947 |
+
target_subtype = template.target_subtype or 'region'
|
| 948 |
per_anchor = 3 if num_a == 3 else 4
|
| 949 |
|
| 950 |
fmt_kwargs = dict(
|
|
|
|
| 1027 |
|
| 1028 |
if template.num_anchors == 1:
|
| 1029 |
if template.anchor_source == "natural_earth":
|
| 1030 |
+
anchor = sample_random_entity(
|
| 1031 |
+
con,
|
| 1032 |
+
tables['natural_earth_inventory'],
|
| 1033 |
+
'natural_earth',
|
| 1034 |
+
subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES),
|
| 1035 |
+
)
|
| 1036 |
else:
|
| 1037 |
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 1038 |
if not anchor:
|
|
|
|
| 1126 |
# Mixed-source clip: division intersected with a natural_earth feature.
|
| 1127 |
# Use cross_source_relations so the pair is guaranteed to intersect —
|
| 1128 |
# random sampling almost never produces an intersecting pair.
|
| 1129 |
+
cs_anchor = sample_cross_source_anchor(
|
| 1130 |
+
tables.get('cross_source_relations', pd.DataFrame()),
|
| 1131 |
+
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
|
| 1132 |
+
)
|
| 1133 |
+
if not cs_anchor:
|
| 1134 |
return None
|
|
|
|
| 1135 |
clip_feature = {
|
| 1136 |
+
'id': cs_anchor['natural_id'],
|
| 1137 |
+
'name': cs_anchor['natural_name'],
|
| 1138 |
'source': 'natural_earth',
|
| 1139 |
}
|
| 1140 |
# Override the division anchor with the paired division so the
|
| 1141 |
# ST_Intersects check in the SQL is guaranteed to pass.
|
| 1142 |
anchor = {
|
| 1143 |
+
'id': cs_anchor['division_id'],
|
| 1144 |
+
'name': cs_anchor['division_name'],
|
| 1145 |
'source': 'divisions_area',
|
| 1146 |
}
|
| 1147 |
|
|
|
|
| 1170 |
target_subtype = random.choice(['locality', 'region'])
|
| 1171 |
|
| 1172 |
if template.template_id in ['agg_03', 'agg_04']:
|
| 1173 |
+
# Country-level aggregation: SQL uses country code, so the anchor
|
| 1174 |
+
# in the question must also be a country.
|
| 1175 |
+
anchor = sample_random_entity(
|
| 1176 |
+
con,
|
| 1177 |
+
tables['divisions_area_inventory'],
|
| 1178 |
+
'divisions_area',
|
| 1179 |
+
subtypes={'country'},
|
| 1180 |
+
)
|
| 1181 |
if not anchor:
|
| 1182 |
return None
|
| 1183 |
|
|
|
|
| 1225 |
elif template.family == "chained":
|
| 1226 |
# Use pre-filtered coastal/landlocked containment pairs so the SQL
|
| 1227 |
# verification step doesn't constantly return empty results.
|
| 1228 |
+
_COASTAL_CHAINED = {"chained_01", "chained_06", "chained_10"}
|
| 1229 |
+
_LANDLOCKED_CHAINED = {"chained_02", "chained_07", "chained_11"}
|
| 1230 |
+
if template.template_id in _COASTAL_CHAINED:
|
| 1231 |
table_key = 'coastal_containment_pairs'
|
| 1232 |
+
elif template.template_id in _LANDLOCKED_CHAINED:
|
| 1233 |
table_key = 'landlocked_containment_pairs'
|
| 1234 |
else:
|
| 1235 |
table_key = 'containment_pairs'
|
| 1236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1237 |
df = tables.get(table_key, tables['containment_pairs'])
|
| 1238 |
+
|
| 1239 |
+
# When the template pins a target_subtype (e.g. chained_06 wants
|
| 1240 |
+
# counties), only consider pairs whose contained entity already
|
| 1241 |
+
# matches — guarantees the sampled container holds at least one
|
| 1242 |
+
# entity of the right subtype so the SQL filter returns rows.
|
| 1243 |
+
if template.target_subtype:
|
| 1244 |
+
df = df[df['contained_subtype'] == template.target_subtype]
|
| 1245 |
+
|
| 1246 |
+
# chained_10/11 additionally need a country-level container so
|
| 1247 |
+
# phrasings like "coastal states of India" line up.
|
| 1248 |
+
if template.template_id in {"chained_10", "chained_11"}:
|
| 1249 |
+
df = df[df['container_subtype'] == 'country']
|
| 1250 |
|
| 1251 |
anchor = sample_containment_anchor(df)
|
| 1252 |
if not anchor:
|
| 1253 |
return None
|
| 1254 |
|
|
|
|
|
|
|
|
|
|
| 1255 |
target_subtype = template.target_subtype or anchor.get('contained_subtype', 'locality')
|
| 1256 |
|
| 1257 |
sql = template.sql_template.format(
|
|
|
|
| 1305 |
# Use cross_source_relations so the pair is guaranteed to intersect
|
| 1306 |
# (ST_Difference on non-intersecting geometries is always equal to
|
| 1307 |
# the original geometry — a trivial and uninformative sample).
|
| 1308 |
+
cs_anchor = sample_cross_source_anchor(
|
| 1309 |
+
tables.get('cross_source_relations', pd.DataFrame()),
|
| 1310 |
+
natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
|
| 1311 |
+
)
|
| 1312 |
+
if not cs_anchor:
|
| 1313 |
return None
|
|
|
|
| 1314 |
anchor = {
|
| 1315 |
+
'id': cs_anchor['division_id'],
|
| 1316 |
+
'name': cs_anchor['division_name'],
|
| 1317 |
'source': 'divisions_area',
|
| 1318 |
}
|
| 1319 |
clip_feature = {
|
| 1320 |
+
'id': cs_anchor['natural_id'],
|
| 1321 |
+
'name': cs_anchor['natural_name'],
|
| 1322 |
'source': 'natural_earth',
|
| 1323 |
}
|
| 1324 |
|
|
|
|
| 1343 |
candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
|
| 1344 |
|
| 1345 |
else:
|
| 1346 |
+
# Two divisions_area anchors: use both ends of a containment
|
| 1347 |
+
# pair so the contained entity is guaranteed to intersect the
|
| 1348 |
+
# container. ST_Difference(container, contained) yields the
|
| 1349 |
+
# portion of the container outside the contained piece.
|
| 1350 |
pair = sample_containment_anchor(tables['containment_pairs'])
|
| 1351 |
if not pair:
|
| 1352 |
return None
|
| 1353 |
|
| 1354 |
anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
|
| 1355 |
+
anchor2 = {'id': pair['contained_id'], 'name': pair['contained_name']}
|
|
|
|
|
|
|
|
|
|
| 1356 |
|
| 1357 |
sql = template.sql_template.format(
|
| 1358 |
+
anchor_id_1=anchor1['id'],
|
| 1359 |
anchor_id_2=anchor2['id'],
|
| 1360 |
)
|
| 1361 |
|
|
|
|
| 1380 |
if not pair:
|
| 1381 |
return None
|
| 1382 |
|
|
|
|
|
|
|
| 1383 |
anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
|
| 1384 |
+
anchor2 = {'id': pair['target_id'], 'name': pair['target_name']}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1385 |
|
| 1386 |
buffer_val = random.choice([5, 10, 25, 50])
|
| 1387 |
|
|
|
|
| 1408 |
)
|
| 1409 |
|
| 1410 |
elif template.family == "window_function":
|
| 1411 |
+
anchor = sample_random_entity(
|
| 1412 |
+
con,
|
| 1413 |
+
tables['divisions_area_inventory'],
|
| 1414 |
+
'divisions_area',
|
| 1415 |
+
subtypes={'country'},
|
| 1416 |
+
)
|
| 1417 |
if not anchor:
|
| 1418 |
return None
|
| 1419 |
|
| 1420 |
country = anchor.get('country') or 'US'
|
| 1421 |
+
target_subtype = template.target_subtype or 'locality'
|
| 1422 |
|
| 1423 |
sql = template.sql_template.format(
|
| 1424 |
country=country,
|
|
|
|
| 1436 |
)
|
| 1437 |
|
| 1438 |
elif template.family == "attribute_filter":
|
| 1439 |
+
anchor = sample_random_entity(
|
| 1440 |
+
con,
|
| 1441 |
+
tables['divisions_area_inventory'],
|
| 1442 |
+
'divisions_area',
|
| 1443 |
+
subtypes={'country'},
|
| 1444 |
+
)
|
| 1445 |
if not anchor:
|
| 1446 |
return None
|
| 1447 |
|
| 1448 |
country = anchor.get('country') or 'US'
|
| 1449 |
+
target_subtype = template.target_subtype or 'region'
|
| 1450 |
|
| 1451 |
sql = template.sql_template.format(
|
| 1452 |
country=country,
|
dataset/scripts/normalize_geodata.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalize source GeoParquet files to a shared CRS-neutral geometry encoding.
|
| 2 |
+
|
| 3 |
+
The training pipeline mixes Overture divisions_area and Natural Earth geometry.
|
| 4 |
+
Across environments these sources can advertise different CRS metadata labels
|
| 5 |
+
(`EPSG:4326` vs `OGC:CRS84`), which causes DuckDB spatial joins to fail even
|
| 6 |
+
when coordinates are already compatible lon/lat values.
|
| 7 |
+
|
| 8 |
+
This script rewrites both datasets into normalized copies whose geometry column
|
| 9 |
+
is rebuilt from WKB. That preserves coordinates while dropping conflicting CRS
|
| 10 |
+
metadata, so downstream joins behave consistently locally and on Modal.
|
| 11 |
+
|
| 12 |
+
Output layout under data/ by default:
|
| 13 |
+
overture_normalized/divisions_area/part-000.parquet
|
| 14 |
+
natural_earth_normalized/ne_geography.parquet
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import duckdb
|
| 20 |
+
|
| 21 |
+
from gazet.config import _DATA_DIR
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def normalize_geodata(output_root: Path | None = None) -> dict[str, str]:
|
| 25 |
+
"""Write normalized copies of both source datasets.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
output_root: Base directory to write normalized datasets into.
|
| 29 |
+
Defaults to the project data dir.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Mapping of dataset name to written path/glob.
|
| 33 |
+
"""
|
| 34 |
+
root = output_root or _DATA_DIR
|
| 35 |
+
overture_dir = root / "overture_normalized" / "divisions_area"
|
| 36 |
+
natural_earth_dir = root / "natural_earth_normalized"
|
| 37 |
+
overture_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
natural_earth_dir.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
overture_path = overture_dir / "part-000.parquet"
|
| 41 |
+
natural_earth_path = natural_earth_dir / "ne_geography.parquet"
|
| 42 |
+
|
| 43 |
+
con = duckdb.connect()
|
| 44 |
+
con.execute("INSTALL spatial")
|
| 45 |
+
con.execute("LOAD spatial")
|
| 46 |
+
|
| 47 |
+
# Rebuild geometry from WKB so conflicting CRS metadata is dropped.
|
| 48 |
+
con.execute(
|
| 49 |
+
f"""
|
| 50 |
+
COPY (
|
| 51 |
+
SELECT * REPLACE (
|
| 52 |
+
ST_GeomFromWKB(ST_AsWKB(geometry)) AS geometry
|
| 53 |
+
)
|
| 54 |
+
FROM read_parquet('{root / 'overture/divisions_area/*.parquet'}')
|
| 55 |
+
WHERE geometry IS NOT NULL
|
| 56 |
+
) TO '{overture_path}' (FORMAT PARQUET)
|
| 57 |
+
"""
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
con.execute(
|
| 61 |
+
f"""
|
| 62 |
+
COPY (
|
| 63 |
+
SELECT * REPLACE (
|
| 64 |
+
ST_GeomFromWKB(ST_AsWKB(geometry)) AS geometry
|
| 65 |
+
)
|
| 66 |
+
FROM read_parquet('{root / 'natural_earth_geoparquet/ne_geography.parquet'}')
|
| 67 |
+
WHERE geometry IS NOT NULL
|
| 68 |
+
) TO '{natural_earth_path}' (FORMAT PARQUET)
|
| 69 |
+
"""
|
| 70 |
+
)
|
| 71 |
+
con.close()
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"divisions_area": str(overture_dir / "*.parquet"),
|
| 75 |
+
"natural_earth": str(natural_earth_path),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
result = normalize_geodata()
|
| 81 |
+
print("Normalized datasets written:")
|
| 82 |
+
for name, path in result.items():
|
| 83 |
+
print(f" {name}: {path}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
main()
|
dataset/scripts/sql_templates.py
CHANGED
|
@@ -293,7 +293,7 @@ TEMPLATES = [
|
|
| 293 |
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 294 |
" FROM read_parquet('natural_earth') AS n, a"
|
| 295 |
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 296 |
-
" AND
|
| 297 |
),
|
| 298 |
question_hints=[
|
| 299 |
"which seas touch {anchor_name}?",
|
|
@@ -679,7 +679,7 @@ TEMPLATES = [
|
|
| 679 |
sql_difficulty="hard",
|
| 680 |
anchor_source="divisions_area",
|
| 681 |
num_anchors=1,
|
| 682 |
-
target_subtype="
|
| 683 |
sql_template=(
|
| 684 |
"WITH region AS ("
|
| 685 |
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
|
@@ -723,7 +723,7 @@ TEMPLATES = [
|
|
| 723 |
" AND ST_Within(b.geometry, region.geometry)"
|
| 724 |
" AND EXISTS ("
|
| 725 |
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 726 |
-
" WHERE n.subtype IN ('
|
| 727 |
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 728 |
" )"
|
| 729 |
),
|
|
@@ -1301,12 +1301,12 @@ TEMPLATES = [
|
|
| 1301 |
" AND subtype = '{target_subtype}'"
|
| 1302 |
),
|
| 1303 |
question_hints=[
|
| 1304 |
-
"
|
| 1305 |
-
"
|
| 1306 |
-
"which
|
| 1307 |
-
"land
|
| 1308 |
-
"
|
| 1309 |
-
"
|
| 1310 |
],
|
| 1311 |
),
|
| 1312 |
|
|
@@ -1339,20 +1339,20 @@ TEMPLATES = [
|
|
| 1339 |
sql_difficulty="medium",
|
| 1340 |
anchor_source="divisions_area",
|
| 1341 |
num_anchors=1,
|
| 1342 |
-
target_subtype="
|
| 1343 |
sql_template=(
|
| 1344 |
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 1345 |
" ST_AsGeoJSON(geometry) AS geometry"
|
| 1346 |
" FROM read_parquet('divisions_area')"
|
| 1347 |
" WHERE country = '{country}'"
|
| 1348 |
" AND subtype = '{target_subtype}'"
|
| 1349 |
-
" AND is_land =
|
| 1350 |
),
|
| 1351 |
question_hints=[
|
| 1352 |
-
"
|
| 1353 |
-
"{target_subtype}s
|
| 1354 |
-
"
|
| 1355 |
-
"
|
| 1356 |
],
|
| 1357 |
),
|
| 1358 |
|
|
@@ -1403,7 +1403,7 @@ TEMPLATES = [
|
|
| 1403 |
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 1404 |
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 1405 |
" FROM read_parquet('natural_earth') AS n, a"
|
| 1406 |
-
" WHERE n.subtype IN ('Range/
|
| 1407 |
" AND ST_Intersects(a.geometry, n.geometry)"
|
| 1408 |
),
|
| 1409 |
question_hints=[
|
|
@@ -1533,7 +1533,7 @@ TEMPLATES = [
|
|
| 1533 |
" AND ST_Within(b.geometry, region.geometry)"
|
| 1534 |
" AND EXISTS ("
|
| 1535 |
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1536 |
-
" WHERE n.subtype IN ('Range/
|
| 1537 |
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1538 |
" )"
|
| 1539 |
),
|
|
@@ -1666,7 +1666,7 @@ TEMPLATES = [
|
|
| 1666 |
" AND ST_Within(b.geometry, region.geometry)"
|
| 1667 |
" AND EXISTS ("
|
| 1668 |
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1669 |
-
" WHERE n.subtype IN ('Range/
|
| 1670 |
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1671 |
" )"
|
| 1672 |
),
|
|
|
|
| 293 |
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 294 |
" FROM read_parquet('natural_earth') AS n, a"
|
| 295 |
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 296 |
+
" AND ST_Intersects(a.geometry, n.geometry)"
|
| 297 |
),
|
| 298 |
question_hints=[
|
| 299 |
"which seas touch {anchor_name}?",
|
|
|
|
| 679 |
sql_difficulty="hard",
|
| 680 |
anchor_source="divisions_area",
|
| 681 |
num_anchors=1,
|
| 682 |
+
target_subtype="locality",
|
| 683 |
sql_template=(
|
| 684 |
"WITH region AS ("
|
| 685 |
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
|
|
|
| 723 |
" AND ST_Within(b.geometry, region.geometry)"
|
| 724 |
" AND EXISTS ("
|
| 725 |
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 726 |
+
" WHERE n.subtype IN ('Range/mtn', 'Island group', 'Peninsula', 'Depression')"
|
| 727 |
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 728 |
" )"
|
| 729 |
),
|
|
|
|
| 1301 |
" AND subtype = '{target_subtype}'"
|
| 1302 |
),
|
| 1303 |
question_hints=[
|
| 1304 |
+
"land {target_subtype}s of {anchor_name}",
|
| 1305 |
+
"dependencies of {anchor_name} that are on land",
|
| 1306 |
+
"which land dependencies belong to {anchor_name}?",
|
| 1307 |
+
"{anchor_name}'s land {target_subtype}s",
|
| 1308 |
+
"dependencies of {anchor_name} with land area",
|
| 1309 |
+
"show the land dependencies of {anchor_name}",
|
| 1310 |
],
|
| 1311 |
),
|
| 1312 |
|
|
|
|
| 1339 |
sql_difficulty="medium",
|
| 1340 |
anchor_source="divisions_area",
|
| 1341 |
num_anchors=1,
|
| 1342 |
+
target_subtype="region",
|
| 1343 |
sql_template=(
|
| 1344 |
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 1345 |
" ST_AsGeoJSON(geometry) AS geometry"
|
| 1346 |
" FROM read_parquet('divisions_area')"
|
| 1347 |
" WHERE country = '{country}'"
|
| 1348 |
" AND subtype = '{target_subtype}'"
|
| 1349 |
+
" AND is_land = FALSE"
|
| 1350 |
),
|
| 1351 |
question_hints=[
|
| 1352 |
+
"offshore {target_subtype}s of {anchor_name}",
|
| 1353 |
+
"{target_subtype}s of {anchor_name} that are not on land",
|
| 1354 |
+
"water-associated {target_subtype}s of {anchor_name}",
|
| 1355 |
+
"marine or offshore {target_subtype}s of {anchor_name}",
|
| 1356 |
],
|
| 1357 |
),
|
| 1358 |
|
|
|
|
| 1403 |
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 1404 |
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 1405 |
" FROM read_parquet('natural_earth') AS n, a"
|
| 1406 |
+
" WHERE n.subtype IN ('Range/mtn', 'Peninsula', 'Depression')"
|
| 1407 |
" AND ST_Intersects(a.geometry, n.geometry)"
|
| 1408 |
),
|
| 1409 |
question_hints=[
|
|
|
|
| 1533 |
" AND ST_Within(b.geometry, region.geometry)"
|
| 1534 |
" AND EXISTS ("
|
| 1535 |
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1536 |
+
" WHERE n.subtype IN ('Range/mtn', 'Depression')"
|
| 1537 |
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1538 |
" )"
|
| 1539 |
),
|
|
|
|
| 1666 |
" AND ST_Within(b.geometry, region.geometry)"
|
| 1667 |
" AND EXISTS ("
|
| 1668 |
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1669 |
+
" WHERE n.subtype IN ('Range/mtn', 'Depression')"
|
| 1670 |
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1671 |
" )"
|
| 1672 |
),
|
dataset/scripts/validate_dataset.py
CHANGED
|
@@ -44,8 +44,8 @@ def _resolve_paths(sql: str) -> str:
|
|
| 44 |
"read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
|
| 45 |
)
|
| 46 |
# Legacy fixed Docker paths from earlier dataset versions
|
| 47 |
-
sql = sql.replace("/data/overture/division_area/*.parquet",
|
| 48 |
-
sql = sql.replace("/data/overture/divisions_area/*.parquet",
|
| 49 |
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
|
| 50 |
return sql
|
| 51 |
|
|
|
|
| 44 |
"read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
|
| 45 |
)
|
| 46 |
# Legacy fixed Docker paths from earlier dataset versions
|
| 47 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
|
| 48 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
|
| 49 |
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
|
| 50 |
return sql
|
| 51 |
|
finetune/README.md
CHANGED
|
@@ -26,7 +26,7 @@ SQL samples are long (schema + candidates + SQL), places samples are short.
|
|
| 26 |
|
| 27 |
```bash
|
| 28 |
modal run finetune/check_token_lengths.py
|
| 29 |
-
modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/v1
|
| 30 |
```
|
| 31 |
|
| 32 |
This prints per-split statistics (min, max, P95, P99) and recommends a
|
|
@@ -66,7 +66,7 @@ overridden, `lora_alpha` is automatically set to `2 * r`.
|
|
| 66 |
|
| 67 |
```
|
| 68 |
base_model: unsloth/Qwen3.5-0.8B
|
| 69 |
-
run_dir: /mnt/gazet/data/v1
|
| 70 |
lora_r: 16
|
| 71 |
lora_alpha: 32 (2 * r, Unsloth recommendation for Qwen)
|
| 72 |
lora_dropout: 0.0
|
|
@@ -106,19 +106,19 @@ for local inference with llama-server.
|
|
| 106 |
|
| 107 |
```bash
|
| 108 |
# Download from Modal volume
|
| 109 |
-
modal volume get gazet checkpoints/qwen35-v1/merged ./finetune/models/merged
|
| 110 |
|
| 111 |
# Convert to GGUF (requires llama.cpp repo)
|
| 112 |
-
uv run \
|
| 113 |
--no-project \
|
| 114 |
--with transformers \
|
| 115 |
--with sentencepiece \
|
| 116 |
--with protobuf \
|
| 117 |
--with torch \
|
| 118 |
python convert_hf_to_gguf.py \
|
| 119 |
-
.
|
| 120 |
--outtype q8_0 \
|
| 121 |
-
--outfile .
|
| 122 |
```
|
| 123 |
|
| 124 |
---
|
|
@@ -129,7 +129,7 @@ uv run \
|
|
| 129 |
|
| 130 |
```bash
|
| 131 |
llama-server \
|
| 132 |
-
-m finetune/models/
|
| 133 |
-ngl 99 \
|
| 134 |
--port 9000 \
|
| 135 |
--ctx-size 2048
|
|
@@ -151,7 +151,7 @@ docker run \
|
|
| 151 |
-v $(pwd)/finetune/models:/models \
|
| 152 |
-p 9000:9000 \
|
| 153 |
ghcr.io/ggml-org/llama.cpp:server \
|
| 154 |
-
-m /models/
|
| 155 |
--port 9000 --host 0.0.0.0 \
|
| 156 |
--ctx-size 2048 -t 2 -v
|
| 157 |
```
|
|
@@ -211,7 +211,7 @@ All batch CLI args:
|
|
| 211 |
| `--label` | `local-gguf` | Label used in the output filename |
|
| 212 |
| `--task` | `sql` | `sql` or `places` |
|
| 213 |
| `--split` | `val` | Data split to evaluate (`val`, `test`) |
|
| 214 |
-
| `--run-dir` | `dataset/output/runs/v1` | Directory with `{task}/{split}.jsonl` |
|
| 215 |
| `--max-samples` | all | Cap the number of samples |
|
| 216 |
| `--output` | `eval-{label}-{task}.json` | Output JSON path |
|
| 217 |
| `--workers` | `4` | Concurrent requests; match llama-server `--parallel` |
|
|
@@ -250,6 +250,20 @@ GAZET_EVAL_DIR=/path/to/results streamlit run finetune/eval_demo.py
|
|
| 250 |
```
|
| 251 |
|
| 252 |
Set `GAZET_DATA_DIR` if your parquet data is not in the default `data/` directory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
---
|
| 255 |
|
|
@@ -287,5 +301,9 @@ loss is computed only on the completion tokens.
|
|
| 287 |
|
| 288 |
SQL in the training data uses symbolic path placeholders
|
| 289 |
(`read_parquet('divisions_area')`) instead of real file paths. At inference
|
| 290 |
-
time, `src/gazet/sql.py`
|
| 291 |
-
executing against DuckDB.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
```bash
|
| 28 |
modal run finetune/check_token_lengths.py
|
| 29 |
+
modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/smalltest-v1
|
| 30 |
```
|
| 31 |
|
| 32 |
This prints per-split statistics (min, max, P95, P99) and recommends a
|
|
|
|
| 66 |
|
| 67 |
```
|
| 68 |
base_model: unsloth/Qwen3.5-0.8B
|
| 69 |
+
run_dir: /mnt/gazet/data/v1 # override to your exported run, e.g. /mnt/gazet/data/smalltest-v1
|
| 70 |
lora_r: 16
|
| 71 |
lora_alpha: 32 (2 * r, Unsloth recommendation for Qwen)
|
| 72 |
lora_dropout: 0.0
|
|
|
|
| 106 |
|
| 107 |
```bash
|
| 108 |
# Download from Modal volume
|
| 109 |
+
modal volume get gazet checkpoints/qwen35-v1/merged ./finetune/models/qwen35-v1-merged
|
| 110 |
|
| 111 |
# Convert to GGUF (requires llama.cpp repo)
|
| 112 |
+
uv run \
|
| 113 |
--no-project \
|
| 114 |
--with transformers \
|
| 115 |
--with sentencepiece \
|
| 116 |
--with protobuf \
|
| 117 |
--with torch \
|
| 118 |
python convert_hf_to_gguf.py \
|
| 119 |
+
./finetune/models/qwen35-v1-merged \
|
| 120 |
--outtype q8_0 \
|
| 121 |
+
--outfile ./finetune/models/qwen35-v1-q8_0.gguf
|
| 122 |
```
|
| 123 |
|
| 124 |
---
|
|
|
|
| 129 |
|
| 130 |
```bash
|
| 131 |
llama-server \
|
| 132 |
+
-m finetune/models/qwen35-v1-q8_0.gguf \
|
| 133 |
-ngl 99 \
|
| 134 |
--port 9000 \
|
| 135 |
--ctx-size 2048
|
|
|
|
| 151 |
-v $(pwd)/finetune/models:/models \
|
| 152 |
-p 9000:9000 \
|
| 153 |
ghcr.io/ggml-org/llama.cpp:server \
|
| 154 |
+
-m /models/qwen35-v1-q8_0.gguf \
|
| 155 |
--port 9000 --host 0.0.0.0 \
|
| 156 |
--ctx-size 2048 -t 2 -v
|
| 157 |
```
|
|
|
|
| 211 |
| `--label` | `local-gguf` | Label used in the output filename |
|
| 212 |
| `--task` | `sql` | `sql` or `places` |
|
| 213 |
| `--split` | `val` | Data split to evaluate (`val`, `test`) |
|
| 214 |
+
| `--run-dir` | `dataset/output/runs/v1` | Directory with `{task}/{split}.jsonl`; override to your exported run, e.g. `dataset/output/runs/smalltest-v1` |
|
| 215 |
| `--max-samples` | all | Cap the number of samples |
|
| 216 |
| `--output` | `eval-{label}-{task}.json` | Output JSON path |
|
| 217 |
| `--workers` | `4` | Concurrent requests; match llama-server `--parallel` |
|
|
|
|
| 250 |
```
|
| 251 |
|
| 252 |
Set `GAZET_DATA_DIR` if your parquet data is not in the default `data/` directory.
|
| 253 |
+
This only affects the visual SQL viewer (`eval_demo.py`), which executes SQL
|
| 254 |
+
against DuckDB; `eval_cli.py` does not read parquet files directly.
|
| 255 |
+
|
| 256 |
+
The eval viewer resolves parquet paths through `gazet.config`, which now
|
| 257 |
+
prefers normalized copies automatically when present:
|
| 258 |
+
|
| 259 |
+
- `data/overture_normalized/divisions_area/*.parquet`
|
| 260 |
+
- `data/natural_earth_normalized/ne_geography.parquet`
|
| 261 |
+
|
| 262 |
+
Disable that fallback only if needed with:
|
| 263 |
+
|
| 264 |
+
```bash
|
| 265 |
+
GAZET_USE_NORMALIZED_DATA=0 streamlit run finetune/eval_demo.py
|
| 266 |
+
```
|
| 267 |
|
| 268 |
---
|
| 269 |
|
|
|
|
| 301 |
|
| 302 |
SQL in the training data uses symbolic path placeholders
|
| 303 |
(`read_parquet('divisions_area')`) instead of real file paths. At inference
|
| 304 |
+
and eval time, `src/gazet/sql.py` / `finetune/eval_demo.py` replace these with
|
| 305 |
+
actual runtime paths before executing against DuckDB. When normalized parquet
|
| 306 |
+
copies are present, `gazet.config` prefers:
|
| 307 |
+
|
| 308 |
+
- `data/overture_normalized/divisions_area/*.parquet`
|
| 309 |
+
- `data/natural_earth_normalized/ne_geography.parquet`
|
finetune/eval_demo.py
CHANGED
|
@@ -16,6 +16,8 @@ import pydeck as pdk
|
|
| 16 |
import sqlparse
|
| 17 |
import streamlit as st
|
| 18 |
|
|
|
|
|
|
|
| 19 |
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
| 20 |
DATA_DIR = pathlib.Path(
|
| 21 |
os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
|
|
@@ -31,13 +33,17 @@ def load_eval_results(path):
|
|
| 31 |
|
| 32 |
|
| 33 |
def rewrite_data_paths(sql):
|
| 34 |
-
"""Replace symbolic and legacy paths with
|
| 35 |
-
# Legacy fixed Docker paths must be replaced first to avoid double-expansion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
sql = sql.replace("/data/", f"{DATA_DIR}/")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
sql = sql.replace("read_parquet('divisions_area')", f"read_parquet('{div_path}')")
|
| 40 |
-
sql = sql.replace("read_parquet('natural_earth')", f"read_parquet('{ne_path}')")
|
| 41 |
return sql
|
| 42 |
|
| 43 |
|
|
|
|
| 16 |
import sqlparse
|
| 17 |
import streamlit as st
|
| 18 |
|
| 19 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 20 |
+
|
| 21 |
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
| 22 |
DATA_DIR = pathlib.Path(
|
| 23 |
os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def rewrite_data_paths(sql):
|
| 36 |
+
"""Replace symbolic and legacy paths with the configured runtime data paths."""
|
| 37 |
+
# Legacy fixed Docker paths must be replaced first to avoid double-expansion.
|
| 38 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
|
| 39 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
|
| 40 |
+
sql = sql.replace(
|
| 41 |
+
"/data/natural_earth_geoparquet/ne_geography.parquet",
|
| 42 |
+
NATURAL_EARTH_PATH,
|
| 43 |
+
)
|
| 44 |
sql = sql.replace("/data/", f"{DATA_DIR}/")
|
| 45 |
+
sql = sql.replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
|
| 46 |
+
sql = sql.replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
|
|
|
|
|
|
|
| 47 |
return sql
|
| 48 |
|
| 49 |
|
finetune/train_modal_qwen35.py
CHANGED
|
@@ -101,7 +101,7 @@ class Qwen35Config:
|
|
| 101 |
# Logging / saving
|
| 102 |
logging_steps: int = 10
|
| 103 |
save_strategy: str = "steps"
|
| 104 |
-
save_steps: int =
|
| 105 |
eval_strategy: str = "steps"
|
| 106 |
eval_steps: int = 200
|
| 107 |
report_to: str = "trackio"
|
|
|
|
| 101 |
# Logging / saving
|
| 102 |
logging_steps: int = 10
|
| 103 |
save_strategy: str = "steps"
|
| 104 |
+
save_steps: int = 1000
|
| 105 |
eval_strategy: str = "steps"
|
| 106 |
eval_steps: int = 200
|
| 107 |
report_to: str = "trackio"
|
gazet_demo.py
CHANGED
|
@@ -68,24 +68,58 @@ def view_state_for_bbox(bbox, padding_zoom=0.8):
|
|
| 68 |
return pdk.ViewState(latitude=lat, longitude=lng, zoom=zoom)
|
| 69 |
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def _render_map(geojson, placeholder):
|
| 72 |
-
|
|
|
|
| 73 |
if pdk and n:
|
|
|
|
| 74 |
layer = pdk.Layer(
|
| 75 |
"GeoJsonLayer",
|
| 76 |
data=geojson,
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
pickable=True,
|
| 81 |
)
|
| 82 |
-
bbox = bbox_from_geojson(geojson)
|
| 83 |
-
view = (
|
| 84 |
-
view_state_for_bbox(bbox)
|
| 85 |
-
if bbox
|
| 86 |
-
else pdk.ViewState(latitude=0, longitude=0, zoom=1)
|
| 87 |
-
)
|
| 88 |
with placeholder.container():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
st.pydeck_chart(
|
| 90 |
pdk.Deck(
|
| 91 |
layers=[layer],
|
|
@@ -128,6 +162,8 @@ st.sidebar.caption(
|
|
| 128 |
|
| 129 |
if "run_q" not in st.session_state:
|
| 130 |
st.session_state.run_q = None
|
|
|
|
|
|
|
| 131 |
|
| 132 |
col1, col2 = st.columns([1, 2])
|
| 133 |
with col1:
|
|
@@ -150,6 +186,7 @@ with col2:
|
|
| 150 |
to_run = st.session_state.run_q
|
| 151 |
if to_run:
|
| 152 |
st.session_state.run_q = None
|
|
|
|
| 153 |
|
| 154 |
status_ph = st.empty()
|
| 155 |
map_ph = st.empty()
|
|
@@ -159,6 +196,8 @@ with col2:
|
|
| 159 |
|
| 160 |
status_ph.info("Extracting places…")
|
| 161 |
|
|
|
|
|
|
|
| 162 |
try:
|
| 163 |
with requests.get(
|
| 164 |
f"{API}/search/stream", params={"q": to_run, "backend": backend}, stream=True, timeout=120
|
|
@@ -173,6 +212,7 @@ with col2:
|
|
| 173 |
|
| 174 |
if t == "places":
|
| 175 |
places = event["data"].get("places", [])
|
|
|
|
| 176 |
status_ph.info("Fuzzy-matching candidates…")
|
| 177 |
if places:
|
| 178 |
with places_ph.container():
|
|
@@ -192,6 +232,7 @@ with col2:
|
|
| 192 |
)
|
| 193 |
|
| 194 |
elif t == "candidates":
|
|
|
|
| 195 |
status_ph.info("Generating SQL…")
|
| 196 |
with candidates_ph.container():
|
| 197 |
with st.expander("Candidate datasets", expanded=True):
|
|
@@ -203,6 +244,7 @@ with col2:
|
|
| 203 |
|
| 204 |
elif t == "sql_attempt":
|
| 205 |
iteration = event.get("iteration", "")
|
|
|
|
| 206 |
status_ph.info(f"Running SQL (attempt {iteration})…")
|
| 207 |
with sql_ph.container():
|
| 208 |
with st.expander("SQL", expanded=True):
|
|
@@ -216,6 +258,7 @@ with col2:
|
|
| 216 |
|
| 217 |
elif t == "geojson":
|
| 218 |
geojson = event["data"]
|
|
|
|
| 219 |
n = len(geojson.get("features", []))
|
| 220 |
status_ph.success(f"**{to_run}** → {n} feature(s)")
|
| 221 |
_render_map(geojson, map_ph)
|
|
@@ -227,3 +270,31 @@ with col2:
|
|
| 227 |
status_ph.error(
|
| 228 |
f"API error: {e}. Is the API running? `uv run uvicorn gazet.api:app --reload`"
|
| 229 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
return pdk.ViewState(latitude=lat, longitude=lng, zoom=zoom)
|
| 69 |
|
| 70 |
|
| 71 |
+
def _has_line_geometries(features):
|
| 72 |
+
"""Return True if features are predominantly line/point (non-polygon) geometries."""
|
| 73 |
+
line_types = {"LineString", "MultiLineString", "Point", "MultiPoint"}
|
| 74 |
+
count = sum(
|
| 75 |
+
1 for f in features
|
| 76 |
+
if f.get("geometry", {}).get("type") in line_types
|
| 77 |
+
)
|
| 78 |
+
return count > len(features) / 2
|
| 79 |
+
|
| 80 |
+
|
| 81 |
def _render_map(geojson, placeholder):
|
| 82 |
+
features = geojson.get("features", [])
|
| 83 |
+
n = len(features)
|
| 84 |
if pdk and n:
|
| 85 |
+
is_linear = _has_line_geometries(features)
|
| 86 |
layer = pdk.Layer(
|
| 87 |
"GeoJsonLayer",
|
| 88 |
data=geojson,
|
| 89 |
+
stroked=True,
|
| 90 |
+
filled=not is_linear,
|
| 91 |
+
get_fill_color=[40, 180, 160, 120],
|
| 92 |
+
get_line_color=[0, 140, 255, 255] if is_linear else [10, 50, 46, 255],
|
| 93 |
+
get_line_width=500 if is_linear else 80,
|
| 94 |
+
line_width_min_pixels=2 if is_linear else 1,
|
| 95 |
pickable=True,
|
| 96 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
with placeholder.container():
|
| 98 |
+
selected_idx = None
|
| 99 |
+
if n > 1:
|
| 100 |
+
names = [
|
| 101 |
+
f.get("properties", {}).get("name", f"Feature {i}")
|
| 102 |
+
for i, f in enumerate(features)
|
| 103 |
+
]
|
| 104 |
+
choice = st.selectbox(
|
| 105 |
+
"Zoom to feature",
|
| 106 |
+
["All features"] + names,
|
| 107 |
+
key="feature_zoom",
|
| 108 |
+
)
|
| 109 |
+
if choice != "All features":
|
| 110 |
+
selected_idx = names.index(choice)
|
| 111 |
+
|
| 112 |
+
if selected_idx is not None:
|
| 113 |
+
single = {"type": "FeatureCollection", "features": [features[selected_idx]]}
|
| 114 |
+
bbox = bbox_from_geojson(single)
|
| 115 |
+
else:
|
| 116 |
+
bbox = bbox_from_geojson(geojson)
|
| 117 |
+
|
| 118 |
+
view = (
|
| 119 |
+
view_state_for_bbox(bbox)
|
| 120 |
+
if bbox
|
| 121 |
+
else pdk.ViewState(latitude=0, longitude=0, zoom=1)
|
| 122 |
+
)
|
| 123 |
st.pydeck_chart(
|
| 124 |
pdk.Deck(
|
| 125 |
layers=[layer],
|
|
|
|
| 162 |
|
| 163 |
if "run_q" not in st.session_state:
|
| 164 |
st.session_state.run_q = None
|
| 165 |
+
if "last_result" not in st.session_state:
|
| 166 |
+
st.session_state.last_result = None
|
| 167 |
|
| 168 |
col1, col2 = st.columns([1, 2])
|
| 169 |
with col1:
|
|
|
|
| 186 |
to_run = st.session_state.run_q
|
| 187 |
if to_run:
|
| 188 |
st.session_state.run_q = None
|
| 189 |
+
st.session_state.last_result = None
|
| 190 |
|
| 191 |
status_ph = st.empty()
|
| 192 |
map_ph = st.empty()
|
|
|
|
| 196 |
|
| 197 |
status_ph.info("Extracting places…")
|
| 198 |
|
| 199 |
+
result = {"query": to_run, "places": None, "candidates": None, "sql": None, "geojson": None}
|
| 200 |
+
|
| 201 |
try:
|
| 202 |
with requests.get(
|
| 203 |
f"{API}/search/stream", params={"q": to_run, "backend": backend}, stream=True, timeout=120
|
|
|
|
| 212 |
|
| 213 |
if t == "places":
|
| 214 |
places = event["data"].get("places", [])
|
| 215 |
+
result["places"] = places
|
| 216 |
status_ph.info("Fuzzy-matching candidates…")
|
| 217 |
if places:
|
| 218 |
with places_ph.container():
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
elif t == "candidates":
|
| 235 |
+
result["candidates"] = event["data"]
|
| 236 |
status_ph.info("Generating SQL…")
|
| 237 |
with candidates_ph.container():
|
| 238 |
with st.expander("Candidate datasets", expanded=True):
|
|
|
|
| 244 |
|
| 245 |
elif t == "sql_attempt":
|
| 246 |
iteration = event.get("iteration", "")
|
| 247 |
+
result["sql"] = event["data"]
|
| 248 |
status_ph.info(f"Running SQL (attempt {iteration})…")
|
| 249 |
with sql_ph.container():
|
| 250 |
with st.expander("SQL", expanded=True):
|
|
|
|
| 258 |
|
| 259 |
elif t == "geojson":
|
| 260 |
geojson = event["data"]
|
| 261 |
+
result["geojson"] = geojson
|
| 262 |
n = len(geojson.get("features", []))
|
| 263 |
status_ph.success(f"**{to_run}** → {n} feature(s)")
|
| 264 |
_render_map(geojson, map_ph)
|
|
|
|
| 270 |
status_ph.error(
|
| 271 |
f"API error: {e}. Is the API running? `uv run uvicorn gazet.api:app --reload`"
|
| 272 |
)
|
| 273 |
+
|
| 274 |
+
st.session_state.last_result = result
|
| 275 |
+
|
| 276 |
+
elif st.session_state.last_result:
|
| 277 |
+
result = st.session_state.last_result
|
| 278 |
+
query = result["query"]
|
| 279 |
+
n_feat = len((result["geojson"] or {}).get("features", []))
|
| 280 |
+
st.success(f"**{query}** -> {n_feat} feature(s)")
|
| 281 |
+
_render_map(result["geojson"], st.empty())
|
| 282 |
+
if result["places"]:
|
| 283 |
+
with st.expander("Extracted place names"):
|
| 284 |
+
st.dataframe(
|
| 285 |
+
pd.DataFrame(result["places"]).rename(
|
| 286 |
+
columns={"place": "Place", "country": "Country", "subtype": "Subtype"}
|
| 287 |
+
),
|
| 288 |
+
use_container_width=True,
|
| 289 |
+
hide_index=True,
|
| 290 |
+
)
|
| 291 |
+
if result["candidates"]:
|
| 292 |
+
with st.expander("Candidate datasets"):
|
| 293 |
+
st.dataframe(
|
| 294 |
+
pd.DataFrame(result["candidates"]),
|
| 295 |
+
use_container_width=True,
|
| 296 |
+
hide_index=True,
|
| 297 |
+
)
|
| 298 |
+
if result["sql"]:
|
| 299 |
+
with st.expander("SQL"):
|
| 300 |
+
st.code(result["sql"], language="sql")
|
ingest/convert_natural_earth.py
CHANGED
|
@@ -107,7 +107,7 @@ def _load_shapefile(src: pathlib.Path, source_key: str) -> gpd.GeoDataFrame:
|
|
| 107 |
|
| 108 |
# subtype: featurecla or source key
|
| 109 |
if "featurecla" in gdf.columns:
|
| 110 |
-
subtype = gdf["featurecla"]
|
| 111 |
else:
|
| 112 |
subtype = pd.Series([source_key] * n)
|
| 113 |
|
|
|
|
| 107 |
|
| 108 |
# subtype: featurecla or source key
|
| 109 |
if "featurecla" in gdf.columns:
|
| 110 |
+
subtype = gdf["featurecla"].str.lower()
|
| 111 |
else:
|
| 112 |
subtype = pd.Series([source_key] * n)
|
| 113 |
|
src/gazet/config.py
CHANGED
|
@@ -6,8 +6,33 @@ import pathlib
|
|
| 6 |
_DATA_DIR = pathlib.Path(os.environ.get("GAZET_DATA_DIR", str(
|
| 7 |
pathlib.Path(__file__).resolve().parent.parent.parent / "data"
|
| 8 |
)))
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# MODEL = "qwen3.5:cloud"
|
| 13 |
# MODEL = "granite4:350m"
|
|
|
|
| 6 |
_DATA_DIR = pathlib.Path(os.environ.get("GAZET_DATA_DIR", str(
|
| 7 |
pathlib.Path(__file__).resolve().parent.parent.parent / "data"
|
| 8 |
)))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _prefer_normalized(path_normalized: pathlib.Path, path_original: pathlib.Path) -> pathlib.Path:
|
| 12 |
+
"""Prefer normalized geodata copies when present."""
|
| 13 |
+
use_normalized = os.environ.get("GAZET_USE_NORMALIZED_DATA", "1") != "0"
|
| 14 |
+
if use_normalized:
|
| 15 |
+
parent = path_normalized.parent
|
| 16 |
+
if "*" in path_normalized.name:
|
| 17 |
+
if parent.exists() and any(parent.glob(path_normalized.name)):
|
| 18 |
+
return path_normalized
|
| 19 |
+
elif path_normalized.exists():
|
| 20 |
+
return path_normalized
|
| 21 |
+
return path_original
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
DIVISIONS_AREA_PATH = str(
|
| 25 |
+
_prefer_normalized(
|
| 26 |
+
_DATA_DIR / "overture_normalized/divisions_area/*.parquet",
|
| 27 |
+
_DATA_DIR / "overture/divisions_area/*.parquet",
|
| 28 |
+
)
|
| 29 |
+
)
|
| 30 |
+
NATURAL_EARTH_PATH = str(
|
| 31 |
+
_prefer_normalized(
|
| 32 |
+
_DATA_DIR / "natural_earth_normalized/ne_geography.parquet",
|
| 33 |
+
_DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet",
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
|
| 37 |
# MODEL = "qwen3.5:cloud"
|
| 38 |
# MODEL = "granite4:350m"
|
src/gazet/sql.py
CHANGED
|
@@ -74,6 +74,28 @@ def _rewrite_data_paths(sql: str) -> str:
|
|
| 74 |
return sql
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def _strip_fences(sql: Optional[str]) -> str:
|
| 78 |
"""Remove markdown code fences that the LM may wrap the SQL in."""
|
| 79 |
if not sql:
|
|
@@ -143,6 +165,7 @@ def run_geo_sql_gguf(
|
|
| 143 |
return
|
| 144 |
|
| 145 |
sql = _rewrite_data_paths(sql)
|
|
|
|
| 146 |
print(f"\n[SQL·GGUF] Generated:\n{sql}\n")
|
| 147 |
yield {"type": "sql_attempt", "sql": sql, "iteration": 1}
|
| 148 |
yield from _execute_sql(con, sql, "SQL·GGUF", iteration=1)
|
|
@@ -183,6 +206,8 @@ def run_geo_sql_dspy(
|
|
| 183 |
execution_error=error,
|
| 184 |
)
|
| 185 |
sql = _strip_fences(pred.sql)
|
|
|
|
|
|
|
| 186 |
except Exception as exc:
|
| 187 |
error = f"LM generation failed: {exc}"
|
| 188 |
print(f"Generation error: {error}")
|
|
|
|
| 74 |
return sql
|
| 75 |
|
| 76 |
|
| 77 |
+
# Title-cased NE subtype literals the trained model may emit.
|
| 78 |
+
# Data is now fully lowercased, so we normalise at query time.
|
| 79 |
+
_NE_SUBTYPE_FIXES = {
|
| 80 |
+
"'River'": "'river'",
|
| 81 |
+
"'Lake'": "'lake'",
|
| 82 |
+
"'Basin'": "'basin'",
|
| 83 |
+
"'Range/mtn'": "'range/mtn'",
|
| 84 |
+
"'Peninsula'": "'peninsula'",
|
| 85 |
+
"'Depression'": "'depression'",
|
| 86 |
+
"'Island group'": "'island group'",
|
| 87 |
+
"'Ocean'": "'ocean'",
|
| 88 |
+
"'Sea'": "'sea'",
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _normalize_ne_subtypes(sql: str) -> str:
|
| 93 |
+
"""Lowercase known NE subtype literals so they match the normalised data."""
|
| 94 |
+
for old, new in _NE_SUBTYPE_FIXES.items():
|
| 95 |
+
sql = sql.replace(old, new)
|
| 96 |
+
return sql
|
| 97 |
+
|
| 98 |
+
|
| 99 |
def _strip_fences(sql: Optional[str]) -> str:
|
| 100 |
"""Remove markdown code fences that the LM may wrap the SQL in."""
|
| 101 |
if not sql:
|
|
|
|
| 165 |
return
|
| 166 |
|
| 167 |
sql = _rewrite_data_paths(sql)
|
| 168 |
+
sql = _normalize_ne_subtypes(sql)
|
| 169 |
print(f"\n[SQL·GGUF] Generated:\n{sql}\n")
|
| 170 |
yield {"type": "sql_attempt", "sql": sql, "iteration": 1}
|
| 171 |
yield from _execute_sql(con, sql, "SQL·GGUF", iteration=1)
|
|
|
|
| 206 |
execution_error=error,
|
| 207 |
)
|
| 208 |
sql = _strip_fences(pred.sql)
|
| 209 |
+
sql = _rewrite_data_paths(sql)
|
| 210 |
+
sql = _normalize_ne_subtypes(sql)
|
| 211 |
except Exception as exc:
|
| 212 |
error = f"LM generation failed: {exc}"
|
| 213 |
print(f"Generation error: {error}")
|