srmsoumya commited on
Commit
dfb9466
·
1 Parent(s): ca28f70

Fix: No pairs are created for mixed queries

Browse files

- Normalize overture & natural earth to EPSG:4326
- Add better adjancency matrix to increase the overlap in query pairs
- Fix titled & lower case subtype names in NE
- Reduce number of threads in duckdb to prevent memory issues

dataset/README.md CHANGED
@@ -20,6 +20,20 @@ uv sync
20
  You need the Overture and Natural Earth parquet files under `data/` locally,
21
  or on a Modal volume if running in the cloud.
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ---
24
 
25
  ## Option A — Run locally (small datasets, development)
@@ -72,6 +86,14 @@ Modal uses two volumes:
72
 
73
  **Step 1 — One-time setup (only first time, or when source parquets change)**
74
 
 
 
 
 
 
 
 
 
75
  ```bash
76
  modal setup # authenticate
77
  gazet-dataset modal-upload --config dataset/config.yaml # ~15 min, uploads data/ to gazet-data volume
@@ -81,7 +103,7 @@ Verify:
81
 
82
  ```bash
83
  modal volume ls gazet-data
84
- # should show: overture/, natural_earth/, natural_earth_geoparquet/
85
  ```
86
 
87
  Skip this step on subsequent runs — the volume persists across runs.
@@ -207,6 +229,23 @@ by default; pass `--fresh` to overwrite existing samples.
207
 
208
  ---
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  ## Troubleshooting
211
 
212
  **Very few samples generated for a family**
 
20
  You need the Overture and Natural Earth parquet files under `data/` locally,
21
  or on a Modal volume if running in the cloud.
22
 
23
+ Before large runs, normalize them once so both datasets use harmonized
24
+ geometry metadata and cross-source joins behave the same locally and on Modal:
25
+
26
+ ```bash
27
+ gazet-dataset normalize-data --config dataset/config.yaml
28
+ ```
29
+
30
+ This writes:
31
+
32
+ - `data/overture_normalized/divisions_area/part-000.parquet`
33
+ - `data/natural_earth_normalized/ne_geography.parquet`
34
+
35
+ When those files exist, `gazet.config` will prefer them automatically.
36
+
37
  ---
38
 
39
  ## Option A — Run locally (small datasets, development)
 
86
 
87
  **Step 1 — One-time setup (only first time, or when source parquets change)**
88
 
89
+ First normalize the source geodata locally:
90
+
91
+ ```bash
92
+ gazet-dataset normalize-data --config dataset/config.yaml
93
+ ```
94
+
95
+ Then upload `data/` to Modal:
96
+
97
  ```bash
98
  modal setup # authenticate
99
  gazet-dataset modal-upload --config dataset/config.yaml # ~15 min, uploads data/ to gazet-data volume
 
103
 
104
  ```bash
105
  modal volume ls gazet-data
106
+ # should show: overture/, overture_normalized/, natural_earth_geoparquet/, natural_earth_normalized/
107
  ```
108
 
109
  Skip this step on subsequent runs — the volume persists across runs.
 
229
 
230
  ---
231
 
232
+ ## Data quality checks
233
+
234
+ After a run, spot-check the output with the pytest suite:
235
+
236
+ ```bash
237
+ uv run --extra dev pytest dataset/tests/ -v
238
+ ```
239
+
240
+ The suite reads `dataset/output/dataset_validated.jsonl` plus the exported
241
+ `runs/{run_name}/*.jsonl` and verifies: schema, no unresolved `{placeholders}`
242
+ in questions, candidate refs resolve, SQL shape, template coverage,
243
+ subtype-filtered templates match their phrasing, disambiguation samples have
244
+ same-name distractors, and exported assistant payloads parse as valid
245
+ JSON / SQL. Tests skip gracefully when outputs are missing.
246
+
247
+ ---
248
+
249
  ## Troubleshooting
250
 
251
  **Very few samples generated for a family**
dataset/config.smalltest.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ countries:
2
+ - BE
3
+ - CH
4
+ - AE
5
+
6
+ sample_targets:
7
+ direct_lookup: 4
8
+ disambiguation: 6
9
+ adjacency: 12
10
+ multi_adjacency: 2
11
+ containment: 8
12
+ intersection: 8
13
+ buffer: 10
14
+ chained: 22
15
+ difference: 4
16
+ border_corridor: 2
17
+ set_operations: 10
18
+ partial_selection: 10
19
+ aggregation: 8
20
+ window_function: 4
21
+ attribute_filter: 6
22
+
23
+ generation:
24
+ max_workers: 4
25
+ retry_multiplier: 2
26
+ append_mode: false
27
+
28
+ auto_scaling:
29
+ safety_factor: 1.25
30
+ manual_limits: {}
31
+
32
+ modal:
33
+ volume_name: "gazet-data"
34
+ app_name: "gazet-dataset"
35
+ num_containers: 10
36
+ container_cpu: 2
37
+ container_memory: 4096
38
+ timeout: 7200
39
+
40
+ run_name: "smalltest-v1"
dataset/config.yaml CHANGED
@@ -17,21 +17,21 @@ countries:
17
  # Bumped families with many templates or mixed-source variants so each
18
  # template_id gets enough coverage after uniform sampling + stratified split.
19
  sample_targets:
20
- direct_lookup: 500
21
- disambiguation: 1500 # 3 templates (disambiguate_01..03) - "Puri, Odisha" pattern
22
- adjacency: 1800 # 6 templates (adj_01..06) - adj_06 is counties
23
- multi_adjacency: 300
24
- containment: 1200 # 4 templates (contain_01..04)
25
- intersection: 1200 # 4 templates (intersect_01..04)
26
- buffer: 1200 # 5 templates (buffer_01..05)
27
- chained: 3300 # 11 templates (chained_01..11) - 10/11 are coastal/inland regions
28
- difference: 900 # 2 templates, one is mixed (diff_02)
29
- border_corridor: 300
30
- set_operations: 900
31
- partial_selection: 1500 # 5 templates, one is mixed (partial_05)
32
- aggregation: 750
33
- window_function: 300
34
- attribute_filter: 600 # 3 templates (attr_01..03)
35
 
36
  # Generation settings
37
  generation:
@@ -58,7 +58,7 @@ modal:
58
  num_containers: 100 # Number of parallel containers for sample generation
59
  container_cpu: 2 # CPUs per container
60
  container_memory: 4096 # Memory (MB) per container
61
- timeout: 3600 # Per-container timeout in seconds
62
 
63
  # Run name — used to version exported splits so re-runs never overwrite previous data.
64
  # Change this whenever you regenerate from scratch (e.g. after template changes).
 
17
  # Bumped families with many templates or mixed-source variants so each
18
  # template_id gets enough coverage after uniform sampling + stratified split.
19
  sample_targets:
20
+ direct_lookup: 1000
21
+ disambiguation: 2000 # 3 templates (disambiguate_01..03) - "Puri, Odisha" pattern
22
+ adjacency: 2000 # 6 templates (adj_01..06) - adj_06 is counties
23
+ multi_adjacency: 1000
24
+ containment: 2000 # 4 templates (contain_01..04) - contain_02 reversed, contain_03/04 NE anchor
25
+ intersection: 2000 # 4 templates (intersect_01..04) - intersect_02/03 NE anchor
26
+ buffer: 2000 # 5 templates (buffer_01..05)
27
+ chained: 2000 # 11 templates (chained_01..11) - 10/11 are coastal/inland regions
28
+ difference: 2000 # 2 templates, one is mixed (diff_02)
29
+ border_corridor: 1000
30
+ set_operations: 2000
31
+ partial_selection: 2000 # 5 templates, one is mixed (partial_05)
32
+ aggregation: 1500
33
+ window_function: 1000
34
+ attribute_filter: 1000 # 3 templates (attr_01..03)
35
 
36
  # Generation settings
37
  generation:
 
58
  num_containers: 100 # Number of parallel containers for sample generation
59
  container_cpu: 2 # CPUs per container
60
  container_memory: 4096 # Memory (MB) per container
61
+ timeout: 7200 # Per-container timeout in seconds
62
 
63
  # Run name — used to version exported splits so re-runs never overwrite previous data.
64
  # Change this whenever you regenerate from scratch (e.g. after template changes).
dataset/modal_app.py CHANGED
@@ -24,7 +24,16 @@ image = (
24
  "pyarrow>=17.0.0",
25
  "pyyaml>=6.0",
26
  )
27
- .env({"GAZET_DATA_DIR": VOLUME_MOUNT, "PYTHONPATH": "/root"})
 
 
 
 
 
 
 
 
 
28
  .add_local_dir("src/gazet", "/root/gazet")
29
  .add_local_dir("dataset", "/root/dataset")
30
  )
@@ -50,9 +59,9 @@ def build_inventory_remote():
50
  @app.function(
51
  image=image,
52
  volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
53
- timeout=3600,
54
- cpu=4,
55
- memory=32768,
56
  )
57
  def build_relation_remote(relation_type: str, countries: list, limit: int):
58
  """Compute one relation type and save to intermediate volume."""
@@ -72,7 +81,7 @@ def build_relation_remote(relation_type: str, countries: list, limit: int):
72
  @app.function(
73
  image=image,
74
  volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
75
- timeout=3600,
76
  cpu=2,
77
  memory=4096,
78
  )
@@ -126,13 +135,40 @@ def run_pipeline(
126
 
127
  relation_needs = calculate_relation_limits(config)
128
 
129
- handles = []
130
- for rel_type, limit in relation_needs.items():
131
- h = build_relation_remote.spawn(rel_type, countries, max(limit, 500))
132
- handles.append((rel_type, h))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- for rel_type, h in handles:
135
- result = h.get()
 
 
136
  print(f" {rel_type}: {result['count']} pairs")
137
 
138
  print(f"Generating samples across {n_containers} containers...")
 
24
  "pyarrow>=17.0.0",
25
  "pyyaml>=6.0",
26
  )
27
+ .env(
28
+ {
29
+ "GAZET_DATA_DIR": VOLUME_MOUNT,
30
+ "PYTHONPATH": "/root",
31
+ # Spatial self-joins are much more stable with conservative
32
+ # DuckDB settings inside Modal containers.
33
+ "GAZET_DUCKDB_THREADS": "1",
34
+ "GAZET_DUCKDB_MEMORY_LIMIT": "20GB",
35
+ }
36
+ )
37
  .add_local_dir("src/gazet", "/root/gazet")
38
  .add_local_dir("dataset", "/root/dataset")
39
  )
 
59
  @app.function(
60
  image=image,
61
  volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
62
+ timeout=7200,
63
+ cpu=2,
64
+ memory=65536,
65
  )
66
  def build_relation_remote(relation_type: str, countries: list, limit: int):
67
  """Compute one relation type and save to intermediate volume."""
 
81
  @app.function(
82
  image=image,
83
  volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
84
+ timeout=7200,
85
  cpu=2,
86
  memory=4096,
87
  )
 
135
 
136
  relation_needs = calculate_relation_limits(config)
137
 
138
+ # Global containment-style relations are the most expensive and don't
139
+ # need extremely large anchor tables to support sample generation.
140
+ if countries == ["all"]:
141
+ for rel_type, cap in {
142
+ "containment": 12000,
143
+ "coastal_containment": 8000,
144
+ "landlocked_containment": 8000,
145
+ }.items():
146
+ if rel_type in relation_needs:
147
+ relation_needs[rel_type] = min(relation_needs[rel_type], cap)
148
+
149
+ # Spatial relation builds are the most crash-prone part of the Modal
150
+ # pipeline. Run them sequentially with conservative DuckDB settings
151
+ # rather than fanning out several large native spatial joins at once.
152
+ # common_neighbor still runs after adjacency because it depends on the
153
+ # adjacency parquet being committed first.
154
+ ordered_relations = [
155
+ rel_type
156
+ for rel_type in (
157
+ "adjacency",
158
+ "containment",
159
+ "intersection",
160
+ "cross_source",
161
+ "coastal_containment",
162
+ "landlocked_containment",
163
+ "common_neighbor",
164
+ )
165
+ if rel_type in relation_needs
166
+ ]
167
 
168
+ for rel_type in ordered_relations:
169
+ limit = max(relation_needs[rel_type], 500)
170
+ print(f" building {rel_type} (limit={limit})...")
171
+ result = build_relation_remote.remote(rel_type, countries, limit)
172
  print(f" {rel_type}: {result['count']} pairs")
173
 
174
  print(f"Generating samples across {n_containers} containers...")
dataset/scripts/build_inventory.py CHANGED
@@ -37,10 +37,11 @@ def build_divisions_area_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFra
37
  ST_XMax(geometry) AS xmax,
38
  ST_YMax(geometry) AS ymax
39
  FROM read_parquet(?)
40
- WHERE names."primary" IS NOT NULL
41
  AND trim(names."primary") != ''
 
42
  """
43
-
44
  df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
45
  print(f"Divisions area inventory: {len(df)} entities")
46
  print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
@@ -69,10 +70,11 @@ def build_natural_earth_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFram
69
  ST_XMax(geometry) AS xmax,
70
  ST_YMax(geometry) AS ymax
71
  FROM read_parquet(?)
72
- WHERE names."primary" IS NOT NULL
73
  AND trim(names."primary") != ''
 
74
  """
75
-
76
  df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
77
  print(f"\nNatural earth inventory: {len(df)} entities")
78
  print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
 
37
  ST_XMax(geometry) AS xmax,
38
  ST_YMax(geometry) AS ymax
39
  FROM read_parquet(?)
40
+ WHERE names."primary" IS NOT NULL
41
  AND trim(names."primary") != ''
42
+ AND geometry IS NOT NULL
43
  """
44
+
45
  df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
46
  print(f"Divisions area inventory: {len(df)} entities")
47
  print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
 
70
  ST_XMax(geometry) AS xmax,
71
  ST_YMax(geometry) AS ymax
72
  FROM read_parquet(?)
73
+ WHERE names."primary" IS NOT NULL
74
  AND trim(names."primary") != ''
75
+ AND geometry IS NOT NULL
76
  """
77
+
78
  df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
79
  print(f"\nNatural earth inventory: {len(df)} entities")
80
  print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
dataset/scripts/build_relations.py CHANGED
@@ -14,10 +14,12 @@ Output:
14
  - intermediate/cross_source_relations.parquet
15
  """
16
 
 
 
 
 
17
  import duckdb
18
  import pandas as pd
19
- from pathlib import Path
20
- from concurrent.futures import ThreadPoolExecutor, as_completed
21
 
22
  from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
23
 
@@ -26,6 +28,43 @@ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
26
  _EXCLUDED_SUBTYPES_FOR_GLOBAL = ("locality", "neighborhood", "microhood", "macrohood")
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def _country_filter(countries: list) -> tuple[str, list]:
30
  """Return (SQL WHERE clause, params) handling 'all' sentinel."""
31
  if countries == ["all"]:
@@ -46,6 +85,29 @@ def _country_filter_for_join(countries: list) -> tuple[str, list]:
46
  return f"WHERE country IN (SELECT unnest(?)) {subtype_clause}", [countries]
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def compute_adjacency_pairs(
50
  con: duckdb.DuckDBPyConnection,
51
  countries: list,
@@ -95,50 +157,100 @@ def compute_adjacency_pairs(
95
  return df
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def compute_containment_pairs(
99
  con: duckdb.DuckDBPyConnection,
100
  countries: list,
101
  limit: int
102
  ) -> pd.DataFrame:
103
- """Find all pairs where one feature contains another."""
104
- print("\nComputing containment pairs (optimized)...")
105
-
106
- cfilter, cparams = _country_filter(countries)
107
-
108
- query = f"""
109
- WITH features AS (
110
- SELECT
111
- id,
112
- names."primary" AS name,
113
- subtype,
114
- country,
115
- admin_level,
116
- geometry,
117
- ST_Envelope(geometry) AS bbox
118
- FROM read_parquet(?)
119
- {cfilter}
120
- )
121
- SELECT
122
- a.id AS container_id,
123
- a.name AS container_name,
124
- a.subtype AS container_subtype,
125
- b.id AS contained_id,
126
- b.name AS contained_name,
127
- b.subtype AS contained_subtype,
128
- 'containment' AS relation_type
129
- FROM features AS a
130
- JOIN features AS b ON (
131
- a.id != b.id
132
- AND a.admin_level < b.admin_level
133
- AND ST_Intersects(a.bbox, b.bbox)
134
- AND ST_Within(b.geometry, a.geometry)
135
- )
136
- LIMIT ?
137
- """
138
-
139
- df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
140
  print(f"Found {len(df)} containment pairs")
141
-
142
  return df
143
 
144
 
@@ -198,60 +310,71 @@ def compute_cross_source_relations(
198
  ) -> pd.DataFrame:
199
  """Find relations between divisions_area and natural_earth.
200
 
201
- Covers all natural_earth subtypes that appear in SQL templates:
202
- seas/oceans (adjacency, buffer, chained), terrain areas and island
203
- groups (chained_03, intersect_02, buffer_03/04).
 
204
  """
205
- print("\nComputing cross-source relations...")
206
 
207
  cfilter, cparams = _country_filter(countries)
208
-
209
- query = f"""
210
- WITH divisions AS (
211
- SELECT
212
- id,
213
- names."primary" AS name,
214
- subtype,
215
- country,
216
- geometry
217
- FROM read_parquet(?)
218
- {cfilter}
219
- ),
220
- natural_features AS (
221
- SELECT
222
- id,
223
- names."primary" AS name,
224
- subtype,
225
- ST_SetCRS(geometry, 'OGC:CRS84') AS geometry
226
- FROM read_parquet(?)
227
- WHERE subtype IN (
228
- 'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay',
229
- 'Terrain area', 'Island group', 'Peninsula', 'Strait',
230
- 'Reef', 'Range/Mts', 'Depression'
 
 
 
 
 
 
231
  )
232
- )
233
- SELECT
234
- d.id AS division_id,
235
- d.name AS division_name,
236
- d.subtype AS division_subtype,
237
- d.country AS division_country,
238
- n.id AS natural_id,
239
- n.name AS natural_name,
240
- n.subtype AS natural_subtype,
241
- CASE
242
- WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
243
- WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
244
- WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
245
- WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
246
- END AS relation_type
247
- FROM divisions AS d
248
- JOIN natural_features AS n ON ST_Intersects(d.geometry, n.geometry)
249
- LIMIT ?
250
- """
251
-
252
- df = con.execute(
253
- query, [DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH, limit]
254
- ).fetchdf()
 
 
 
 
255
  print(f"Found {len(df)} cross-source relations")
256
  return df
257
 
@@ -261,64 +384,29 @@ def compute_coastal_containment_pairs(
261
  countries: list,
262
  limit: int,
263
  ) -> pd.DataFrame:
264
- """Containment pairs where the container is in a coastal country.
265
 
266
- Used by chained_01 (coastal towns of X) to ensure sampled containment
267
- anchors actually have sea-adjacent sub-features, keeping the SQL
268
- verification step from constantly returning empty results.
269
-
270
- Strategy: find countries whose geometry intersects any ocean/sea in
271
- natural_earth, then filter containment_pairs to those countries.
272
  """
273
- print("\nComputing coastal containment pairs...")
274
-
275
- cfilter, cparams = _country_filter(countries)
276
-
277
- query = f"""
278
- WITH coastal_countries AS (
279
- SELECT DISTINCT d.country
280
- FROM read_parquet(?) AS d
281
- JOIN read_parquet(?) AS n
282
- ON ST_Intersects(d.geometry, ST_SetCRS(n.geometry, 'OGC:CRS84'))
283
- WHERE d.subtype = 'country'
284
- AND n.subtype IN ('sea', 'ocean')
285
- ),
286
- features AS (
287
- SELECT
288
- id,
289
- names."primary" AS name,
290
- subtype,
291
- country,
292
- admin_level,
293
- geometry,
294
- ST_Envelope(geometry) AS bbox
295
- FROM read_parquet(?)
296
- {cfilter}
297
- )
298
- SELECT
299
- a.id AS container_id,
300
- a.name AS container_name,
301
- a.subtype AS container_subtype,
302
- b.id AS contained_id,
303
- b.name AS contained_name,
304
- b.subtype AS contained_subtype,
305
- a.country AS container_country,
306
- 'coastal_containment' AS relation_type
307
- FROM features AS a
308
- JOIN features AS b ON (
309
- a.id != b.id
310
- AND a.admin_level < b.admin_level
311
- AND ST_Intersects(a.bbox, b.bbox)
312
- AND ST_Within(b.geometry, a.geometry)
313
- )
314
- WHERE a.country IN (SELECT country FROM coastal_countries)
315
- LIMIT ?
316
  """
317
-
318
- df = con.execute(
319
- query,
320
- [DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH] + cparams + [DIVISIONS_AREA_PATH, limit],
321
- ).fetchdf()
322
  print(f"Found {len(df)} coastal containment pairs")
323
  return df
324
 
@@ -328,61 +416,29 @@ def compute_landlocked_containment_pairs(
328
  countries: list,
329
  limit: int,
330
  ) -> pd.DataFrame:
331
- """Containment pairs where the container is in a landlocked country.
332
 
333
- Used by chained_02 (landlocked regions within X) to ensure sampled
334
- anchors genuinely have no sea access, keeping SQL verification from
335
- always returning empty.
336
  """
337
- print("\nComputing landlocked containment pairs...")
338
-
339
- cfilter, cparams = _country_filter(countries)
340
-
341
- query = f"""
342
- WITH coastal_countries AS (
343
- SELECT DISTINCT d.country
344
- FROM read_parquet(?) AS d
345
- JOIN read_parquet(?) AS n
346
- ON ST_Intersects(d.geometry, ST_SetCRS(n.geometry, 'OGC:CRS84'))
347
- WHERE d.subtype = 'country'
348
- AND n.subtype IN ('sea', 'ocean')
349
- ),
350
- features AS (
351
- SELECT
352
- id,
353
- names."primary" AS name,
354
- subtype,
355
- country,
356
- admin_level,
357
- geometry,
358
- ST_Envelope(geometry) AS bbox
359
- FROM read_parquet(?)
360
- {cfilter}
361
- )
362
- SELECT
363
- a.id AS container_id,
364
- a.name AS container_name,
365
- a.subtype AS container_subtype,
366
- b.id AS contained_id,
367
- b.name AS contained_name,
368
- b.subtype AS contained_subtype,
369
- a.country AS container_country,
370
- 'landlocked_containment' AS relation_type
371
- FROM features AS a
372
- JOIN features AS b ON (
373
- a.id != b.id
374
- AND a.admin_level < b.admin_level
375
- AND ST_Intersects(a.bbox, b.bbox)
376
- AND ST_Within(b.geometry, a.geometry)
377
- )
378
- WHERE a.country NOT IN (SELECT country FROM coastal_countries)
379
- LIMIT ?
380
  """
381
-
382
- df = con.execute(
383
- query,
384
- [DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH] + cparams + [DIVISIONS_AREA_PATH, limit],
385
- ).fetchdf()
386
  print(f"Found {len(df)} landlocked containment pairs")
387
  return df
388
 
@@ -411,6 +467,21 @@ def compute_common_neighbor_pairs(
411
  ])
412
 
413
  query = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  SELECT DISTINCT
415
  a1.anchor_id AS anchor_id_1,
416
  a1.anchor_name AS anchor_name_1,
@@ -418,8 +489,8 @@ def compute_common_neighbor_pairs(
418
  a2.anchor_name AS anchor_name_2,
419
  a1.target_id AS shared_neighbor_id,
420
  a1.target_name AS shared_neighbor_name
421
- FROM read_parquet(?) AS a1
422
- JOIN read_parquet(?) AS a2
423
  ON a1.target_id = a2.target_id
424
  AND a1.anchor_id < a2.anchor_id
425
  LIMIT ?
@@ -435,9 +506,11 @@ def _make_connection():
435
  con = duckdb.connect()
436
  con.execute("INSTALL spatial")
437
  con.execute("LOAD spatial")
438
- con.execute("SET memory_limit='24GB'")
 
 
439
  con.execute("SET temp_directory='/tmp/duckdb_tmp'")
440
- con.execute("SET threads=4")
441
  return con
442
 
443
 
 
14
  - intermediate/cross_source_relations.parquet
15
  """
16
 
17
+ import os
18
+ from concurrent.futures import ThreadPoolExecutor, as_completed
19
+ from pathlib import Path
20
+
21
  import duckdb
22
  import pandas as pd
 
 
23
 
24
  from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
25
 
 
28
  _EXCLUDED_SUBTYPES_FOR_GLOBAL = ("locality", "neighborhood", "microhood", "macrohood")
29
 
30
 
31
+ # (container_subtype, contained_subtype) combos used by the chained,
32
+ # containment, and disambiguation templates. A naive self-join with
33
+ # LIMIT fills up with coarse pairs (country -> region / county) first and
34
+ # never emits locality-level pairs; stratifying by combo ensures each
35
+ # template has anchors to draw from.
36
+ _CONTAINMENT_SUBTYPE_PAIRS = (
37
+ ("country", "region"),
38
+ ("country", "county"),
39
+ ("country", "localadmin"),
40
+ ("country", "locality"),
41
+ ("region", "county"),
42
+ ("region", "localadmin"),
43
+ ("region", "locality"),
44
+ ("county", "locality"),
45
+ ("localadmin", "locality"),
46
+ )
47
+
48
+
49
+ # Natural Earth subtype vocabulary in the current GeoParquet.
50
+ # Keep these exact strings in one place so relation building and templates
51
+ # stay aligned with the underlying dataset.
52
+ _NE_CROSS_SOURCE_SUBTYPES = (
53
+ "sea",
54
+ "ocean",
55
+ "Lake",
56
+ "River",
57
+ "Basin",
58
+ "gulf",
59
+ "bay",
60
+ "Island group",
61
+ "Peninsula",
62
+ "strait",
63
+ "Range/mtn",
64
+ "Depression",
65
+ )
66
+
67
+
68
  def _country_filter(countries: list) -> tuple[str, list]:
69
  """Return (SQL WHERE clause, params) handling 'all' sentinel."""
70
  if countries == ["all"]:
 
85
  return f"WHERE country IN (SELECT unnest(?)) {subtype_clause}", [countries]
86
 
87
 
88
+ def _country_chunks(
89
+ con: duckdb.DuckDBPyConnection,
90
+ countries: list,
91
+ chunk_size: int = 40,
92
+ ) -> list[list[str]]:
93
+ """Return explicit country batches for safer global containment joins."""
94
+ if countries != ["all"]:
95
+ return [countries]
96
+
97
+ rows = con.execute(
98
+ """
99
+ SELECT DISTINCT country
100
+ FROM read_parquet(?)
101
+ WHERE country IS NOT NULL
102
+ AND trim(country) != ''
103
+ ORDER BY country
104
+ """,
105
+ [DIVISIONS_AREA_PATH],
106
+ ).fetchall()
107
+ codes = [row[0] for row in rows]
108
+ return [codes[i:i + chunk_size] for i in range(0, len(codes), chunk_size)]
109
+
110
+
111
  def compute_adjacency_pairs(
112
  con: duckdb.DuckDBPyConnection,
113
  countries: list,
 
157
  return df
158
 
159
 
160
+ def _stratified_containment(
161
+ con: duckdb.DuckDBPyConnection,
162
+ countries: list,
163
+ limit: int,
164
+ relation_type: str,
165
+ extra_where: str = "",
166
+ extra_params: list = None,
167
+ ) -> pd.DataFrame:
168
+ """Compute containment pairs stratified by (container_subtype, contained_subtype).
169
+
170
+ A single global self-join with LIMIT fills up with coarse country->region
171
+ pairs before emitting any locality-level pairs. We run one focused query
172
+ per subtype combo instead so every combo receives a fair share of the
173
+ overall limit.
174
+
175
+ ``extra_where`` / ``extra_params`` let the coastal and landlocked variants
176
+ inject their country-set filter without duplicating the whole body.
177
+ """
178
+ extra_params = extra_params or []
179
+ # Use a lower target per subtype combo for global runs; they are the most
180
+ # memory-intensive part of the pipeline and don't need huge anchor tables.
181
+ if countries == ["all"]:
182
+ per_combo = min(max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100), 1500)
183
+ else:
184
+ per_combo = max(limit // len(_CONTAINMENT_SUBTYPE_PAIRS), 100)
185
+
186
+ country_batches = _country_chunks(con, countries)
187
+ frames: list[pd.DataFrame] = []
188
+ for container_st, contained_st in _CONTAINMENT_SUBTYPE_PAIRS:
189
+ remaining = per_combo
190
+ combo_parts: list[pd.DataFrame] = []
191
+
192
+ for batch in country_batches:
193
+ if remaining <= 0:
194
+ break
195
+
196
+ cfilter, cparams = _country_filter(batch)
197
+ query = f"""
198
+ WITH a AS (
199
+ SELECT src.id, src.names."primary" AS name, src.subtype, src.country, src.admin_level,
200
+ src.geometry, ST_Envelope(src.geometry) AS bbox
201
+ FROM read_parquet(?) AS src
202
+ WHERE src.subtype = '{container_st}'
203
+ {cfilter.replace("WHERE", "AND") if cfilter else ""}
204
+ {extra_where}
205
+ ),
206
+ b AS (
207
+ SELECT dst.id, dst.names."primary" AS name, dst.subtype, dst.country, dst.admin_level,
208
+ dst.geometry, ST_Envelope(dst.geometry) AS bbox
209
+ FROM read_parquet(?) AS dst
210
+ WHERE dst.subtype = '{contained_st}'
211
+ {cfilter.replace("WHERE", "AND") if cfilter else ""}
212
+ )
213
+ SELECT
214
+ a.id AS container_id,
215
+ a.name AS container_name,
216
+ a.subtype AS container_subtype,
217
+ b.id AS contained_id,
218
+ b.name AS contained_name,
219
+ b.subtype AS contained_subtype,
220
+ a.country AS container_country,
221
+ '{relation_type}' AS relation_type
222
+ FROM a JOIN b ON (
223
+ a.id != b.id
224
+ AND ST_Intersects(a.bbox, b.bbox)
225
+ AND ST_Within(b.geometry, a.geometry)
226
+ )
227
+ LIMIT {remaining}
228
+ """
229
+ params = [DIVISIONS_AREA_PATH] + extra_params + cparams + [DIVISIONS_AREA_PATH] + cparams
230
+ df_part = con.execute(query, params).fetchdf()
231
+ if not df_part.empty:
232
+ combo_parts.append(df_part)
233
+ remaining -= len(df_part)
234
+
235
+ df_combo = (
236
+ pd.concat(combo_parts, ignore_index=True)
237
+ if combo_parts else pd.DataFrame()
238
+ )
239
+ print(f" {relation_type} {container_st:>10s} -> {contained_st:<10s}: {len(df_combo)} pairs")
240
+ frames.append(df_combo)
241
+
242
+ return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
243
+
244
+
245
  def compute_containment_pairs(
246
  con: duckdb.DuckDBPyConnection,
247
  countries: list,
248
  limit: int
249
  ) -> pd.DataFrame:
250
+ """Find containment pairs stratified across admin-level combinations."""
251
+ print("\nComputing containment pairs (stratified by subtype combo)...")
252
+ df = _stratified_containment(con, countries, limit, relation_type="containment")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  print(f"Found {len(df)} containment pairs")
 
254
  return df
255
 
256
 
 
310
  ) -> pd.DataFrame:
311
  """Find relations between divisions_area and natural_earth.
312
 
313
+ The join is heavily skewed by a few abundant Natural Earth subtypes
314
+ (especially rivers and mountain ranges). We therefore stratify by exact
315
+ natural subtype so seas/oceans, gulfs/bays, island groups, and rarer
316
+ landforms all make it into the anchor pool.
317
  """
318
+ print("\nComputing cross-source relations (stratified by NE subtype)...")
319
 
320
  cfilter, cparams = _country_filter(countries)
321
+ per_subtype = max(limit // len(_NE_CROSS_SOURCE_SUBTYPES), 50)
322
+
323
+ frames: list[pd.DataFrame] = []
324
+ for natural_subtype in _NE_CROSS_SOURCE_SUBTYPES:
325
+ query = f"""
326
+ WITH divisions AS (
327
+ SELECT
328
+ id,
329
+ names."primary" AS name,
330
+ subtype,
331
+ country,
332
+ geometry
333
+ FROM read_parquet(?)
334
+ WHERE geometry IS NOT NULL
335
+ AND names."primary" IS NOT NULL
336
+ AND trim(names."primary") != ''
337
+ {cfilter.replace("WHERE", "AND") if cfilter else ''}
338
+ ),
339
+ natural_features AS (
340
+ SELECT
341
+ id,
342
+ names."primary" AS name,
343
+ subtype,
344
+ geometry
345
+ FROM read_parquet(?)
346
+ WHERE geometry IS NOT NULL
347
+ AND names."primary" IS NOT NULL
348
+ AND trim(names."primary") != ''
349
+ AND subtype = '{natural_subtype}'
350
  )
351
+ SELECT
352
+ d.id AS division_id,
353
+ d.name AS division_name,
354
+ d.subtype AS division_subtype,
355
+ d.country AS division_country,
356
+ n.id AS natural_id,
357
+ n.name AS natural_name,
358
+ n.subtype AS natural_subtype,
359
+ CASE
360
+ WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
361
+ WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
362
+ WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
363
+ WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
364
+ END AS relation_type
365
+ FROM divisions AS d
366
+ JOIN natural_features AS n
367
+ ON ST_Intersects(d.geometry, n.geometry)
368
+ LIMIT {per_subtype}
369
+ """
370
+ df_part = con.execute(
371
+ query,
372
+ [DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH],
373
+ ).fetchdf()
374
+ print(f" cross_source {natural_subtype:>12s}: {len(df_part)} rows")
375
+ frames.append(df_part)
376
+
377
+ df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
378
  print(f"Found {len(df)} cross-source relations")
379
  return df
380
 
 
384
  countries: list,
385
  limit: int,
386
  ) -> pd.DataFrame:
387
+ """Stratified containment pairs limited to coastal-country containers.
388
 
389
+ Used by chained_01 (coastal towns of X) so sampled anchors actually have
390
+ sea-adjacent sub-features. Stratification guarantees coverage of every
391
+ admin-level combination (country->locality, region->locality, etc.).
 
 
 
392
  """
393
+ print("\nComputing coastal containment pairs (stratified)...")
394
+ extra_where = f"""
395
+ AND EXISTS (
396
+ SELECT 1
397
+ FROM read_parquet('{NATURAL_EARTH_PATH}') AS n
398
+ WHERE n.geometry IS NOT NULL
399
+ AND n.names."primary" IS NOT NULL
400
+ AND trim(n.names."primary") != ''
401
+ AND n.subtype IN ('sea', 'ocean')
402
+ AND ST_Intersects(src.geometry, n.geometry)
403
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """
405
+ df = _stratified_containment(
406
+ con, countries, limit,
407
+ relation_type="coastal_containment",
408
+ extra_where=extra_where,
409
+ )
410
  print(f"Found {len(df)} coastal containment pairs")
411
  return df
412
 
 
416
  countries: list,
417
  limit: int,
418
  ) -> pd.DataFrame:
419
+ """Stratified containment pairs limited to landlocked-country containers.
420
 
421
+ Used by chained_02 (landlocked localities within X). Stratification by
422
+ subtype combo ensures locality-level pairs are actually present in the
423
+ output instead of being starved by coarse country->region pairs.
424
  """
425
+ print("\nComputing landlocked containment pairs (stratified)...")
426
+ extra_where = f"""
427
+ AND NOT EXISTS (
428
+ SELECT 1
429
+ FROM read_parquet('{NATURAL_EARTH_PATH}') AS n
430
+ WHERE n.geometry IS NOT NULL
431
+ AND n.names."primary" IS NOT NULL
432
+ AND trim(n.names."primary") != ''
433
+ AND n.subtype IN ('sea', 'ocean')
434
+ AND ST_Intersects(src.geometry, n.geometry)
435
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  """
437
+ df = _stratified_containment(
438
+ con, countries, limit,
439
+ relation_type="landlocked_containment",
440
+ extra_where=extra_where,
441
+ )
442
  print(f"Found {len(df)} landlocked containment pairs")
443
  return df
444
 
 
467
  ])
468
 
469
  query = """
470
+ WITH undirected AS (
471
+ SELECT
472
+ anchor_id,
473
+ anchor_name,
474
+ target_id,
475
+ target_name
476
+ FROM read_parquet(?)
477
+ UNION ALL
478
+ SELECT
479
+ target_id AS anchor_id,
480
+ target_name AS anchor_name,
481
+ anchor_id AS target_id,
482
+ anchor_name AS target_name
483
+ FROM read_parquet(?)
484
+ )
485
  SELECT DISTINCT
486
  a1.anchor_id AS anchor_id_1,
487
  a1.anchor_name AS anchor_name_1,
 
489
  a2.anchor_name AS anchor_name_2,
490
  a1.target_id AS shared_neighbor_id,
491
  a1.target_name AS shared_neighbor_name
492
+ FROM undirected AS a1
493
+ JOIN undirected AS a2
494
  ON a1.target_id = a2.target_id
495
  AND a1.anchor_id < a2.anchor_id
496
  LIMIT ?
 
506
  con = duckdb.connect()
507
  con.execute("INSTALL spatial")
508
  con.execute("LOAD spatial")
509
+ memory_limit = os.environ.get("GAZET_DUCKDB_MEMORY_LIMIT", "12GB")
510
+ threads = int(os.environ.get("GAZET_DUCKDB_THREADS", "1"))
511
+ con.execute(f"SET memory_limit='{memory_limit}'")
512
  con.execute("SET temp_directory='/tmp/duckdb_tmp'")
513
+ con.execute(f"SET threads={threads}")
514
  return con
515
 
516
 
dataset/scripts/cli.py CHANGED
@@ -110,6 +110,19 @@ def calculate_relation_limits(config: dict) -> Dict[str, int]:
110
  return relation_needs
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def build_relations(config_path: Path):
114
  """Run relation building with config."""
115
  config = load_config(config_path)
@@ -269,6 +282,9 @@ def main():
269
  formatter_class=argparse.RawDescriptionHelpFormatter,
270
  epilog="""
271
  Examples:
 
 
 
272
  # Build relation tables only
273
  python cli.py build-relations --config ../config.yaml
274
 
@@ -296,7 +312,7 @@ Examples:
296
 
297
  parser.add_argument(
298
  'command',
299
- choices=['build-relations', 'generate-samples', 'validate', 'export',
300
  'full-pipeline', 'modal-upload', 'modal-generate'],
301
  help='Command to run'
302
  )
@@ -348,7 +364,9 @@ Examples:
348
 
349
  # Run the appropriate command
350
  try:
351
- if args.command == 'build-relations':
 
 
352
  build_relations(args.config)
353
  elif args.command == 'generate-samples':
354
  generate_samples(args.config, args.append)
 
110
  return relation_needs
111
 
112
 
113
+ def normalize_data():
114
+ """Build normalized source parquet copies with harmonized geometry metadata."""
115
+ print("=" * 60)
116
+ print("STEP 0: Normalizing Source Geodata")
117
+ print("=" * 60)
118
+ from dataset.scripts.normalize_geodata import normalize_geodata
119
+
120
+ result = normalize_geodata()
121
+ for name, path in result.items():
122
+ print(f" {name}: {path}")
123
+
124
+
125
+
126
  def build_relations(config_path: Path):
127
  """Run relation building with config."""
128
  config = load_config(config_path)
 
282
  formatter_class=argparse.RawDescriptionHelpFormatter,
283
  epilog="""
284
  Examples:
285
+ # Normalize source geodata first (recommended before Modal upload)
286
+ python cli.py normalize-data --config ../config.yaml
287
+
288
  # Build relation tables only
289
  python cli.py build-relations --config ../config.yaml
290
 
 
312
 
313
  parser.add_argument(
314
  'command',
315
+ choices=['normalize-data', 'build-relations', 'generate-samples', 'validate', 'export',
316
  'full-pipeline', 'modal-upload', 'modal-generate'],
317
  help='Command to run'
318
  )
 
364
 
365
  # Run the appropriate command
366
  try:
367
+ if args.command == 'normalize-data':
368
+ normalize_data()
369
+ elif args.command == 'build-relations':
370
  build_relations(args.config)
371
  elif args.command == 'generate-samples':
372
  generate_samples(args.config, args.append)
dataset/scripts/export_training_data.py CHANGED
@@ -124,7 +124,7 @@ You have access to two DuckDB parquet tables. Given a set of candidate entities
124
  id VARCHAR -- unique feature id prefixed 'ne_'
125
  names STRUCT("primary" VARCHAR, ...)
126
  country VARCHAR
127
- subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
128
  class VARCHAR
129
  region VARCHAR
130
  admin_level INTEGER
 
124
  id VARCHAR -- unique feature id prefixed 'ne_'
125
  names STRUCT("primary" VARCHAR, ...)
126
  country VARCHAR
127
+ subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Range/mtn', 'Island group'
128
  class VARCHAR
129
  region VARCHAR
130
  admin_level INTEGER
dataset/scripts/generate_samples.py CHANGED
@@ -44,7 +44,7 @@ def _for_execution(sql: str) -> str:
44
  return (
45
  sql
46
  .replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
47
- .replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
48
  )
49
 
50
  # Configurable parameters (can be overridden by CLI)
@@ -60,6 +60,33 @@ SQLTemplate = sql_templates.SQLTemplate
60
  get_templates_by_family = sql_templates.get_templates_by_family
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  class Candidate(BaseModel):
64
  """Candidate entity for grounding."""
65
  candidate_id: str
@@ -121,6 +148,8 @@ def sample_adjacency_anchor(
121
  'anchor_name': row['anchor_name'],
122
  'anchor_subtype': row['anchor_subtype'],
123
  'anchor_country': row.get('anchor_country'), # May not exist in all tables
 
 
124
  'target_subtype': row.get('target_subtype')
125
  }
126
 
@@ -142,16 +171,23 @@ def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[s
142
 
143
 
144
  def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
145
- """Sample a random containment pair."""
 
 
 
 
 
146
  if containment_df.empty:
147
  return None
148
-
149
  row = containment_df.sample(n=1).iloc[0]
150
  return {
151
  'container_id': row['container_id'],
152
  'container_name': row['container_name'],
153
  'container_subtype': row['container_subtype'],
154
- 'contained_subtype': row['contained_subtype']
 
 
155
  }
156
 
157
 
@@ -186,12 +222,24 @@ def sample_disambiguation_anchor(
186
  }
187
 
188
 
189
- def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
190
- """Sample a random cross-source relation."""
 
 
 
 
191
  if cross_source_df.empty:
192
  return None
193
-
194
- row = cross_source_df.sample(n=1).iloc[0]
 
 
 
 
 
 
 
 
195
  return {
196
  'division_id': row['division_id'],
197
  'division_name': row['division_name'],
@@ -242,16 +290,16 @@ def build_candidate_list(
242
  difficulty: str = "medium"
243
  ) -> List[Candidate]:
244
  """Build candidate list with true anchor + distractors."""
245
-
246
  # Helper to convert pandas NA to None
247
  def safe_get(row, key, default=None):
248
  val = row.get(key, default)
249
  return None if pd.isna(val) else val
250
-
251
  # Get the true anchor
252
  if anchor_source == "divisions_area":
253
  query = """
254
- SELECT
255
  id,
256
  names."primary" AS name,
257
  subtype,
@@ -264,7 +312,7 @@ def build_candidate_list(
264
  anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
265
  else:
266
  query = """
267
- SELECT
268
  id,
269
  names."primary" AS name,
270
  subtype
@@ -272,8 +320,7 @@ def build_candidate_list(
272
  WHERE id = ?
273
  """
274
  anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
275
-
276
- # Build true candidate
277
  true_candidate = Candidate(
278
  candidate_id="c1",
279
  source=anchor_source,
@@ -283,25 +330,31 @@ def build_candidate_list(
283
  country=safe_get(anchor_row, 'country'),
284
  region=safe_get(anchor_row, 'region'),
285
  admin_level=safe_get(anchor_row, 'admin_level'),
286
- similarity=1.0
287
  )
288
-
289
- # Build distractors based on difficulty
290
  distractors = build_distractors(
291
- con,
292
- anchor_name,
293
  anchor_source,
294
  anchor_id,
295
  num_candidates - 1,
296
- difficulty
297
  )
298
-
299
- # Order: true anchor first, then same-source distractors, then cross-source
300
- # distractors. This mirrors inference order (anchor at top by similarity,
301
- # same source grouped before the other source).
302
- candidates = [true_candidate] + distractors
303
 
304
- # Reassign candidate IDs in order
 
 
 
 
 
 
 
 
 
 
 
 
305
  for i, cand in enumerate(candidates, 1):
306
  cand.candidate_id = f"c{i}"
307
 
@@ -335,21 +388,39 @@ def build_distractors(
335
 
336
  def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
337
  query = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  SELECT
339
  id,
340
- names."primary" AS name,
341
  subtype,
342
  country,
343
  region,
344
  admin_level,
345
- jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
346
- FROM read_parquet(?)
347
- WHERE id != ?
348
- AND names."primary" IS NOT NULL
349
  ORDER BY similarity DESC
350
  LIMIT ?
351
  """
352
- df = con.execute(query, [anchor_name, path, excl_id, n]).fetchdf()
353
  results = []
354
  for _, row in df.iterrows():
355
  results.append(Candidate(
@@ -515,13 +586,23 @@ WHERE b.id != '{anchor['container_id']}'
515
  def sample_random_entity(
516
  con: duckdb.DuckDBPyConnection,
517
  inventory_df: pd.DataFrame,
518
- source: str
 
 
519
  ) -> Optional[Dict[str, Any]]:
520
- """Sample a random entity from inventory."""
521
  if inventory_df.empty:
522
  return None
523
-
524
- row = inventory_df.sample(n=1).iloc[0]
 
 
 
 
 
 
 
 
525
  return {
526
  'id': row['id'],
527
  'name': row['name'],
@@ -545,7 +626,12 @@ def generate_template_based_sample(
545
  if template.anchor_source == "divisions_area":
546
  anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
547
  else:
548
- anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
 
 
 
 
 
549
 
550
  if not anchor:
551
  return None
@@ -636,70 +722,156 @@ def generate_template_based_sample(
636
  anchor = {"id": pair["contained_id"], "name": pair["contained_name"]}
637
 
638
  elif template.family == "adjacency":
639
- # If the template pins a target_subtype (e.g. adj_02='region',
640
- # adj_06='county'), honour it so the sampled pair is guaranteed to
641
- # match the question phrasing ("neighbouring counties of X").
642
- anchor = sample_adjacency_anchor(
643
- tables['adjacency_pairs'],
644
- target_subtype=template.target_subtype,
645
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  if not anchor:
647
  return None
648
 
649
  sql = template.sql_template.format(
650
  anchor_id=anchor['anchor_id'],
651
- target_subtype=anchor['target_subtype']
652
  )
653
-
654
  candidates = build_candidate_list(
655
  con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
656
  num_candidates=10, difficulty="medium"
657
  )
658
-
659
  question = random.choice(template.question_hints).format(
660
  anchor_name=anchor['anchor_name'],
661
- target_subtype=anchor['target_subtype']
662
  )
663
 
664
  elif template.family == "containment":
665
- anchor = sample_containment_anchor(tables['containment_pairs'])
666
- if not anchor:
667
- return None
668
-
669
- sql = template.sql_template.format(
670
- anchor_id=anchor['container_id'],
671
- target_subtype=anchor['contained_subtype']
672
- )
673
-
674
- candidates = build_candidate_list(
675
- con, anchor['container_id'], anchor['container_name'], 'divisions_area',
676
- num_candidates=10, difficulty="medium"
677
- )
678
-
679
- question = random.choice(template.question_hints).format(
680
- anchor_name=anchor['container_name'],
681
- target_subtype=anchor['contained_subtype']
682
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
  elif template.family == "intersection":
685
  if template.anchor_source == "natural_earth":
686
- anchor = sample_cross_source_anchor(tables['cross_source_relations'])
 
 
 
687
  if not anchor:
688
  return None
689
-
 
 
690
  sql = template.sql_template.format(
691
  anchor_id=anchor['natural_id'],
692
- target_subtype='country'
693
  )
694
-
695
  candidates = build_candidate_list(
696
  con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
697
  num_candidates=10, difficulty="medium"
698
  )
699
-
700
  question = random.choice(template.question_hints).format(
701
  anchor_name=anchor['natural_name'],
702
- target_subtype='country'
703
  )
704
  else:
705
  # Same-source intersection
@@ -760,14 +932,19 @@ def generate_template_based_sample(
760
  # country IN clause — 2 or 3 anchors, each contributes its country code
761
  num_a = 3 if template.template_id == "contain_multi_02" else 2
762
  anchors = [
763
- sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
 
 
 
 
 
764
  for _ in range(num_a)
765
  ]
766
  if any(a is None for a in anchors):
767
  return None
768
 
769
  countries = [a.get('country') or 'US' for a in anchors]
770
- target_subtype = random.choice(['region', 'locality'])
771
  per_anchor = 3 if num_a == 3 else 4
772
 
773
  fmt_kwargs = dict(
@@ -850,7 +1027,12 @@ def generate_template_based_sample(
850
 
851
  if template.num_anchors == 1:
852
  if template.anchor_source == "natural_earth":
853
- anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
 
 
 
 
 
854
  else:
855
  anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
856
  if not anchor:
@@ -944,20 +1126,22 @@ def generate_template_based_sample(
944
  # Mixed-source clip: division intersected with a natural_earth feature.
945
  # Use cross_source_relations so the pair is guaranteed to intersect —
946
  # random sampling almost never produces an intersecting pair.
947
- cs_df = tables.get('cross_source_relations', pd.DataFrame())
948
- if cs_df.empty:
 
 
 
949
  return None
950
- row = cs_df.sample(n=1).iloc[0]
951
  clip_feature = {
952
- 'id': row['natural_id'],
953
- 'name': row['natural_name'],
954
  'source': 'natural_earth',
955
  }
956
  # Override the division anchor with the paired division so the
957
  # ST_Intersects check in the SQL is guaranteed to pass.
958
  anchor = {
959
- 'id': row['division_id'],
960
- 'name': row['division_name'],
961
  'source': 'divisions_area',
962
  }
963
 
@@ -986,8 +1170,14 @@ def generate_template_based_sample(
986
  target_subtype = random.choice(['locality', 'region'])
987
 
988
  if template.template_id in ['agg_03', 'agg_04']:
989
- # Country-level aggregation: SQL uses country code, not anchor id.
990
- anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
 
 
 
 
 
 
991
  if not anchor:
992
  return None
993
 
@@ -1035,35 +1225,33 @@ def generate_template_based_sample(
1035
  elif template.family == "chained":
1036
  # Use pre-filtered coastal/landlocked containment pairs so the SQL
1037
  # verification step doesn't constantly return empty results.
1038
- if template.template_id == "chained_01":
 
 
1039
  table_key = 'coastal_containment_pairs'
1040
- elif template.template_id == "chained_02":
1041
  table_key = 'landlocked_containment_pairs'
1042
  else:
1043
  table_key = 'containment_pairs'
1044
 
1045
- # chained_10/11 need a country-level anchor ("coastal states of
1046
- # India") and region-level targets, so filter the containment pairs
1047
- # to (container=country, contained=region) before sampling.
1048
- _chained_subtype_filter = {
1049
- "chained_10": ("country", "region"),
1050
- "chained_11": ("country", "region"),
1051
- }
1052
  df = tables.get(table_key, tables['containment_pairs'])
1053
- filt = _chained_subtype_filter.get(template.template_id)
1054
- if filt:
1055
- df = df[
1056
- (df['container_subtype'] == filt[0])
1057
- & (df['contained_subtype'] == filt[1])
1058
- ]
 
 
 
 
 
 
1059
 
1060
  anchor = sample_containment_anchor(df)
1061
  if not anchor:
1062
  return None
1063
 
1064
- # Prefer the template-pinned target_subtype when set (e.g. chained_10
1065
- # always wants 'region') so the SQL filter and question phrasing stay
1066
- # in sync regardless of what the sampled pair happens to contain.
1067
  target_subtype = template.target_subtype or anchor.get('contained_subtype', 'locality')
1068
 
1069
  sql = template.sql_template.format(
@@ -1117,18 +1305,20 @@ def generate_template_based_sample(
1117
  # Use cross_source_relations so the pair is guaranteed to intersect
1118
  # (ST_Difference on non-intersecting geometries is always equal to
1119
  # the original geometry — a trivial and uninformative sample).
1120
- cs_df = tables.get('cross_source_relations', pd.DataFrame())
1121
- if cs_df.empty:
 
 
 
1122
  return None
1123
- row = cs_df.sample(n=1).iloc[0]
1124
  anchor = {
1125
- 'id': row['division_id'],
1126
- 'name': row['division_name'],
1127
  'source': 'divisions_area',
1128
  }
1129
  clip_feature = {
1130
- 'id': row['natural_id'],
1131
- 'name': row['natural_name'],
1132
  'source': 'natural_earth',
1133
  }
1134
 
@@ -1153,20 +1343,19 @@ def generate_template_based_sample(
1153
  candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
1154
 
1155
  else:
1156
- # Two divisions_area anchors use containment pairs so the
1157
- # smaller (contained) is guaranteed to intersect the larger.
 
 
1158
  pair = sample_containment_anchor(tables['containment_pairs'])
1159
  if not pair:
1160
  return None
1161
 
1162
  anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
1163
- anchor2_row = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
1164
- if not anchor2_row:
1165
- return None
1166
- anchor2 = anchor2_row
1167
 
1168
  sql = template.sql_template.format(
1169
- anchor_id_1=anchor1['id'],
1170
  anchor_id_2=anchor2['id'],
1171
  )
1172
 
@@ -1191,19 +1380,8 @@ def generate_template_based_sample(
1191
  if not pair:
1192
  return None
1193
 
1194
- # The adjacency table only records one direction; sample a second
1195
- # anchor that is known to be adjacent to the first.
1196
  anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
1197
-
1198
- # Find a random neighbour of anchor1 from adjacency pairs
1199
- neighbours = tables['adjacency_pairs']
1200
- neighbours = neighbours[neighbours['anchor_id'] == anchor1['id']]
1201
- if neighbours.empty:
1202
- return None
1203
- nb_row = neighbours.sample(n=1).iloc[0]
1204
- anchor2 = {'id': nb_row.get('target_id', nb_row['anchor_id']), 'name': nb_row.get('target_name', nb_row['anchor_name'])}
1205
- if anchor1['id'] == anchor2['id']:
1206
- return None
1207
 
1208
  buffer_val = random.choice([5, 10, 25, 50])
1209
 
@@ -1230,12 +1408,17 @@ def generate_template_based_sample(
1230
  )
1231
 
1232
  elif template.family == "window_function":
1233
- anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
 
 
 
 
 
1234
  if not anchor:
1235
  return None
1236
 
1237
  country = anchor.get('country') or 'US'
1238
- target_subtype = random.choice(['locality', 'neighborhood'])
1239
 
1240
  sql = template.sql_template.format(
1241
  country=country,
@@ -1253,12 +1436,17 @@ def generate_template_based_sample(
1253
  )
1254
 
1255
  elif template.family == "attribute_filter":
1256
- anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
 
 
 
 
 
1257
  if not anchor:
1258
  return None
1259
 
1260
  country = anchor.get('country') or 'US'
1261
- target_subtype = template.target_subtype or random.choice(['dependency', 'region', 'locality'])
1262
 
1263
  sql = template.sql_template.format(
1264
  country=country,
 
44
  return (
45
  sql
46
  .replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
47
+ .replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
48
  )
49
 
50
  # Configurable parameters (can be overridden by CLI)
 
60
  get_templates_by_family = sql_templates.get_templates_by_family
61
 
62
 
63
+ _NE_NAMED_LOOKUP_SUBTYPES = {
64
+ 'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay',
65
+ 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression',
66
+ }
67
+
68
+ _NE_TEMPLATE_SUBTYPES = {
69
+ 'lookup_02': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
70
+ 'adj_03': {'sea', 'ocean'},
71
+ 'adj_04': {'River', 'Lake', 'Basin'},
72
+ 'adj_05': {'Range/mtn', 'Peninsula', 'Depression'},
73
+ 'contain_03': {'sea', 'ocean', 'gulf', 'bay', 'Basin', 'Island group', 'Peninsula', 'Range/mtn', 'Depression'},
74
+ 'contain_04': {'sea', 'ocean', 'gulf', 'bay', 'strait'},
75
+ 'intersect_02': {'River', 'Lake', 'Basin', 'gulf', 'bay', 'strait', 'Range/mtn', 'Peninsula', 'Depression'},
76
+ 'intersect_03': {'River', 'Lake', 'Basin', 'gulf', 'bay', 'strait', 'Range/mtn', 'Peninsula', 'Depression'},
77
+ 'buffer_03': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
78
+ 'buffer_04': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
79
+ 'buffer_05': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
80
+ 'chained_03': {'Island group', 'Peninsula', 'Range/mtn', 'Depression'},
81
+ 'chained_04': {'River', 'Lake', 'Basin'},
82
+ 'chained_05': {'Range/mtn', 'Depression'},
83
+ 'chained_08': {'River', 'Lake', 'Basin'},
84
+ 'chained_09': {'Range/mtn', 'Depression'},
85
+ 'partial_05': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
86
+ 'diff_02': {'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay', 'Island group', 'Peninsula', 'strait', 'Range/mtn', 'Depression'},
87
+ }
88
+
89
+
90
  class Candidate(BaseModel):
91
  """Candidate entity for grounding."""
92
  candidate_id: str
 
148
  'anchor_name': row['anchor_name'],
149
  'anchor_subtype': row['anchor_subtype'],
150
  'anchor_country': row.get('anchor_country'), # May not exist in all tables
151
+ 'target_id': row.get('target_id'),
152
+ 'target_name': row.get('target_name'),
153
  'target_subtype': row.get('target_subtype')
154
  }
155
 
 
171
 
172
 
173
  def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
174
+ """Sample a random containment pair.
175
+
176
+ Returns both ends of the pair so callers that need the contained entity
177
+ (e.g. difference templates that clip container by contained) can use it
178
+ directly without a second random draw.
179
+ """
180
  if containment_df.empty:
181
  return None
182
+
183
  row = containment_df.sample(n=1).iloc[0]
184
  return {
185
  'container_id': row['container_id'],
186
  'container_name': row['container_name'],
187
  'container_subtype': row['container_subtype'],
188
+ 'contained_id': row['contained_id'],
189
+ 'contained_name': row['contained_name'],
190
+ 'contained_subtype': row['contained_subtype'],
191
  }
192
 
193
 
 
222
  }
223
 
224
 
225
+ def sample_cross_source_anchor(
226
+ cross_source_df: pd.DataFrame,
227
+ natural_subtypes: Optional[set[str]] = None,
228
+ relation_types: Optional[set[str]] = None,
229
+ ) -> Optional[Dict[str, Any]]:
230
+ """Sample a random cross-source relation with optional subtype filters."""
231
  if cross_source_df.empty:
232
  return None
233
+
234
+ df = cross_source_df
235
+ if natural_subtypes is not None:
236
+ df = df[df['natural_subtype'].isin(natural_subtypes)]
237
+ if relation_types is not None:
238
+ df = df[df['relation_type'].isin(relation_types)]
239
+ if df.empty:
240
+ return None
241
+
242
+ row = df.sample(n=1).iloc[0]
243
  return {
244
  'division_id': row['division_id'],
245
  'division_name': row['division_name'],
 
290
  difficulty: str = "medium"
291
  ) -> List[Candidate]:
292
  """Build candidate list with true anchor + distractors."""
293
+
294
  # Helper to convert pandas NA to None
295
  def safe_get(row, key, default=None):
296
  val = row.get(key, default)
297
  return None if pd.isna(val) else val
298
+
299
  # Get the true anchor
300
  if anchor_source == "divisions_area":
301
  query = """
302
+ SELECT
303
  id,
304
  names."primary" AS name,
305
  subtype,
 
312
  anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
313
  else:
314
  query = """
315
+ SELECT
316
  id,
317
  names."primary" AS name,
318
  subtype
 
320
  WHERE id = ?
321
  """
322
  anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
323
+
 
324
  true_candidate = Candidate(
325
  candidate_id="c1",
326
  source=anchor_source,
 
330
  country=safe_get(anchor_row, 'country'),
331
  region=safe_get(anchor_row, 'region'),
332
  admin_level=safe_get(anchor_row, 'admin_level'),
333
+ similarity=1.0,
334
  )
335
+
 
336
  distractors = build_distractors(
337
+ con,
338
+ anchor_name,
339
  anchor_source,
340
  anchor_id,
341
  num_candidates - 1,
342
+ difficulty,
343
  )
 
 
 
 
 
344
 
345
+ # Deduplicate by underlying entity id while preserving order.
346
+ # Some parquet sources contain repeated rows for the same feature id,
347
+ # which can otherwise leak duplicate candidates into the dataset.
348
+ candidates: List[Candidate] = []
349
+ seen_ids: set[str] = set()
350
+ for cand in [true_candidate] + distractors:
351
+ if cand.id in seen_ids:
352
+ continue
353
+ candidates.append(cand)
354
+ seen_ids.add(cand.id)
355
+ if len(candidates) >= num_candidates:
356
+ break
357
+
358
  for i, cand in enumerate(candidates, 1):
359
  cand.candidate_id = f"c{i}"
360
 
 
388
 
389
  def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
390
  query = """
391
+ WITH ranked AS (
392
+ SELECT
393
+ id,
394
+ names."primary" AS name,
395
+ subtype,
396
+ country,
397
+ region,
398
+ admin_level,
399
+ jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity,
400
+ ROW_NUMBER() OVER (
401
+ PARTITION BY id
402
+ ORDER BY jaro_winkler_similarity(lower(names."primary"), lower(?)) DESC
403
+ ) AS rn
404
+ FROM read_parquet(?)
405
+ WHERE id != ?
406
+ AND names."primary" IS NOT NULL
407
+ AND trim(names."primary") != ''
408
+ AND geometry IS NOT NULL
409
+ )
410
  SELECT
411
  id,
412
+ name,
413
  subtype,
414
  country,
415
  region,
416
  admin_level,
417
+ similarity
418
+ FROM ranked
419
+ WHERE rn = 1
 
420
  ORDER BY similarity DESC
421
  LIMIT ?
422
  """
423
+ df = con.execute(query, [anchor_name, anchor_name, path, excl_id, n]).fetchdf()
424
  results = []
425
  for _, row in df.iterrows():
426
  results.append(Candidate(
 
586
  def sample_random_entity(
587
  con: duckdb.DuckDBPyConnection,
588
  inventory_df: pd.DataFrame,
589
+ source: str,
590
+ subtypes: Optional[set[str]] = None,
591
+ countries: Optional[set[str]] = None,
592
  ) -> Optional[Dict[str, Any]]:
593
+ """Sample a random entity from inventory with optional filters."""
594
  if inventory_df.empty:
595
  return None
596
+
597
+ df = inventory_df
598
+ if subtypes is not None:
599
+ df = df[df['subtype'].isin(subtypes)]
600
+ if countries is not None and 'country' in df.columns:
601
+ df = df[df['country'].isin(countries)]
602
+ if df.empty:
603
+ return None
604
+
605
+ row = df.sample(n=1).iloc[0]
606
  return {
607
  'id': row['id'],
608
  'name': row['name'],
 
626
  if template.anchor_source == "divisions_area":
627
  anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
628
  else:
629
+ anchor = sample_random_entity(
630
+ con,
631
+ tables['natural_earth_inventory'],
632
+ 'natural_earth',
633
+ subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES),
634
+ )
635
 
636
  if not anchor:
637
  return None
 
722
  anchor = {"id": pair["contained_id"], "name": pair["contained_name"]}
723
 
724
  elif template.family == "adjacency":
725
+ # adj_03/04/05 target natural_earth features (seas, rivers, ranges).
726
+ # Their SQL hardcodes NE subtypes and does not use {target_subtype}.
727
+ # Sample from cross_source_relations so the anchor is a division
728
+ # that actually intersects the right NE features.
729
+ _NE_ADJ_SUBTYPES = {
730
+ "adj_03": ("ocean", "sea"),
731
+ "adj_04": ("River", "Lake", "Basin"),
732
+ "adj_05": ("Range/mtn", "Peninsula", "Depression"),
733
+ }
734
+ if template.template_id in _NE_ADJ_SUBTYPES:
735
+ cs_df = tables.get('cross_source_relations', pd.DataFrame())
736
+ if cs_df.empty:
737
+ return None
738
+ ne_types = _NE_ADJ_SUBTYPES[template.template_id]
739
+ filtered = cs_df[cs_df['natural_subtype'].isin(ne_types)]
740
+ if filtered.empty:
741
+ return None
742
+ row = filtered.sample(n=1).iloc[0]
743
+ anchor = {
744
+ 'anchor_id': row['division_id'],
745
+ 'anchor_name': row['division_name'],
746
+ 'anchor_subtype': row['division_subtype'],
747
+ 'target_subtype': row['natural_subtype'],
748
+ }
749
+ else:
750
+ # adj_01/02/06: divisions_area self-join adjacency.
751
+ # Only filter by target_subtype when the SQL uses {target_subtype}.
752
+ filter_subtype = (
753
+ template.target_subtype
754
+ if '{target_subtype}' in template.sql_template
755
+ else None
756
+ )
757
+ anchor = sample_adjacency_anchor(
758
+ tables['adjacency_pairs'],
759
+ target_subtype=filter_subtype,
760
+ )
761
  if not anchor:
762
  return None
763
 
764
  sql = template.sql_template.format(
765
  anchor_id=anchor['anchor_id'],
766
+ target_subtype=anchor.get('target_subtype', ''),
767
  )
768
+
769
  candidates = build_candidate_list(
770
  con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
771
  num_candidates=10, difficulty="medium"
772
  )
773
+
774
  question = random.choice(template.question_hints).format(
775
  anchor_name=anchor['anchor_name'],
776
+ target_subtype=anchor.get('target_subtype', ''),
777
  )
778
 
779
  elif template.family == "containment":
780
+ if template.anchor_source == "natural_earth":
781
+ # contain_03 / contain_04: NE anchor (sea, desert, etc.).
782
+ # Use cross_source_relations so the anchor exists in natural_earth
783
+ # and is guaranteed to intersect divisions_area features.
784
+ cs_anchor = sample_cross_source_anchor(
785
+ tables.get('cross_source_relations', pd.DataFrame()),
786
+ natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
787
+ )
788
+ if not cs_anchor:
789
+ return None
790
+ anchor_id = cs_anchor['natural_id']
791
+ anchor_name = cs_anchor['natural_name']
792
+ target_subtype = template.target_subtype or 'country'
793
+
794
+ sql = template.sql_template.format(
795
+ anchor_id=anchor_id,
796
+ target_subtype=target_subtype,
797
+ )
798
+ candidates = build_candidate_list(
799
+ con, anchor_id, anchor_name, 'natural_earth',
800
+ num_candidates=10, difficulty="medium"
801
+ )
802
+ question = random.choice(template.question_hints).format(
803
+ anchor_name=anchor_name,
804
+ target_subtype=target_subtype,
805
+ )
806
+ anchor = {'id': anchor_id, 'name': anchor_name}
807
+
808
+ elif template.template_id == "contain_02":
809
+ # "What country contains X?" - anchor is the CONTAINED entity;
810
+ # result is the country that ST_Contains it.
811
+ df = tables['containment_pairs']
812
+ df = df[df['container_subtype'] == 'country']
813
+ pair = sample_containment_anchor(df)
814
+ if not pair:
815
+ return None
816
+
817
+ sql = template.sql_template.format(
818
+ anchor_id=pair['contained_id'],
819
+ target_subtype='country',
820
+ )
821
+ candidates = build_candidate_list(
822
+ con, pair['contained_id'], pair['contained_name'], 'divisions_area',
823
+ num_candidates=10, difficulty="medium"
824
+ )
825
+ question = random.choice(template.question_hints).format(
826
+ anchor_name=pair['contained_name'],
827
+ target_subtype='country',
828
+ )
829
+ anchor = {'id': pair['contained_id'], 'name': pair['contained_name']}
830
+
831
+ else:
832
+ # contain_01: standard containment.
833
+ # Anchor = container, target_subtype = contained entity's subtype.
834
+ anchor = sample_containment_anchor(tables['containment_pairs'])
835
+ if not anchor:
836
+ return None
837
+
838
+ sql = template.sql_template.format(
839
+ anchor_id=anchor['container_id'],
840
+ target_subtype=anchor['contained_subtype'],
841
+ )
842
+ candidates = build_candidate_list(
843
+ con, anchor['container_id'], anchor['container_name'], 'divisions_area',
844
+ num_candidates=10, difficulty="medium"
845
+ )
846
+ question = random.choice(template.question_hints).format(
847
+ anchor_name=anchor['container_name'],
848
+ target_subtype=anchor['contained_subtype'],
849
+ )
850
 
851
  elif template.family == "intersection":
852
  if template.anchor_source == "natural_earth":
853
+ anchor = sample_cross_source_anchor(
854
+ tables['cross_source_relations'],
855
+ natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
856
+ )
857
  if not anchor:
858
  return None
859
+
860
+ target_subtype = template.target_subtype or 'country'
861
+
862
  sql = template.sql_template.format(
863
  anchor_id=anchor['natural_id'],
864
+ target_subtype=target_subtype,
865
  )
866
+
867
  candidates = build_candidate_list(
868
  con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
869
  num_candidates=10, difficulty="medium"
870
  )
871
+
872
  question = random.choice(template.question_hints).format(
873
  anchor_name=anchor['natural_name'],
874
+ target_subtype=target_subtype,
875
  )
876
  else:
877
  # Same-source intersection
 
932
  # country IN clause — 2 or 3 anchors, each contributes its country code
933
  num_a = 3 if template.template_id == "contain_multi_02" else 2
934
  anchors = [
935
+ sample_random_entity(
936
+ con,
937
+ tables['divisions_area_inventory'],
938
+ 'divisions_area',
939
+ subtypes={'country'},
940
+ )
941
  for _ in range(num_a)
942
  ]
943
  if any(a is None for a in anchors):
944
  return None
945
 
946
  countries = [a.get('country') or 'US' for a in anchors]
947
+ target_subtype = template.target_subtype or 'region'
948
  per_anchor = 3 if num_a == 3 else 4
949
 
950
  fmt_kwargs = dict(
 
1027
 
1028
  if template.num_anchors == 1:
1029
  if template.anchor_source == "natural_earth":
1030
+ anchor = sample_random_entity(
1031
+ con,
1032
+ tables['natural_earth_inventory'],
1033
+ 'natural_earth',
1034
+ subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id, _NE_NAMED_LOOKUP_SUBTYPES),
1035
+ )
1036
  else:
1037
  anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
1038
  if not anchor:
 
1126
  # Mixed-source clip: division intersected with a natural_earth feature.
1127
  # Use cross_source_relations so the pair is guaranteed to intersect —
1128
  # random sampling almost never produces an intersecting pair.
1129
+ cs_anchor = sample_cross_source_anchor(
1130
+ tables.get('cross_source_relations', pd.DataFrame()),
1131
+ natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
1132
+ )
1133
+ if not cs_anchor:
1134
  return None
 
1135
  clip_feature = {
1136
+ 'id': cs_anchor['natural_id'],
1137
+ 'name': cs_anchor['natural_name'],
1138
  'source': 'natural_earth',
1139
  }
1140
  # Override the division anchor with the paired division so the
1141
  # ST_Intersects check in the SQL is guaranteed to pass.
1142
  anchor = {
1143
+ 'id': cs_anchor['division_id'],
1144
+ 'name': cs_anchor['division_name'],
1145
  'source': 'divisions_area',
1146
  }
1147
 
 
1170
  target_subtype = random.choice(['locality', 'region'])
1171
 
1172
  if template.template_id in ['agg_03', 'agg_04']:
1173
+ # Country-level aggregation: SQL uses country code, so the anchor
1174
+ # in the question must also be a country.
1175
+ anchor = sample_random_entity(
1176
+ con,
1177
+ tables['divisions_area_inventory'],
1178
+ 'divisions_area',
1179
+ subtypes={'country'},
1180
+ )
1181
  if not anchor:
1182
  return None
1183
 
 
1225
  elif template.family == "chained":
1226
  # Use pre-filtered coastal/landlocked containment pairs so the SQL
1227
  # verification step doesn't constantly return empty results.
1228
+ _COASTAL_CHAINED = {"chained_01", "chained_06", "chained_10"}
1229
+ _LANDLOCKED_CHAINED = {"chained_02", "chained_07", "chained_11"}
1230
+ if template.template_id in _COASTAL_CHAINED:
1231
  table_key = 'coastal_containment_pairs'
1232
+ elif template.template_id in _LANDLOCKED_CHAINED:
1233
  table_key = 'landlocked_containment_pairs'
1234
  else:
1235
  table_key = 'containment_pairs'
1236
 
 
 
 
 
 
 
 
1237
  df = tables.get(table_key, tables['containment_pairs'])
1238
+
1239
+ # When the template pins a target_subtype (e.g. chained_06 wants
1240
+ # counties), only consider pairs whose contained entity already
1241
+ # matches — guarantees the sampled container holds at least one
1242
+ # entity of the right subtype so the SQL filter returns rows.
1243
+ if template.target_subtype:
1244
+ df = df[df['contained_subtype'] == template.target_subtype]
1245
+
1246
+ # chained_10/11 additionally need a country-level container so
1247
+ # phrasings like "coastal states of India" line up.
1248
+ if template.template_id in {"chained_10", "chained_11"}:
1249
+ df = df[df['container_subtype'] == 'country']
1250
 
1251
  anchor = sample_containment_anchor(df)
1252
  if not anchor:
1253
  return None
1254
 
 
 
 
1255
  target_subtype = template.target_subtype or anchor.get('contained_subtype', 'locality')
1256
 
1257
  sql = template.sql_template.format(
 
1305
  # Use cross_source_relations so the pair is guaranteed to intersect
1306
  # (ST_Difference on non-intersecting geometries is always equal to
1307
  # the original geometry — a trivial and uninformative sample).
1308
+ cs_anchor = sample_cross_source_anchor(
1309
+ tables.get('cross_source_relations', pd.DataFrame()),
1310
+ natural_subtypes=_NE_TEMPLATE_SUBTYPES.get(template.template_id),
1311
+ )
1312
+ if not cs_anchor:
1313
  return None
 
1314
  anchor = {
1315
+ 'id': cs_anchor['division_id'],
1316
+ 'name': cs_anchor['division_name'],
1317
  'source': 'divisions_area',
1318
  }
1319
  clip_feature = {
1320
+ 'id': cs_anchor['natural_id'],
1321
+ 'name': cs_anchor['natural_name'],
1322
  'source': 'natural_earth',
1323
  }
1324
 
 
1343
  candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
1344
 
1345
  else:
1346
+ # Two divisions_area anchors: use both ends of a containment
1347
+ # pair so the contained entity is guaranteed to intersect the
1348
+ # container. ST_Difference(container, contained) yields the
1349
+ # portion of the container outside the contained piece.
1350
  pair = sample_containment_anchor(tables['containment_pairs'])
1351
  if not pair:
1352
  return None
1353
 
1354
  anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
1355
+ anchor2 = {'id': pair['contained_id'], 'name': pair['contained_name']}
 
 
 
1356
 
1357
  sql = template.sql_template.format(
1358
+ anchor_id_1=anchor1['id'],
1359
  anchor_id_2=anchor2['id'],
1360
  )
1361
 
 
1380
  if not pair:
1381
  return None
1382
 
 
 
1383
  anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
1384
+ anchor2 = {'id': pair['target_id'], 'name': pair['target_name']}
 
 
 
 
 
 
 
 
 
1385
 
1386
  buffer_val = random.choice([5, 10, 25, 50])
1387
 
 
1408
  )
1409
 
1410
  elif template.family == "window_function":
1411
+ anchor = sample_random_entity(
1412
+ con,
1413
+ tables['divisions_area_inventory'],
1414
+ 'divisions_area',
1415
+ subtypes={'country'},
1416
+ )
1417
  if not anchor:
1418
  return None
1419
 
1420
  country = anchor.get('country') or 'US'
1421
+ target_subtype = template.target_subtype or 'locality'
1422
 
1423
  sql = template.sql_template.format(
1424
  country=country,
 
1436
  )
1437
 
1438
  elif template.family == "attribute_filter":
1439
+ anchor = sample_random_entity(
1440
+ con,
1441
+ tables['divisions_area_inventory'],
1442
+ 'divisions_area',
1443
+ subtypes={'country'},
1444
+ )
1445
  if not anchor:
1446
  return None
1447
 
1448
  country = anchor.get('country') or 'US'
1449
+ target_subtype = template.target_subtype or 'region'
1450
 
1451
  sql = template.sql_template.format(
1452
  country=country,
dataset/scripts/normalize_geodata.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalize source GeoParquet files to a shared CRS-neutral geometry encoding.
2
+
3
+ The training pipeline mixes Overture divisions_area and Natural Earth geometry.
4
+ Across environments these sources can advertise different CRS metadata labels
5
+ (`EPSG:4326` vs `OGC:CRS84`), which causes DuckDB spatial joins to fail even
6
+ when coordinates are already compatible lon/lat values.
7
+
8
+ This script rewrites both datasets into normalized copies whose geometry column
9
+ is rebuilt from WKB. That preserves coordinates while dropping conflicting CRS
10
+ metadata, so downstream joins behave consistently locally and on Modal.
11
+
12
+ Output layout under data/ by default:
13
+ overture_normalized/divisions_area/part-000.parquet
14
+ natural_earth_normalized/ne_geography.parquet
15
+ """
16
+
17
+ from pathlib import Path
18
+
19
+ import duckdb
20
+
21
+ from gazet.config import _DATA_DIR
22
+
23
+
24
+ def normalize_geodata(output_root: Path | None = None) -> dict[str, str]:
25
+ """Write normalized copies of both source datasets.
26
+
27
+ Args:
28
+ output_root: Base directory to write normalized datasets into.
29
+ Defaults to the project data dir.
30
+
31
+ Returns:
32
+ Mapping of dataset name to written path/glob.
33
+ """
34
+ root = output_root or _DATA_DIR
35
+ overture_dir = root / "overture_normalized" / "divisions_area"
36
+ natural_earth_dir = root / "natural_earth_normalized"
37
+ overture_dir.mkdir(parents=True, exist_ok=True)
38
+ natural_earth_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ overture_path = overture_dir / "part-000.parquet"
41
+ natural_earth_path = natural_earth_dir / "ne_geography.parquet"
42
+
43
+ con = duckdb.connect()
44
+ con.execute("INSTALL spatial")
45
+ con.execute("LOAD spatial")
46
+
47
+ # Rebuild geometry from WKB so conflicting CRS metadata is dropped.
48
+ con.execute(
49
+ f"""
50
+ COPY (
51
+ SELECT * REPLACE (
52
+ ST_GeomFromWKB(ST_AsWKB(geometry)) AS geometry
53
+ )
54
+ FROM read_parquet('{root / 'overture/divisions_area/*.parquet'}')
55
+ WHERE geometry IS NOT NULL
56
+ ) TO '{overture_path}' (FORMAT PARQUET)
57
+ """
58
+ )
59
+
60
+ con.execute(
61
+ f"""
62
+ COPY (
63
+ SELECT * REPLACE (
64
+ ST_GeomFromWKB(ST_AsWKB(geometry)) AS geometry
65
+ )
66
+ FROM read_parquet('{root / 'natural_earth_geoparquet/ne_geography.parquet'}')
67
+ WHERE geometry IS NOT NULL
68
+ ) TO '{natural_earth_path}' (FORMAT PARQUET)
69
+ """
70
+ )
71
+ con.close()
72
+
73
+ return {
74
+ "divisions_area": str(overture_dir / "*.parquet"),
75
+ "natural_earth": str(natural_earth_path),
76
+ }
77
+
78
+
79
+ def main() -> None:
80
+ result = normalize_geodata()
81
+ print("Normalized datasets written:")
82
+ for name, path in result.items():
83
+ print(f" {name}: {path}")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
dataset/scripts/sql_templates.py CHANGED
@@ -293,7 +293,7 @@ TEMPLATES = [
293
  " ST_AsGeoJSON(n.geometry) AS geometry"
294
  " FROM read_parquet('natural_earth') AS n, a"
295
  " WHERE n.subtype IN ('ocean', 'sea')"
296
- " AND ST_Touches(a.geometry, n.geometry)"
297
  ),
298
  question_hints=[
299
  "which seas touch {anchor_name}?",
@@ -679,7 +679,7 @@ TEMPLATES = [
679
  sql_difficulty="hard",
680
  anchor_source="divisions_area",
681
  num_anchors=1,
682
- target_subtype="country",
683
  sql_template=(
684
  "WITH region AS ("
685
  " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
@@ -723,7 +723,7 @@ TEMPLATES = [
723
  " AND ST_Within(b.geometry, region.geometry)"
724
  " AND EXISTS ("
725
  " SELECT 1 FROM read_parquet('natural_earth') AS n"
726
- " WHERE n.subtype IN ('Terrain area', 'Island group', 'Peninsula')"
727
  " AND ST_Intersects(b.geometry, n.geometry)"
728
  " )"
729
  ),
@@ -1301,12 +1301,12 @@ TEMPLATES = [
1301
  " AND subtype = '{target_subtype}'"
1302
  ),
1303
  question_hints=[
1304
- "island territories of {anchor_name}",
1305
- "overseas island {target_subtype}s belonging to {anchor_name}",
1306
- "which islands are part of {anchor_name}?",
1307
- "land territories of {anchor_name}",
1308
- "island possessions of {anchor_name}",
1309
- "{anchor_name}'s island {target_subtype}s",
1310
  ],
1311
  ),
1312
 
@@ -1339,20 +1339,20 @@ TEMPLATES = [
1339
  sql_difficulty="medium",
1340
  anchor_source="divisions_area",
1341
  num_anchors=1,
1342
- target_subtype="locality",
1343
  sql_template=(
1344
  "SELECT id, names.\"primary\" AS name, subtype, country,"
1345
  " ST_AsGeoJSON(geometry) AS geometry"
1346
  " FROM read_parquet('divisions_area')"
1347
  " WHERE country = '{country}'"
1348
  " AND subtype = '{target_subtype}'"
1349
- " AND is_land = TRUE"
1350
  ),
1351
  question_hints=[
1352
- "land-based {target_subtype}s of {anchor_name}",
1353
- "{target_subtype}s on the mainland of {anchor_name}",
1354
- "all {target_subtype}s on land in {anchor_name}",
1355
- "non-island {target_subtype}s of {anchor_name}",
1356
  ],
1357
  ),
1358
 
@@ -1403,7 +1403,7 @@ TEMPLATES = [
1403
  " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
1404
  " ST_AsGeoJSON(n.geometry) AS geometry"
1405
  " FROM read_parquet('natural_earth') AS n, a"
1406
- " WHERE n.subtype IN ('Range/Mts', 'Terrain area', 'Peninsula', 'Depression')"
1407
  " AND ST_Intersects(a.geometry, n.geometry)"
1408
  ),
1409
  question_hints=[
@@ -1533,7 +1533,7 @@ TEMPLATES = [
1533
  " AND ST_Within(b.geometry, region.geometry)"
1534
  " AND EXISTS ("
1535
  " SELECT 1 FROM read_parquet('natural_earth') AS n"
1536
- " WHERE n.subtype IN ('Range/Mts', 'Depression')"
1537
  " AND ST_Intersects(b.geometry, n.geometry)"
1538
  " )"
1539
  ),
@@ -1666,7 +1666,7 @@ TEMPLATES = [
1666
  " AND ST_Within(b.geometry, region.geometry)"
1667
  " AND EXISTS ("
1668
  " SELECT 1 FROM read_parquet('natural_earth') AS n"
1669
- " WHERE n.subtype IN ('Range/Mts', 'Depression')"
1670
  " AND ST_Intersects(b.geometry, n.geometry)"
1671
  " )"
1672
  ),
 
293
  " ST_AsGeoJSON(n.geometry) AS geometry"
294
  " FROM read_parquet('natural_earth') AS n, a"
295
  " WHERE n.subtype IN ('ocean', 'sea')"
296
+ " AND ST_Intersects(a.geometry, n.geometry)"
297
  ),
298
  question_hints=[
299
  "which seas touch {anchor_name}?",
 
679
  sql_difficulty="hard",
680
  anchor_source="divisions_area",
681
  num_anchors=1,
682
+ target_subtype="locality",
683
  sql_template=(
684
  "WITH region AS ("
685
  " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
 
723
  " AND ST_Within(b.geometry, region.geometry)"
724
  " AND EXISTS ("
725
  " SELECT 1 FROM read_parquet('natural_earth') AS n"
726
+ " WHERE n.subtype IN ('Range/mtn', 'Island group', 'Peninsula', 'Depression')"
727
  " AND ST_Intersects(b.geometry, n.geometry)"
728
  " )"
729
  ),
 
1301
  " AND subtype = '{target_subtype}'"
1302
  ),
1303
  question_hints=[
1304
+ "land {target_subtype}s of {anchor_name}",
1305
+ "dependencies of {anchor_name} that are on land",
1306
+ "which land dependencies belong to {anchor_name}?",
1307
+ "{anchor_name}'s land {target_subtype}s",
1308
+ "dependencies of {anchor_name} with land area",
1309
+ "show the land dependencies of {anchor_name}",
1310
  ],
1311
  ),
1312
 
 
1339
  sql_difficulty="medium",
1340
  anchor_source="divisions_area",
1341
  num_anchors=1,
1342
+ target_subtype="region",
1343
  sql_template=(
1344
  "SELECT id, names.\"primary\" AS name, subtype, country,"
1345
  " ST_AsGeoJSON(geometry) AS geometry"
1346
  " FROM read_parquet('divisions_area')"
1347
  " WHERE country = '{country}'"
1348
  " AND subtype = '{target_subtype}'"
1349
+ " AND is_land = FALSE"
1350
  ),
1351
  question_hints=[
1352
+ "offshore {target_subtype}s of {anchor_name}",
1353
+ "{target_subtype}s of {anchor_name} that are not on land",
1354
+ "water-associated {target_subtype}s of {anchor_name}",
1355
+ "marine or offshore {target_subtype}s of {anchor_name}",
1356
  ],
1357
  ),
1358
 
 
1403
  " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
1404
  " ST_AsGeoJSON(n.geometry) AS geometry"
1405
  " FROM read_parquet('natural_earth') AS n, a"
1406
+ " WHERE n.subtype IN ('Range/mtn', 'Peninsula', 'Depression')"
1407
  " AND ST_Intersects(a.geometry, n.geometry)"
1408
  ),
1409
  question_hints=[
 
1533
  " AND ST_Within(b.geometry, region.geometry)"
1534
  " AND EXISTS ("
1535
  " SELECT 1 FROM read_parquet('natural_earth') AS n"
1536
+ " WHERE n.subtype IN ('Range/mtn', 'Depression')"
1537
  " AND ST_Intersects(b.geometry, n.geometry)"
1538
  " )"
1539
  ),
 
1666
  " AND ST_Within(b.geometry, region.geometry)"
1667
  " AND EXISTS ("
1668
  " SELECT 1 FROM read_parquet('natural_earth') AS n"
1669
+ " WHERE n.subtype IN ('Range/mtn', 'Depression')"
1670
  " AND ST_Intersects(b.geometry, n.geometry)"
1671
  " )"
1672
  ),
dataset/scripts/validate_dataset.py CHANGED
@@ -44,8 +44,8 @@ def _resolve_paths(sql: str) -> str:
44
  "read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
45
  )
46
  # Legacy fixed Docker paths from earlier dataset versions
47
- sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
48
- sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
49
  sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
50
  return sql
51
 
 
44
  "read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
45
  )
46
  # Legacy fixed Docker paths from earlier dataset versions
47
+ sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
48
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
49
  sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
50
  return sql
51
 
finetune/README.md CHANGED
@@ -26,7 +26,7 @@ SQL samples are long (schema + candidates + SQL), places samples are short.
26
 
27
  ```bash
28
  modal run finetune/check_token_lengths.py
29
- modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/v1
30
  ```
31
 
32
  This prints per-split statistics (min, max, P95, P99) and recommends a
@@ -66,7 +66,7 @@ overridden, `lora_alpha` is automatically set to `2 * r`.
66
 
67
  ```
68
  base_model: unsloth/Qwen3.5-0.8B
69
- run_dir: /mnt/gazet/data/v1
70
  lora_r: 16
71
  lora_alpha: 32 (2 * r, Unsloth recommendation for Qwen)
72
  lora_dropout: 0.0
@@ -106,19 +106,19 @@ for local inference with llama-server.
106
 
107
  ```bash
108
  # Download from Modal volume
109
- modal volume get gazet checkpoints/qwen35-v1/merged ./finetune/models/merged
110
 
111
  # Convert to GGUF (requires llama.cpp repo)
112
- uv run \
113
  --no-project \
114
  --with transformers \
115
  --with sentencepiece \
116
  --with protobuf \
117
  --with torch \
118
  python convert_hf_to_gguf.py \
119
- ../gazet/finetune/models/qwen-base/merged \
120
  --outtype q8_0 \
121
- --outfile ../gazet/finetune/models/qwen-base/ckpt-001.gguf
122
  ```
123
 
124
  ---
@@ -129,7 +129,7 @@ uv run \
129
 
130
  ```bash
131
  llama-server \
132
- -m finetune/models/qwen-base/ckpt-001.gguf \
133
  -ngl 99 \
134
  --port 9000 \
135
  --ctx-size 2048
@@ -151,7 +151,7 @@ docker run \
151
  -v $(pwd)/finetune/models:/models \
152
  -p 9000:9000 \
153
  ghcr.io/ggml-org/llama.cpp:server \
154
- -m /models/qwen-base/ckpt-001.gguf \
155
  --port 9000 --host 0.0.0.0 \
156
  --ctx-size 2048 -t 2 -v
157
  ```
@@ -211,7 +211,7 @@ All batch CLI args:
211
  | `--label` | `local-gguf` | Label used in the output filename |
212
  | `--task` | `sql` | `sql` or `places` |
213
  | `--split` | `val` | Data split to evaluate (`val`, `test`) |
214
- | `--run-dir` | `dataset/output/runs/v1` | Directory with `{task}/{split}.jsonl` |
215
  | `--max-samples` | all | Cap the number of samples |
216
  | `--output` | `eval-{label}-{task}.json` | Output JSON path |
217
  | `--workers` | `4` | Concurrent requests; match llama-server `--parallel` |
@@ -250,6 +250,20 @@ GAZET_EVAL_DIR=/path/to/results streamlit run finetune/eval_demo.py
250
  ```
251
 
252
  Set `GAZET_DATA_DIR` if your parquet data is not in the default `data/` directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  ---
255
 
@@ -287,5 +301,9 @@ loss is computed only on the completion tokens.
287
 
288
  SQL in the training data uses symbolic path placeholders
289
  (`read_parquet('divisions_area')`) instead of real file paths. At inference
290
- time, `src/gazet/sql.py` replaces these with actual runtime paths before
291
- executing against DuckDB.
 
 
 
 
 
26
 
27
  ```bash
28
  modal run finetune/check_token_lengths.py
29
+ modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/smalltest-v1
30
  ```
31
 
32
  This prints per-split statistics (min, max, P95, P99) and recommends a
 
66
 
67
  ```
68
  base_model: unsloth/Qwen3.5-0.8B
69
+ run_dir: /mnt/gazet/data/v1 # override to your exported run, e.g. /mnt/gazet/data/smalltest-v1
70
  lora_r: 16
71
  lora_alpha: 32 (2 * r, Unsloth recommendation for Qwen)
72
  lora_dropout: 0.0
 
106
 
107
  ```bash
108
  # Download from Modal volume
109
+ modal volume get gazet checkpoints/qwen35-v1/merged ./finetune/models/qwen35-v1-merged
110
 
111
  # Convert to GGUF (requires llama.cpp repo)
112
+ uv run \
113
  --no-project \
114
  --with transformers \
115
  --with sentencepiece \
116
  --with protobuf \
117
  --with torch \
118
  python convert_hf_to_gguf.py \
119
+ ./finetune/models/qwen35-v1-merged \
120
  --outtype q8_0 \
121
+ --outfile ./finetune/models/qwen35-v1-q8_0.gguf
122
  ```
123
 
124
  ---
 
129
 
130
  ```bash
131
  llama-server \
132
+ -m finetune/models/qwen35-v1-q8_0.gguf \
133
  -ngl 99 \
134
  --port 9000 \
135
  --ctx-size 2048
 
151
  -v $(pwd)/finetune/models:/models \
152
  -p 9000:9000 \
153
  ghcr.io/ggml-org/llama.cpp:server \
154
+ -m /models/qwen35-v1-q8_0.gguf \
155
  --port 9000 --host 0.0.0.0 \
156
  --ctx-size 2048 -t 2 -v
157
  ```
 
211
  | `--label` | `local-gguf` | Label used in the output filename |
212
  | `--task` | `sql` | `sql` or `places` |
213
  | `--split` | `val` | Data split to evaluate (`val`, `test`) |
214
+ | `--run-dir` | `dataset/output/runs/v1` | Directory with `{task}/{split}.jsonl`; override to your exported run, e.g. `dataset/output/runs/smalltest-v1` |
215
  | `--max-samples` | all | Cap the number of samples |
216
  | `--output` | `eval-{label}-{task}.json` | Output JSON path |
217
  | `--workers` | `4` | Concurrent requests; match llama-server `--parallel` |
 
250
  ```
251
 
252
  Set `GAZET_DATA_DIR` if your parquet data is not in the default `data/` directory.
253
+ This only affects the visual SQL viewer (`eval_demo.py`), which executes SQL
254
+ against DuckDB; `eval_cli.py` does not read parquet files directly.
255
+
256
+ The eval viewer resolves parquet paths through `gazet.config`, which now
257
+ prefers normalized copies automatically when present:
258
+
259
+ - `data/overture_normalized/divisions_area/*.parquet`
260
+ - `data/natural_earth_normalized/ne_geography.parquet`
261
+
262
+ Disable that fallback only if needed with:
263
+
264
+ ```bash
265
+ GAZET_USE_NORMALIZED_DATA=0 streamlit run finetune/eval_demo.py
266
+ ```
267
 
268
  ---
269
 
 
301
 
302
  SQL in the training data uses symbolic path placeholders
303
  (`read_parquet('divisions_area')`) instead of real file paths. At inference
304
+ and eval time, `src/gazet/sql.py` / `finetune/eval_demo.py` replace these with
305
+ actual runtime paths before executing against DuckDB. When normalized parquet
306
+ copies are present, `gazet.config` prefers:
307
+
308
+ - `data/overture_normalized/divisions_area/*.parquet`
309
+ - `data/natural_earth_normalized/ne_geography.parquet`
finetune/eval_demo.py CHANGED
@@ -16,6 +16,8 @@ import pydeck as pdk
16
  import sqlparse
17
  import streamlit as st
18
 
 
 
19
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
20
  DATA_DIR = pathlib.Path(
21
  os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
@@ -31,13 +33,17 @@ def load_eval_results(path):
31
 
32
 
33
  def rewrite_data_paths(sql):
34
- """Replace symbolic and legacy paths with actual local data paths."""
35
- # Legacy fixed Docker paths must be replaced first to avoid double-expansion
 
 
 
 
 
 
36
  sql = sql.replace("/data/", f"{DATA_DIR}/")
37
- div_path = str(DATA_DIR / "overture" / "divisions_area" / "*.parquet")
38
- ne_path = str(DATA_DIR / "natural_earth_geoparquet" / "ne_geography.parquet")
39
- sql = sql.replace("read_parquet('divisions_area')", f"read_parquet('{div_path}')")
40
- sql = sql.replace("read_parquet('natural_earth')", f"read_parquet('{ne_path}')")
41
  return sql
42
 
43
 
 
16
  import sqlparse
17
  import streamlit as st
18
 
19
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
20
+
21
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
22
  DATA_DIR = pathlib.Path(
23
  os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
 
33
 
34
 
35
  def rewrite_data_paths(sql):
36
+ """Replace symbolic and legacy paths with the configured runtime data paths."""
37
+ # Legacy fixed Docker paths must be replaced first to avoid double-expansion.
38
+ sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
39
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
40
+ sql = sql.replace(
41
+ "/data/natural_earth_geoparquet/ne_geography.parquet",
42
+ NATURAL_EARTH_PATH,
43
+ )
44
  sql = sql.replace("/data/", f"{DATA_DIR}/")
45
+ sql = sql.replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
46
+ sql = sql.replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
 
 
47
  return sql
48
 
49
 
finetune/train_modal_qwen35.py CHANGED
@@ -101,7 +101,7 @@ class Qwen35Config:
101
  # Logging / saving
102
  logging_steps: int = 10
103
  save_strategy: str = "steps"
104
- save_steps: int = 400
105
  eval_strategy: str = "steps"
106
  eval_steps: int = 200
107
  report_to: str = "trackio"
 
101
  # Logging / saving
102
  logging_steps: int = 10
103
  save_strategy: str = "steps"
104
+ save_steps: int = 1000
105
  eval_strategy: str = "steps"
106
  eval_steps: int = 200
107
  report_to: str = "trackio"
gazet_demo.py CHANGED
@@ -68,24 +68,58 @@ def view_state_for_bbox(bbox, padding_zoom=0.8):
68
  return pdk.ViewState(latitude=lat, longitude=lng, zoom=zoom)
69
 
70
 
 
 
 
 
 
 
 
 
 
 
71
  def _render_map(geojson, placeholder):
72
- n = len(geojson.get("features", []))
 
73
  if pdk and n:
 
74
  layer = pdk.Layer(
75
  "GeoJsonLayer",
76
  data=geojson,
77
- get_fill_color=[40, 180, 160, 200],
78
- get_line_color=[125, 211, 192, 255],
79
- get_line_width=2,
 
 
 
80
  pickable=True,
81
  )
82
- bbox = bbox_from_geojson(geojson)
83
- view = (
84
- view_state_for_bbox(bbox)
85
- if bbox
86
- else pdk.ViewState(latitude=0, longitude=0, zoom=1)
87
- )
88
  with placeholder.container():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  st.pydeck_chart(
90
  pdk.Deck(
91
  layers=[layer],
@@ -128,6 +162,8 @@ st.sidebar.caption(
128
 
129
  if "run_q" not in st.session_state:
130
  st.session_state.run_q = None
 
 
131
 
132
  col1, col2 = st.columns([1, 2])
133
  with col1:
@@ -150,6 +186,7 @@ with col2:
150
  to_run = st.session_state.run_q
151
  if to_run:
152
  st.session_state.run_q = None
 
153
 
154
  status_ph = st.empty()
155
  map_ph = st.empty()
@@ -159,6 +196,8 @@ with col2:
159
 
160
  status_ph.info("Extracting places…")
161
 
 
 
162
  try:
163
  with requests.get(
164
  f"{API}/search/stream", params={"q": to_run, "backend": backend}, stream=True, timeout=120
@@ -173,6 +212,7 @@ with col2:
173
 
174
  if t == "places":
175
  places = event["data"].get("places", [])
 
176
  status_ph.info("Fuzzy-matching candidates…")
177
  if places:
178
  with places_ph.container():
@@ -192,6 +232,7 @@ with col2:
192
  )
193
 
194
  elif t == "candidates":
 
195
  status_ph.info("Generating SQL…")
196
  with candidates_ph.container():
197
  with st.expander("Candidate datasets", expanded=True):
@@ -203,6 +244,7 @@ with col2:
203
 
204
  elif t == "sql_attempt":
205
  iteration = event.get("iteration", "")
 
206
  status_ph.info(f"Running SQL (attempt {iteration})…")
207
  with sql_ph.container():
208
  with st.expander("SQL", expanded=True):
@@ -216,6 +258,7 @@ with col2:
216
 
217
  elif t == "geojson":
218
  geojson = event["data"]
 
219
  n = len(geojson.get("features", []))
220
  status_ph.success(f"**{to_run}** → {n} feature(s)")
221
  _render_map(geojson, map_ph)
@@ -227,3 +270,31 @@ with col2:
227
  status_ph.error(
228
  f"API error: {e}. Is the API running? `uv run uvicorn gazet.api:app --reload`"
229
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return pdk.ViewState(latitude=lat, longitude=lng, zoom=zoom)
69
 
70
 
71
+ def _has_line_geometries(features):
72
+ """Return True if features are predominantly line/point (non-polygon) geometries."""
73
+ line_types = {"LineString", "MultiLineString", "Point", "MultiPoint"}
74
+ count = sum(
75
+ 1 for f in features
76
+ if f.get("geometry", {}).get("type") in line_types
77
+ )
78
+ return count > len(features) / 2
79
+
80
+
81
  def _render_map(geojson, placeholder):
82
+ features = geojson.get("features", [])
83
+ n = len(features)
84
  if pdk and n:
85
+ is_linear = _has_line_geometries(features)
86
  layer = pdk.Layer(
87
  "GeoJsonLayer",
88
  data=geojson,
89
+ stroked=True,
90
+ filled=not is_linear,
91
+ get_fill_color=[40, 180, 160, 120],
92
+ get_line_color=[0, 140, 255, 255] if is_linear else [10, 50, 46, 255],
93
+ get_line_width=500 if is_linear else 80,
94
+ line_width_min_pixels=2 if is_linear else 1,
95
  pickable=True,
96
  )
 
 
 
 
 
 
97
  with placeholder.container():
98
+ selected_idx = None
99
+ if n > 1:
100
+ names = [
101
+ f.get("properties", {}).get("name", f"Feature {i}")
102
+ for i, f in enumerate(features)
103
+ ]
104
+ choice = st.selectbox(
105
+ "Zoom to feature",
106
+ ["All features"] + names,
107
+ key="feature_zoom",
108
+ )
109
+ if choice != "All features":
110
+ selected_idx = names.index(choice)
111
+
112
+ if selected_idx is not None:
113
+ single = {"type": "FeatureCollection", "features": [features[selected_idx]]}
114
+ bbox = bbox_from_geojson(single)
115
+ else:
116
+ bbox = bbox_from_geojson(geojson)
117
+
118
+ view = (
119
+ view_state_for_bbox(bbox)
120
+ if bbox
121
+ else pdk.ViewState(latitude=0, longitude=0, zoom=1)
122
+ )
123
  st.pydeck_chart(
124
  pdk.Deck(
125
  layers=[layer],
 
162
 
163
  if "run_q" not in st.session_state:
164
  st.session_state.run_q = None
165
+ if "last_result" not in st.session_state:
166
+ st.session_state.last_result = None
167
 
168
  col1, col2 = st.columns([1, 2])
169
  with col1:
 
186
  to_run = st.session_state.run_q
187
  if to_run:
188
  st.session_state.run_q = None
189
+ st.session_state.last_result = None
190
 
191
  status_ph = st.empty()
192
  map_ph = st.empty()
 
196
 
197
  status_ph.info("Extracting places…")
198
 
199
+ result = {"query": to_run, "places": None, "candidates": None, "sql": None, "geojson": None}
200
+
201
  try:
202
  with requests.get(
203
  f"{API}/search/stream", params={"q": to_run, "backend": backend}, stream=True, timeout=120
 
212
 
213
  if t == "places":
214
  places = event["data"].get("places", [])
215
+ result["places"] = places
216
  status_ph.info("Fuzzy-matching candidates…")
217
  if places:
218
  with places_ph.container():
 
232
  )
233
 
234
  elif t == "candidates":
235
+ result["candidates"] = event["data"]
236
  status_ph.info("Generating SQL…")
237
  with candidates_ph.container():
238
  with st.expander("Candidate datasets", expanded=True):
 
244
 
245
  elif t == "sql_attempt":
246
  iteration = event.get("iteration", "")
247
+ result["sql"] = event["data"]
248
  status_ph.info(f"Running SQL (attempt {iteration})…")
249
  with sql_ph.container():
250
  with st.expander("SQL", expanded=True):
 
258
 
259
  elif t == "geojson":
260
  geojson = event["data"]
261
+ result["geojson"] = geojson
262
  n = len(geojson.get("features", []))
263
  status_ph.success(f"**{to_run}** → {n} feature(s)")
264
  _render_map(geojson, map_ph)
 
270
  status_ph.error(
271
  f"API error: {e}. Is the API running? `uv run uvicorn gazet.api:app --reload`"
272
  )
273
+
274
+ st.session_state.last_result = result
275
+
276
+ elif st.session_state.last_result:
277
+ result = st.session_state.last_result
278
+ query = result["query"]
279
+ n_feat = len((result["geojson"] or {}).get("features", []))
280
+ st.success(f"**{query}** -> {n_feat} feature(s)")
281
+ _render_map(result["geojson"], st.empty())
282
+ if result["places"]:
283
+ with st.expander("Extracted place names"):
284
+ st.dataframe(
285
+ pd.DataFrame(result["places"]).rename(
286
+ columns={"place": "Place", "country": "Country", "subtype": "Subtype"}
287
+ ),
288
+ use_container_width=True,
289
+ hide_index=True,
290
+ )
291
+ if result["candidates"]:
292
+ with st.expander("Candidate datasets"):
293
+ st.dataframe(
294
+ pd.DataFrame(result["candidates"]),
295
+ use_container_width=True,
296
+ hide_index=True,
297
+ )
298
+ if result["sql"]:
299
+ with st.expander("SQL"):
300
+ st.code(result["sql"], language="sql")
ingest/convert_natural_earth.py CHANGED
@@ -107,7 +107,7 @@ def _load_shapefile(src: pathlib.Path, source_key: str) -> gpd.GeoDataFrame:
107
 
108
  # subtype: featurecla or source key
109
  if "featurecla" in gdf.columns:
110
- subtype = gdf["featurecla"]
111
  else:
112
  subtype = pd.Series([source_key] * n)
113
 
 
107
 
108
  # subtype: featurecla or source key
109
  if "featurecla" in gdf.columns:
110
+ subtype = gdf["featurecla"].str.lower()
111
  else:
112
  subtype = pd.Series([source_key] * n)
113
 
src/gazet/config.py CHANGED
@@ -6,8 +6,33 @@ import pathlib
6
  _DATA_DIR = pathlib.Path(os.environ.get("GAZET_DATA_DIR", str(
7
  pathlib.Path(__file__).resolve().parent.parent.parent / "data"
8
  )))
9
- DIVISIONS_AREA_PATH = str(_DATA_DIR / "overture/divisions_area/*.parquet")
10
- NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # MODEL = "qwen3.5:cloud"
13
  # MODEL = "granite4:350m"
 
6
  _DATA_DIR = pathlib.Path(os.environ.get("GAZET_DATA_DIR", str(
7
  pathlib.Path(__file__).resolve().parent.parent.parent / "data"
8
  )))
9
+
10
+
11
+ def _prefer_normalized(path_normalized: pathlib.Path, path_original: pathlib.Path) -> pathlib.Path:
12
+ """Prefer normalized geodata copies when present."""
13
+ use_normalized = os.environ.get("GAZET_USE_NORMALIZED_DATA", "1") != "0"
14
+ if use_normalized:
15
+ parent = path_normalized.parent
16
+ if "*" in path_normalized.name:
17
+ if parent.exists() and any(parent.glob(path_normalized.name)):
18
+ return path_normalized
19
+ elif path_normalized.exists():
20
+ return path_normalized
21
+ return path_original
22
+
23
+
24
+ DIVISIONS_AREA_PATH = str(
25
+ _prefer_normalized(
26
+ _DATA_DIR / "overture_normalized/divisions_area/*.parquet",
27
+ _DATA_DIR / "overture/divisions_area/*.parquet",
28
+ )
29
+ )
30
+ NATURAL_EARTH_PATH = str(
31
+ _prefer_normalized(
32
+ _DATA_DIR / "natural_earth_normalized/ne_geography.parquet",
33
+ _DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet",
34
+ )
35
+ )
36
 
37
  # MODEL = "qwen3.5:cloud"
38
  # MODEL = "granite4:350m"
src/gazet/sql.py CHANGED
@@ -74,6 +74,28 @@ def _rewrite_data_paths(sql: str) -> str:
74
  return sql
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def _strip_fences(sql: Optional[str]) -> str:
78
  """Remove markdown code fences that the LM may wrap the SQL in."""
79
  if not sql:
@@ -143,6 +165,7 @@ def run_geo_sql_gguf(
143
  return
144
 
145
  sql = _rewrite_data_paths(sql)
 
146
  print(f"\n[SQL·GGUF] Generated:\n{sql}\n")
147
  yield {"type": "sql_attempt", "sql": sql, "iteration": 1}
148
  yield from _execute_sql(con, sql, "SQL·GGUF", iteration=1)
@@ -183,6 +206,8 @@ def run_geo_sql_dspy(
183
  execution_error=error,
184
  )
185
  sql = _strip_fences(pred.sql)
 
 
186
  except Exception as exc:
187
  error = f"LM generation failed: {exc}"
188
  print(f"Generation error: {error}")
 
74
  return sql
75
 
76
 
77
+ # Title-cased NE subtype literals the trained model may emit.
78
+ # Data is now fully lowercased, so we normalise at query time.
79
+ _NE_SUBTYPE_FIXES = {
80
+ "'River'": "'river'",
81
+ "'Lake'": "'lake'",
82
+ "'Basin'": "'basin'",
83
+ "'Range/mtn'": "'range/mtn'",
84
+ "'Peninsula'": "'peninsula'",
85
+ "'Depression'": "'depression'",
86
+ "'Island group'": "'island group'",
87
+ "'Ocean'": "'ocean'",
88
+ "'Sea'": "'sea'",
89
+ }
90
+
91
+
92
+ def _normalize_ne_subtypes(sql: str) -> str:
93
+ """Lowercase known NE subtype literals so they match the normalised data."""
94
+ for old, new in _NE_SUBTYPE_FIXES.items():
95
+ sql = sql.replace(old, new)
96
+ return sql
97
+
98
+
99
  def _strip_fences(sql: Optional[str]) -> str:
100
  """Remove markdown code fences that the LM may wrap the SQL in."""
101
  if not sql:
 
165
  return
166
 
167
  sql = _rewrite_data_paths(sql)
168
+ sql = _normalize_ne_subtypes(sql)
169
  print(f"\n[SQL·GGUF] Generated:\n{sql}\n")
170
  yield {"type": "sql_attempt", "sql": sql, "iteration": 1}
171
  yield from _execute_sql(con, sql, "SQL·GGUF", iteration=1)
 
206
  execution_error=error,
207
  )
208
  sql = _strip_fences(pred.sql)
209
+ sql = _rewrite_data_paths(sql)
210
+ sql = _normalize_ne_subtypes(sql)
211
  except Exception as exc:
212
  error = f"LM generation failed: {exc}"
213
  print(f"Generation error: {error}")