Daniel Wiesmann commited on
Commit
af5cded
·
unverified ·
2 Parent(s): 5e72207789bf58

Merge pull request #1 from developmentseed/slm-qwen3.5

Browse files
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .venv
3
+ __pycache__
4
+ data/
5
+ finetune/models/
6
+ dataset/output/
7
+ dataset/intermediate/
8
+ results/
9
+ *.gguf
10
+ *.safetensors
11
+ .windsurf/
12
+ .claude/
.gitignore CHANGED
@@ -133,6 +133,26 @@ dmypy.json
133
  # Pyre type checker
134
  .pyre/
135
 
136
-
137
  data/
138
- output/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # Pyre type checker
134
  .pyre/
135
 
136
+ # Dataset
137
  data/
138
+ output/
139
+ *.parquet
140
+ *.jsonl
141
+
142
+ # Eval results
143
+ results/
144
+
145
+ # IDE
146
+ .windsurf/
147
+
148
+ # Local notes
149
+ notes.md
150
+
151
+ # Model
152
+ models/
153
+ *.gguf
154
+ *.safetensors
155
+ *.bin
156
+ *.pt
157
+ *.pth
158
+ *.ckpt
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13-slim
2
+
3
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
4
+
5
+ WORKDIR /app
6
+
7
+ # Install dependencies first (cache layer)
8
+ COPY pyproject.toml uv.lock ./
9
+ RUN uv sync --frozen --no-install-project --extra demo
10
+
11
+ # Copy application code
12
+ COPY src/ src/
13
+ COPY gazet_demo.py .
14
+
15
+ # Install the project itself
16
+ RUN uv sync --frozen --extra demo
17
+
18
+ ENV PATH="/app/.venv/bin:$PATH"
19
+
20
+ EXPOSE 8000 8501
README.md CHANGED
@@ -26,14 +26,14 @@ uv sync --extra dev --extra demo
26
  Example for downloading overture
27
 
28
  ```bash
29
- aws s3 sync
30
- s3 sync s3://overturemaps-us-west-2/release/2026-02-18.0/theme=divisions/type=division_area/ data/overture/divisions_area
31
  ```
32
 
33
  Example for running conversion script for natural earth
34
 
35
  ```bash
36
- python -m ingest.convert_natural_earth ~/Downloads/10m_physical
 
37
  ```
38
 
39
  ### Based on ollama
@@ -61,7 +61,7 @@ uv run streamlit run gazet_demo.py # demo UI
61
  | Module | Contents |
62
  | --- | --- |
63
  | `config.py` | data paths, model name, SQL schema description |
64
- | `types.py` | `SUBTYPES`, `COUNTRIES`, `Place`, `PlacesResult` |
65
  | `lm.py` | DSPy signatures + LM init (`extract`, `write_sql`) |
66
  | `search.py` | fuzzy search against `divisions_area` / `natural_earth` |
67
  | `sql.py` | code-act SQL generation loop |
 
26
  Example for downloading overture
27
 
28
  ```bash
29
+ aws s3 sync s3://overturemaps-us-west-2/release/2026-02-18.0/theme=divisions/type=division_area/ data/overture/divisions_area
 
30
  ```
31
 
32
  Example for running conversion script for natural earth
33
 
34
  ```bash
35
+ unzip ~/Downloads/10m_physical.zip -d data/natural_earth
36
+ python -m ingest.convert_natural_earth data/natural_earth
37
  ```
38
 
39
  ### Based on ollama
 
61
  | Module | Contents |
62
  | --- | --- |
63
  | `config.py` | data paths, model name, SQL schema description |
64
+ | `schemas.py` | `SUBTYPES`, `COUNTRIES`, `Place`, `PlacesResult` |
65
  | `lm.py` | DSPy signatures + LM init (`extract`, `write_sql`) |
66
  | `search.py` | fuzzy search against `divisions_area` / `natural_earth` |
67
  | `sql.py` | code-act SQL generation loop |
dataset/README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gazet Dataset Generation
2
+
3
+ Generates synthetic training data for fine-tuning the geocoding model.
4
+ Two datasets come out of one pipeline run:
5
+
6
+ - **SQL generation** — `(question + candidates) -> DuckDB SQL`
7
+ - **Place extraction** — `question -> place names JSON`
8
+
9
+ Both tasks export in **conversation format** (`messages` list of
10
+ system/user/assistant turns), ready for chat-template fine-tuning.
11
+
12
+ ---
13
+
14
+ ## Prerequisites
15
+
16
+ ```bash
17
+ uv sync
18
+ ```
19
+
20
+ You need the Overture and Natural Earth parquet files under `data/` locally,
21
+ or on a Modal volume if running in the cloud.
22
+
23
+ ---
24
+
25
+ ## Option A — Run locally (small datasets, development)
26
+
27
+ Use this when you want to iterate quickly on a laptop with a subset of countries.
28
+
29
+ **Step 1 — Pick a run name and countries in `config.yaml`**
30
+
31
+ ```yaml
32
+ run_name: "v1" # change this every time you generate fresh data
33
+
34
+ countries:
35
+ - IN # India
36
+ - BR # Brazil
37
+ - US # United States
38
+ # add more, or use "- all" for every country (slow locally)
39
+ ```
40
+
41
+ **Step 2 — Run the full pipeline**
42
+
43
+ ```bash
44
+ gazet-dataset full-pipeline --config dataset/config.yaml
45
+ ```
46
+
47
+ That's it. It runs all four steps in order and puts the results in
48
+ `dataset/output/runs/my-run-001/`.
49
+
50
+ If you want to run steps individually (e.g. to re-export without regenerating):
51
+
52
+ ```bash
53
+ gazet-dataset build-relations --config dataset/config.yaml # ~5 min
54
+ gazet-dataset generate-samples --config dataset/config.yaml # ~15 min
55
+ gazet-dataset validate --config dataset/config.yaml # ~5 min
56
+ gazet-dataset export --config dataset/config.yaml # <1 min
57
+ ```
58
+
59
+ ---
60
+
61
+ ## Option B — Run on Modal (large datasets, production)
62
+
63
+ Use this when you need 10 K+ samples or want to use all countries. Modal
64
+ distributes generation across many containers in parallel.
65
+
66
+ **Step 1 — One-time setup**
67
+
68
+ ```bash
69
+ modal setup # authenticate with Modal (one time)
70
+ gazet-dataset modal-upload --config dataset/config.yaml # upload parquet data to Modal volume
71
+ ```
72
+
73
+ **Step 2 — Set run name and targets in `config.yaml`**
74
+
75
+ ```yaml
76
+ run_name: "v1"
77
+
78
+ countries:
79
+ - all
80
+
81
+ sample_targets:
82
+ adjacency: 1250
83
+ containment: 1250
84
+ # ... see config.yaml for all families
85
+ ```
86
+
87
+ **Step 3 — Run on Modal**
88
+
89
+ ```bash
90
+ gazet-dataset modal-generate --config dataset/config.yaml
91
+ ```
92
+
93
+ This builds relations, generates samples, validates, and exports — same as
94
+ `full-pipeline` but distributed across 100 cloud containers.
95
+
96
+ If relations are already built from a previous run (same countries, same
97
+ template version), skip rebuilding them:
98
+
99
+ ```bash
100
+ gazet-dataset modal-generate --config dataset/config.yaml --skip-relations
101
+ ```
102
+
103
+ ---
104
+
105
+ ## Output
106
+
107
+ After running, your training files are at:
108
+
109
+ ```
110
+ dataset/output/runs/{run_name}/
111
+ sql/
112
+ train.jsonl <- fine-tune the SQL generation model
113
+ val.jsonl
114
+ test.jsonl
115
+ places/
116
+ train.jsonl <- fine-tune the place extraction model
117
+ val.jsonl
118
+ test.jsonl
119
+ stats.json <- sample counts by family
120
+ ```
121
+
122
+ Each JSONL row is a conversation-format dict:
123
+
124
+ ```json
125
+ {
126
+ "messages": [
127
+ {"role": "system", "content": "..."},
128
+ {"role": "user", "content": "..."},
129
+ {"role": "assistant", "content": "..."}
130
+ ]
131
+ }
132
+ ```
133
+
134
+ **SQL task**: the system prompt includes the full two-table schema inside
135
+ `<SCHEMA>` tags. The user prompt contains only `<CANDIDATES>` CSV and
136
+ `<USER_QUERY>`. The assistant response is pretty-printed SQL (via `sqlparse`).
137
+ All parquet paths are symbolic (`divisions_area` / `natural_earth`), never
138
+ runtime-specific.
139
+
140
+ **Places task**: the system prompt includes output format, extraction rules,
141
+ and the full list of Overture subtypes. The assistant response is a JSON
142
+ object with a `places` array.
143
+
144
+ ---
145
+
146
+ ## When to regenerate from scratch
147
+
148
+ Change `run_name` and regenerate from scratch whenever you:
149
+
150
+ - Change any SQL templates (`sql_templates.py`)
151
+ - Add new template families
152
+ - Change the candidate format or count
153
+ - Change the system/user prompt structure or content
154
+ - Change the export format
155
+
156
+ For local runs, the default is a clean run. For Modal, `modal-generate` appends
157
+ by default; pass `--fresh` to overwrite existing samples.
158
+
159
+ ---
160
+
161
+ ## Troubleshooting
162
+
163
+ **Very few samples generated for a family**
164
+ The generation loop tries `retry_multiplier × target` and discards SQL that
165
+ returns empty results. Some families (e.g. `multi_adjacency`, `chained`) have
166
+ a lower success rate. Increase `sample_targets` for those families or increase
167
+ `retry_multiplier` in `config.yaml`.
168
+
169
+ **Relations step is slow**
170
+ Normal for `countries: [all]` — it's a spatial self-join over millions of
171
+ features. Use a country subset for development. Relations only need to be
172
+ rebuilt when you add countries or change template families.
173
+
174
+ **Validate step drops many samples**
175
+ The validate step re-executes every SQL query and discards ones that return
176
+ empty results. This is expected — check `output/runs/{run_name}/stats.json`
177
+ for per-family counts after export.
dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Synthetic dataset generation package."""
dataset/config.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Generation Configuration
2
+ # This config controls which countries to process and how many samples to generate
3
+
4
+ # Countries to include in relation building
5
+ # Use ISO 3166-1 alpha-2 codes, or "all" to include every country
6
+ countries:
7
+ - all
8
+ # Or specify a subset:
9
+ # - IN # India
10
+ # - PK # Pakistan
11
+ # - EC # Ecuador
12
+ # - BE # Belgium
13
+ # - KE # Kenya
14
+
15
+ # Sample generation targets per family
16
+ # Relation limits are auto-calculated from these targets
17
+ sample_targets:
18
+ direct_lookup: 500
19
+ adjacency: 750
20
+ multi_adjacency: 300
21
+ containment: 750
22
+ intersection: 500
23
+ buffer: 500
24
+ chained: 750 # coastal / landlocked variants
25
+ difference: 300
26
+ border_corridor: 300
27
+ set_operations: 500
28
+ partial_selection: 500
29
+ aggregation: 500
30
+ window_function: 300
31
+ attribute_filter: 300
32
+
33
+ # Generation settings
34
+ generation:
35
+ max_workers: 8 # Number of parallel workers
36
+ retry_multiplier: 2 # Generate 2x samples to account for failures
37
+ append_mode: false # Set false for clean regeneration after template/format changes
38
+
39
+ # Auto-scaling configuration
40
+ # Relation limits are automatically calculated: target * retry_multiplier * safety_factor
41
+ auto_scaling:
42
+ safety_factor: 1.5 # Extra buffer to ensure enough unique pairs
43
+
44
+ # Manual overrides (optional) - uncomment to override auto-calculated limits
45
+ manual_limits: {}
46
+ # adjacency: 10000 # Uncomment to manually set
47
+ # containment: 2000
48
+ # intersection: 1000
49
+ # cross_source: 500
50
+
51
+ # Modal configuration for distributed generation
52
+ modal:
53
+ volume_name: "gazet-data" # Modal Volume for parquet data
54
+ app_name: "gazet-dataset" # Modal app name
55
+ num_containers: 100 # Number of parallel containers for sample generation
56
+ container_cpu: 2 # CPUs per container
57
+ container_memory: 4096 # Memory (MB) per container
58
+ timeout: 3600 # Per-container timeout in seconds
59
+
60
+ # Run name — used to version exported splits so re-runs never overwrite previous data.
61
+ # Change this whenever you regenerate from scratch (e.g. after template changes).
62
+ # Exported files land in: output/runs/{run_name}/
63
+ run_name: "v1"
dataset/modal_app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal app for distributed dataset generation."""
2
+
3
+
4
+ import modal
5
+
6
+ app = modal.App("gazet-dataset")
7
+
8
+ VOLUME_MOUNT = "/data"
9
+ INTERMEDIATE_MOUNT = "/intermediate"
10
+
11
+ volume = modal.Volume.from_name("gazet-data", create_if_missing=True)
12
+ intermediate_volume = modal.Volume.from_name(
13
+ "gazet-intermediate", create_if_missing=True
14
+ )
15
+
16
+ image = (
17
+ modal.Image.debian_slim(python_version="3.12")
18
+ .pip_install(
19
+ "duckdb>=1.4.4",
20
+ "dspy>=3.1.3",
21
+ "fastapi>=0.100",
22
+ "pandas>=2.2",
23
+ "pydantic>=2.0",
24
+ "pyarrow>=17.0.0",
25
+ "pyyaml>=6.0",
26
+ )
27
+ .env({"GAZET_DATA_DIR": VOLUME_MOUNT, "PYTHONPATH": "/root"})
28
+ .add_local_dir("src/gazet", "/root/gazet")
29
+ .add_local_dir("dataset", "/root/dataset")
30
+ )
31
+
32
+
33
+ @app.function(
34
+ image=image,
35
+ volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
36
+ timeout=300,
37
+ cpu=2,
38
+ memory=4096,
39
+ )
40
+ def build_inventory_remote():
41
+ """Build entity inventory from parquet files on the volume."""
42
+ from pathlib import Path
43
+ from dataset.scripts.build_inventory import build_inventory_to_dir
44
+
45
+ result = build_inventory_to_dir(Path(INTERMEDIATE_MOUNT))
46
+ intermediate_volume.commit()
47
+ return result
48
+
49
+
50
+ @app.function(
51
+ image=image,
52
+ volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
53
+ timeout=3600,
54
+ cpu=4,
55
+ memory=32768,
56
+ )
57
+ def build_relation_remote(relation_type: str, countries: list, limit: int):
58
+ """Compute one relation type and save to intermediate volume."""
59
+ from pathlib import Path
60
+ from dataset.scripts.build_relations import compute_single_relation
61
+
62
+ count = compute_single_relation(
63
+ relation_type=relation_type,
64
+ countries=countries,
65
+ limit=limit,
66
+ output_dir=Path(INTERMEDIATE_MOUNT),
67
+ )
68
+ intermediate_volume.commit()
69
+ return {"relation_type": relation_type, "count": count}
70
+
71
+
72
+ @app.function(
73
+ image=image,
74
+ volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
75
+ timeout=3600,
76
+ cpu=2,
77
+ memory=4096,
78
+ )
79
+ def generate_batch_remote(work_items: list) -> list:
80
+ """Process a batch of work items on a Modal container."""
81
+ from dataset.scripts.generate_samples import generate_batch_core
82
+
83
+ results = generate_batch_core(
84
+ work_items=work_items,
85
+ intermediate_dir=INTERMEDIATE_MOUNT,
86
+ )
87
+
88
+ print(f"Batch complete: {sum(1 for r in results if r['sample'])} success / "
89
+ f"{sum(1 for r in results if not r['sample'])} failed out of {len(work_items)}")
90
+
91
+ return results
92
+
93
+
94
+ @app.local_entrypoint()
95
+ def run_pipeline(
96
+ config_path: str = "dataset/config.yaml",
97
+ num_containers: int = 0,
98
+ skip_inventory: bool = False,
99
+ skip_relations: bool = False,
100
+ fresh: bool = False,
101
+ ):
102
+ """Run the full distributed pipeline."""
103
+ import yaml
104
+ from pathlib import Path
105
+
106
+ config = yaml.safe_load(Path(config_path).read_text())
107
+ countries = config["countries"]
108
+ sample_targets = config["sample_targets"]
109
+ modal_cfg = config.get("modal", {})
110
+ n_containers = num_containers or modal_cfg.get("num_containers", 50)
111
+ retry_multiplier = config["generation"]["retry_multiplier"]
112
+
113
+ print(f"Countries: {countries}")
114
+ print(f"Sample targets: {sample_targets}")
115
+ print(f"Containers: {n_containers}")
116
+
117
+ if not skip_inventory:
118
+ print("Building inventory...")
119
+ result = build_inventory_remote.remote()
120
+ print(f" Inventory: {result}")
121
+
122
+ if not skip_relations:
123
+ print("Building relations...")
124
+
125
+ from dataset.scripts.cli import calculate_relation_limits
126
+
127
+ relation_needs = calculate_relation_limits(config)
128
+
129
+ handles = []
130
+ for rel_type, limit in relation_needs.items():
131
+ h = build_relation_remote.spawn(rel_type, countries, max(limit, 500))
132
+ handles.append((rel_type, h))
133
+
134
+ for rel_type, h in handles:
135
+ result = h.get()
136
+ print(f" {rel_type}: {result['count']} pairs")
137
+
138
+ print(f"Generating samples across {n_containers} containers...")
139
+
140
+ import json
141
+ from dataset.scripts.generate_samples import prepare_work_items
142
+
143
+ output_dir = Path("dataset/output")
144
+ output_dir.mkdir(exist_ok=True, parents=True)
145
+ output_file = output_dir / "dataset_raw.jsonl"
146
+
147
+ existing_samples = []
148
+ sample_counter = 1
149
+ if not fresh and output_file.exists():
150
+ with open(output_file) as f:
151
+ for line in f:
152
+ if line.strip():
153
+ existing_samples.append(json.loads(line))
154
+ if existing_samples:
155
+ max_id = max(
156
+ int(s["id"].split("_")[1])
157
+ for s in existing_samples
158
+ if s["id"].startswith("sample_")
159
+ )
160
+ sample_counter = max_id + 1
161
+ print(f" Appending to {len(existing_samples)} existing samples")
162
+
163
+ work_items = prepare_work_items(
164
+ target_counts=sample_targets,
165
+ retry_multiplier=retry_multiplier,
166
+ start_counter=sample_counter,
167
+ intermediate_dir_str="",
168
+ )
169
+
170
+ total_work = len(work_items)
171
+ print(f" Total work items: {total_work}")
172
+
173
+ batch_size = max(1, (total_work + n_containers - 1) // n_containers)
174
+ batches = [
175
+ work_items[i : i + batch_size]
176
+ for i in range(0, total_work, batch_size)
177
+ ]
178
+ print(f" Batches: {len(batches)} x ~{batch_size} items")
179
+
180
+ new_sample_count = 0
181
+ failed_batches = 0
182
+ family_progress = {}
183
+
184
+ write_mode = "w" if fresh else "a"
185
+ fout = open(output_file, write_mode)
186
+
187
+ try:
188
+ for batch_results in generate_batch_remote.map(
189
+ batches, return_exceptions=True
190
+ ):
191
+ if isinstance(batch_results, Exception):
192
+ failed_batches += 1
193
+ print(f" Batch failed: {batch_results}")
194
+ continue
195
+
196
+ batch_samples = []
197
+ for r in batch_results:
198
+ fam = r["family"]
199
+ if fam not in family_progress:
200
+ family_progress[fam] = {"success": 0, "failed": 0}
201
+ if r["sample"]:
202
+ batch_samples.append(r["sample"])
203
+ family_progress[fam]["success"] += 1
204
+ else:
205
+ family_progress[fam]["failed"] += 1
206
+
207
+ for sample in batch_samples:
208
+ fout.write(json.dumps(sample) + "\n")
209
+ fout.flush()
210
+ new_sample_count += len(batch_samples)
211
+
212
+ done = sum(p["success"] + p["failed"] for p in family_progress.values())
213
+ print(f" Progress: {done}/{total_work} items | {new_sample_count} saved | {failed_batches} batch errors")
214
+
215
+ except Exception as e:
216
+ print(f" Map interrupted: {e}")
217
+ finally:
218
+ fout.close()
219
+
220
+ print(f"\nResults by family:")
221
+ for fam in sorted(family_progress.keys()):
222
+ s = family_progress[fam]["success"]
223
+ f = family_progress[fam]["failed"]
224
+ total = s + f
225
+ rate = (s / total * 100) if total > 0 else 0
226
+ target = sample_targets.get(fam, 0)
227
+ print(
228
+ f" {fam:20s}: {s:4d} success / {f:4d} failed "
229
+ f"({rate:5.1f}%, target: {target})"
230
+ )
231
+
232
+ total_samples = len(existing_samples) + new_sample_count
233
+ status = "COMPLETE" if failed_batches == 0 else "PARTIAL"
234
+ print(f"\nGeneration {status}: {new_sample_count} new, {total_samples} total")
235
+ if failed_batches:
236
+ print(f" Failed batches: {failed_batches}/{len(batches)}")
237
+ print(f" Output: {output_file}")
238
+
239
+
240
+ @app.local_entrypoint()
241
+ def upload_data(data_dir: str = "data"):
242
+ """Upload local data directory to the Modal volume."""
243
+ import os
244
+ from pathlib import Path
245
+
246
+ data_path = Path(data_dir)
247
+ if not data_path.exists():
248
+ print(f"Error: {data_path} does not exist")
249
+ return
250
+
251
+ print(f"Uploading {data_path} to Modal volume 'gazet-data'...")
252
+
253
+ file_count = 0
254
+ total_size = 0
255
+
256
+ for root, dirs, files in os.walk(data_path):
257
+ for f in files:
258
+ local_path = os.path.join(root, f)
259
+ # Relative path within data_dir becomes the volume path
260
+ rel = os.path.relpath(local_path, data_path)
261
+ size = os.path.getsize(local_path)
262
+ total_size += size
263
+ file_count += 1
264
+ print(f" {rel} ({size / (1024*1024):.1f} MB)")
265
+
266
+ print(f" {file_count} files, {total_size / (1024*1024):.1f} MB")
267
+
268
+ vol = modal.Volume.from_name("gazet-data", create_if_missing=True)
269
+ with vol.batch_upload() as batch:
270
+ batch.put_directory(str(data_path), "/")
271
+
272
+ print("Upload complete")
dataset/scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Dataset generation scripts package."""
dataset/scripts/build_inventory.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build entity inventory from divisions_area and natural_earth parquet files.
3
+
4
+ This script creates compact inventory tables containing only the fields needed
5
+ for candidate sampling and distractor generation.
6
+
7
+ Output:
8
+ - intermediate/divisions_area_inventory.parquet
9
+ - intermediate/natural_earth_inventory.parquet
10
+ """
11
+
12
+ import duckdb
13
+ import pandas as pd
14
+ from pathlib import Path
15
+
16
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
17
+
18
+
19
+ def build_divisions_area_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
20
+ """Extract compact inventory from divisions_area."""
21
+ query = """
22
+ SELECT
23
+ 'divisions_area' AS source,
24
+ id,
25
+ names."primary" AS name,
26
+ subtype,
27
+ country,
28
+ region,
29
+ admin_level,
30
+ class,
31
+ is_land,
32
+ is_territorial,
33
+ division_id,
34
+ ST_Area(geometry) AS area_sq_deg,
35
+ ST_XMin(geometry) AS xmin,
36
+ ST_YMin(geometry) AS ymin,
37
+ ST_XMax(geometry) AS xmax,
38
+ ST_YMax(geometry) AS ymax
39
+ FROM read_parquet(?)
40
+ WHERE names."primary" IS NOT NULL
41
+ AND trim(names."primary") != ''
42
+ """
43
+
44
+ df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
45
+ print(f"Divisions area inventory: {len(df)} entities")
46
+ print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
47
+ print(f"Countries: {df['country'].nunique()} unique")
48
+
49
+ return df
50
+
51
+
52
+ def build_natural_earth_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
53
+ """Extract compact inventory from natural_earth."""
54
+ query = """
55
+ SELECT
56
+ 'natural_earth' AS source,
57
+ id,
58
+ names."primary" AS name,
59
+ subtype,
60
+ country,
61
+ region,
62
+ admin_level,
63
+ class,
64
+ is_land,
65
+ is_territorial,
66
+ ST_Area(geometry) AS area_sq_deg,
67
+ ST_XMin(geometry) AS xmin,
68
+ ST_YMin(geometry) AS ymin,
69
+ ST_XMax(geometry) AS xmax,
70
+ ST_YMax(geometry) AS ymax
71
+ FROM read_parquet(?)
72
+ WHERE names."primary" IS NOT NULL
73
+ AND trim(names."primary") != ''
74
+ """
75
+
76
+ df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
77
+ print(f"\nNatural earth inventory: {len(df)} entities")
78
+ print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
79
+
80
+ return df
81
+
82
+
83
+ def build_inventory_to_dir(output_dir: Path) -> dict:
84
+ """Build and save all inventory tables to output_dir.
85
+
86
+ Reusable entry point for both local CLI and Modal.
87
+
88
+ Returns:
89
+ Dict with counts: {"divisions_area": int, "natural_earth": int}
90
+ """
91
+ output_dir.mkdir(exist_ok=True, parents=True)
92
+
93
+ con = duckdb.connect()
94
+ con.execute("INSTALL spatial")
95
+ con.execute("LOAD spatial")
96
+
97
+ print("Building divisions_area inventory...")
98
+ divisions_df = build_divisions_area_inventory(con)
99
+ divisions_path = output_dir / "divisions_area_inventory.parquet"
100
+ divisions_df.to_parquet(divisions_path, index=False)
101
+ print(f"Saved to {divisions_path}")
102
+
103
+ print("\nBuilding natural_earth inventory...")
104
+ natural_earth_df = build_natural_earth_inventory(con)
105
+ natural_earth_path = output_dir / "natural_earth_inventory.parquet"
106
+ natural_earth_df.to_parquet(natural_earth_path, index=False)
107
+ print(f"Saved to {natural_earth_path}")
108
+
109
+ con.close()
110
+
111
+ total = len(divisions_df) + len(natural_earth_df)
112
+ print(f"\nInventory build complete")
113
+ print(f" Total entities: {total}")
114
+ return {"divisions_area": len(divisions_df), "natural_earth": len(natural_earth_df)}
115
+
116
+
117
+ def main():
118
+ """Build and save inventory tables."""
119
+ output_dir = Path(__file__).parent.parent / "intermediate"
120
+ build_inventory_to_dir(output_dir)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
dataset/scripts/build_relations.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Precompute spatial relation tables for efficient anchor sampling.
3
+
4
+ This script computes:
5
+ - Adjacency pairs (touching features)
6
+ - Containment pairs (features within other features)
7
+ - Intersection pairs (overlapping features)
8
+ - Cross-source relations (divisions_area ↔ natural_earth)
9
+
10
+ Output:
11
+ - intermediate/adjacency_pairs.parquet
12
+ - intermediate/containment_pairs.parquet
13
+ - intermediate/intersection_pairs.parquet
14
+ - intermediate/cross_source_relations.parquet
15
+ """
16
+
17
+ import duckdb
18
+ import pandas as pd
19
+ from pathlib import Path
20
+ from concurrent.futures import ThreadPoolExecutor, as_completed
21
+
22
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
23
+
24
+
25
+ # Subtypes too granular for spatial self-joins at global scale
26
+ _EXCLUDED_SUBTYPES_FOR_GLOBAL = ("locality", "neighborhood", "microhood", "macrohood")
27
+
28
+
29
+ def _country_filter(countries: list) -> tuple[str, list]:
30
+ """Return (SQL WHERE clause, params) handling 'all' sentinel."""
31
+ if countries == ["all"]:
32
+ return "", []
33
+ return "WHERE country IN (SELECT unnest(?))", [countries]
34
+
35
+
36
+ def _country_filter_for_join(countries: list) -> tuple[str, list]:
37
+ """Like _country_filter but also excludes fine-grained subtypes for global runs.
38
+
39
+ When joining all 1M+ entities, localities/neighborhoods/microhoods cause
40
+ OOM. Excluding them keeps ~110K higher-level admin entities.
41
+ """
42
+ excluded = "', '".join(_EXCLUDED_SUBTYPES_FOR_GLOBAL)
43
+ subtype_clause = f"AND subtype NOT IN ('{excluded}')"
44
+ if countries == ["all"]:
45
+ return f"WHERE 1=1 {subtype_clause}", []
46
+ return f"WHERE country IN (SELECT unnest(?)) {subtype_clause}", [countries]
47
+
48
+
49
+ def compute_adjacency_pairs(
50
+ con: duckdb.DuckDBPyConnection,
51
+ countries: list,
52
+ limit: int
53
+ ) -> pd.DataFrame:
54
+ """Find all pairs of features that touch (share a boundary)."""
55
+ print("Computing adjacency pairs (optimized with spatial index)...")
56
+
57
+ cfilter, cparams = _country_filter_for_join(countries)
58
+
59
+ # Use bounding box pre-filter to avoid full cartesian product
60
+ query = f"""
61
+ WITH features AS (
62
+ SELECT
63
+ id,
64
+ names."primary" AS name,
65
+ subtype,
66
+ country,
67
+ admin_level,
68
+ geometry,
69
+ ST_Envelope(geometry) AS bbox
70
+ FROM read_parquet(?)
71
+ {cfilter}
72
+ )
73
+ SELECT
74
+ a.id AS anchor_id,
75
+ a.name AS anchor_name,
76
+ a.subtype AS anchor_subtype,
77
+ a.country AS anchor_country,
78
+ b.id AS target_id,
79
+ b.name AS target_name,
80
+ b.subtype AS target_subtype,
81
+ b.country AS target_country,
82
+ 'adjacency' AS relation_type
83
+ FROM features AS a
84
+ JOIN features AS b ON (
85
+ a.id < b.id
86
+ AND ST_Intersects(a.bbox, b.bbox)
87
+ AND ST_Touches(a.geometry, b.geometry)
88
+ )
89
+ LIMIT ?
90
+ """
91
+
92
+ df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
93
+ print(f"Found {len(df)} adjacency pairs")
94
+
95
+ return df
96
+
97
+
98
+ def compute_containment_pairs(
99
+ con: duckdb.DuckDBPyConnection,
100
+ countries: list,
101
+ limit: int
102
+ ) -> pd.DataFrame:
103
+ """Find all pairs where one feature contains another."""
104
+ print("\nComputing containment pairs (optimized)...")
105
+
106
+ cfilter, cparams = _country_filter(countries)
107
+
108
+ query = f"""
109
+ WITH features AS (
110
+ SELECT
111
+ id,
112
+ names."primary" AS name,
113
+ subtype,
114
+ country,
115
+ admin_level,
116
+ geometry,
117
+ ST_Envelope(geometry) AS bbox
118
+ FROM read_parquet(?)
119
+ {cfilter}
120
+ )
121
+ SELECT
122
+ a.id AS container_id,
123
+ a.name AS container_name,
124
+ a.subtype AS container_subtype,
125
+ b.id AS contained_id,
126
+ b.name AS contained_name,
127
+ b.subtype AS contained_subtype,
128
+ 'containment' AS relation_type
129
+ FROM features AS a
130
+ JOIN features AS b ON (
131
+ a.id != b.id
132
+ AND a.admin_level < b.admin_level
133
+ AND ST_Intersects(a.bbox, b.bbox)
134
+ AND ST_Within(b.geometry, a.geometry)
135
+ )
136
+ LIMIT ?
137
+ """
138
+
139
+ df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
140
+ print(f"Found {len(df)} containment pairs")
141
+
142
+ return df
143
+
144
+
145
+ def compute_intersection_pairs(
146
+ con: duckdb.DuckDBPyConnection,
147
+ countries: list,
148
+ limit: int
149
+ ) -> pd.DataFrame:
150
+ """Find pairs that intersect but don't touch or contain."""
151
+ print("\nComputing intersection pairs (optimized)...")
152
+
153
+ cfilter, cparams = _country_filter_for_join(countries)
154
+
155
+ query = f"""
156
+ WITH features AS (
157
+ SELECT
158
+ id,
159
+ names."primary" AS name,
160
+ subtype,
161
+ country,
162
+ admin_level,
163
+ geometry,
164
+ ST_Envelope(geometry) AS bbox
165
+ FROM read_parquet(?)
166
+ {cfilter}
167
+ )
168
+ SELECT
169
+ a.id AS anchor_id,
170
+ a.name AS anchor_name,
171
+ a.subtype AS anchor_subtype,
172
+ b.id AS target_id,
173
+ b.name AS target_name,
174
+ b.subtype AS target_subtype,
175
+ 'intersection' AS relation_type
176
+ FROM features AS a
177
+ JOIN features AS b ON (
178
+ a.id < b.id
179
+ AND ST_Intersects(a.bbox, b.bbox)
180
+ AND ST_Intersects(a.geometry, b.geometry)
181
+ AND NOT ST_Touches(a.geometry, b.geometry)
182
+ AND NOT ST_Within(a.geometry, b.geometry)
183
+ AND NOT ST_Within(b.geometry, a.geometry)
184
+ )
185
+ LIMIT ?
186
+ """
187
+
188
+ df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
189
+ print(f"Found {len(df)} same-source intersection pairs")
190
+
191
+ return df
192
+
193
+
194
+ def compute_cross_source_relations(
195
+ con: duckdb.DuckDBPyConnection,
196
+ countries: list,
197
+ limit: int
198
+ ) -> pd.DataFrame:
199
+ """Find relations between divisions_area and natural_earth.
200
+
201
+ Covers all natural_earth subtypes that appear in SQL templates:
202
+ seas/oceans (adjacency, buffer, chained), terrain areas and island
203
+ groups (chained_03, intersect_02, buffer_03/04).
204
+ """
205
+ print("\nComputing cross-source relations...")
206
+
207
+ cfilter, cparams = _country_filter(countries)
208
+
209
+ query = f"""
210
+ WITH divisions AS (
211
+ SELECT
212
+ id,
213
+ names."primary" AS name,
214
+ subtype,
215
+ country,
216
+ geometry
217
+ FROM read_parquet(?)
218
+ {cfilter}
219
+ ),
220
+ natural_features AS (
221
+ SELECT
222
+ id,
223
+ names."primary" AS name,
224
+ subtype,
225
+ ST_SetCRS(geometry, 'OGC:CRS84') AS geometry
226
+ FROM read_parquet(?)
227
+ WHERE subtype IN (
228
+ 'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay',
229
+ 'Terrain area', 'Island group', 'Peninsula', 'Strait',
230
+ 'Reef', 'Range/Mts', 'Depression'
231
+ )
232
+ LIMIT 500
233
+ )
234
+ SELECT
235
+ d.id AS division_id,
236
+ d.name AS division_name,
237
+ d.subtype AS division_subtype,
238
+ d.country AS division_country,
239
+ n.id AS natural_id,
240
+ n.name AS natural_name,
241
+ n.subtype AS natural_subtype,
242
+ CASE
243
+ WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
244
+ WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
245
+ WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
246
+ WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
247
+ END AS relation_type
248
+ FROM divisions AS d
249
+ JOIN natural_features AS n ON ST_Intersects(d.geometry, n.geometry)
250
+ LIMIT ?
251
+ """
252
+
253
+ df = con.execute(
254
+ query, [DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH, limit]
255
+ ).fetchdf()
256
+ print(f"Found {len(df)} cross-source relations")
257
+ return df
258
+
259
+
260
+ def compute_coastal_containment_pairs(
261
+ con: duckdb.DuckDBPyConnection,
262
+ countries: list,
263
+ limit: int,
264
+ ) -> pd.DataFrame:
265
+ """Containment pairs where the container is in a coastal country.
266
+
267
+ Used by chained_01 (coastal towns of X) to ensure sampled containment
268
+ anchors actually have sea-adjacent sub-features, keeping the SQL
269
+ verification step from constantly returning empty results.
270
+
271
+ Strategy: find countries whose geometry intersects any ocean/sea in
272
+ natural_earth, then filter containment_pairs to those countries.
273
+ """
274
+ print("\nComputing coastal containment pairs...")
275
+
276
+ cfilter, cparams = _country_filter(countries)
277
+
278
+ query = f"""
279
+ WITH coastal_countries AS (
280
+ SELECT DISTINCT d.country
281
+ FROM read_parquet(?) AS d
282
+ JOIN read_parquet(?) AS n
283
+ ON ST_Intersects(d.geometry, ST_SetCRS(n.geometry, 'OGC:CRS84'))
284
+ WHERE d.subtype = 'country'
285
+ AND n.subtype IN ('sea', 'ocean')
286
+ ),
287
+ features AS (
288
+ SELECT
289
+ id,
290
+ names."primary" AS name,
291
+ subtype,
292
+ country,
293
+ admin_level,
294
+ geometry,
295
+ ST_Envelope(geometry) AS bbox
296
+ FROM read_parquet(?)
297
+ {cfilter}
298
+ )
299
+ SELECT
300
+ a.id AS container_id,
301
+ a.name AS container_name,
302
+ a.subtype AS container_subtype,
303
+ b.id AS contained_id,
304
+ b.name AS contained_name,
305
+ b.subtype AS contained_subtype,
306
+ a.country AS container_country,
307
+ 'coastal_containment' AS relation_type
308
+ FROM features AS a
309
+ JOIN features AS b ON (
310
+ a.id != b.id
311
+ AND a.admin_level < b.admin_level
312
+ AND ST_Intersects(a.bbox, b.bbox)
313
+ AND ST_Within(b.geometry, a.geometry)
314
+ )
315
+ WHERE a.country IN (SELECT country FROM coastal_countries)
316
+ LIMIT ?
317
+ """
318
+
319
+ df = con.execute(
320
+ query,
321
+ [DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH] + cparams + [DIVISIONS_AREA_PATH, limit],
322
+ ).fetchdf()
323
+ print(f"Found {len(df)} coastal containment pairs")
324
+ return df
325
+
326
+
327
+ def compute_landlocked_containment_pairs(
328
+ con: duckdb.DuckDBPyConnection,
329
+ countries: list,
330
+ limit: int,
331
+ ) -> pd.DataFrame:
332
+ """Containment pairs where the container is in a landlocked country.
333
+
334
+ Used by chained_02 (landlocked regions within X) to ensure sampled
335
+ anchors genuinely have no sea access, keeping SQL verification from
336
+ always returning empty.
337
+ """
338
+ print("\nComputing landlocked containment pairs...")
339
+
340
+ cfilter, cparams = _country_filter(countries)
341
+
342
+ query = f"""
343
+ WITH coastal_countries AS (
344
+ SELECT DISTINCT d.country
345
+ FROM read_parquet(?) AS d
346
+ JOIN read_parquet(?) AS n
347
+ ON ST_Intersects(d.geometry, ST_SetCRS(n.geometry, 'OGC:CRS84'))
348
+ WHERE d.subtype = 'country'
349
+ AND n.subtype IN ('sea', 'ocean')
350
+ ),
351
+ features AS (
352
+ SELECT
353
+ id,
354
+ names."primary" AS name,
355
+ subtype,
356
+ country,
357
+ admin_level,
358
+ geometry,
359
+ ST_Envelope(geometry) AS bbox
360
+ FROM read_parquet(?)
361
+ {cfilter}
362
+ )
363
+ SELECT
364
+ a.id AS container_id,
365
+ a.name AS container_name,
366
+ a.subtype AS container_subtype,
367
+ b.id AS contained_id,
368
+ b.name AS contained_name,
369
+ b.subtype AS contained_subtype,
370
+ a.country AS container_country,
371
+ 'landlocked_containment' AS relation_type
372
+ FROM features AS a
373
+ JOIN features AS b ON (
374
+ a.id != b.id
375
+ AND a.admin_level < b.admin_level
376
+ AND ST_Intersects(a.bbox, b.bbox)
377
+ AND ST_Within(b.geometry, a.geometry)
378
+ )
379
+ WHERE a.country NOT IN (SELECT country FROM coastal_countries)
380
+ LIMIT ?
381
+ """
382
+
383
+ df = con.execute(
384
+ query,
385
+ [DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH] + cparams + [DIVISIONS_AREA_PATH, limit],
386
+ ).fetchdf()
387
+ print(f"Found {len(df)} landlocked containment pairs")
388
+ return df
389
+
390
+
391
+ def compute_common_neighbor_pairs(
392
+ con: duckdb.DuckDBPyConnection,
393
+ countries: list,
394
+ limit: int,
395
+ ) -> pd.DataFrame:
396
+ """Pairs of anchors that share at least one common touching neighbour.
397
+
398
+ Used by multi_adj_01 (borders both X and Y) so that the generated SQL
399
+ is guaranteed to return at least one result rather than failing constantly
400
+ on random pairs that have no common neighbour.
401
+
402
+ Derived by self-joining adjacency_pairs on the shared target_id.
403
+ """
404
+ print("\nComputing common-neighbor pairs...")
405
+
406
+ adj_path = Path(__file__).parent.parent / "intermediate" / "adjacency_pairs.parquet"
407
+ if not adj_path.exists():
408
+ print(" adjacency_pairs.parquet not found — skipping (run adjacency first)")
409
+ return pd.DataFrame(columns=[
410
+ "anchor_id_1", "anchor_name_1", "anchor_id_2", "anchor_name_2",
411
+ "shared_neighbor_id", "shared_neighbor_name",
412
+ ])
413
+
414
+ query = """
415
+ SELECT DISTINCT
416
+ a1.anchor_id AS anchor_id_1,
417
+ a1.anchor_name AS anchor_name_1,
418
+ a2.anchor_id AS anchor_id_2,
419
+ a2.anchor_name AS anchor_name_2,
420
+ a1.target_id AS shared_neighbor_id,
421
+ a1.target_name AS shared_neighbor_name
422
+ FROM read_parquet(?) AS a1
423
+ JOIN read_parquet(?) AS a2
424
+ ON a1.target_id = a2.target_id
425
+ AND a1.anchor_id < a2.anchor_id
426
+ LIMIT ?
427
+ """
428
+
429
+ df = con.execute(query, [str(adj_path), str(adj_path), limit]).fetchdf()
430
+ print(f"Found {len(df)} common-neighbor pairs")
431
+ return df
432
+
433
+
434
+ def _make_connection():
435
+ """Create a new DuckDB connection with spatial extension loaded."""
436
+ con = duckdb.connect()
437
+ con.execute("INSTALL spatial")
438
+ con.execute("LOAD spatial")
439
+ con.execute("SET memory_limit='24GB'")
440
+ con.execute("SET temp_directory='/tmp/duckdb_tmp'")
441
+ con.execute("SET threads=4")
442
+ return con
443
+
444
+
445
+ def _compute_and_save(compute_fn, countries, limit, output_path):
446
+ """Compute a relation table and save it to parquet. Uses its own DuckDB connection."""
447
+ con = _make_connection()
448
+ try:
449
+ df = compute_fn(con, countries, limit)
450
+ df.to_parquet(output_path, index=False)
451
+ print(f"Saved to {output_path}")
452
+ return df
453
+ finally:
454
+ con.close()
455
+
456
+
457
+ RELATION_FUNCTIONS = {
458
+ "adjacency": compute_adjacency_pairs,
459
+ "containment": compute_containment_pairs,
460
+ "intersection": compute_intersection_pairs,
461
+ "cross_source": compute_cross_source_relations,
462
+ "coastal_containment": compute_coastal_containment_pairs,
463
+ "landlocked_containment": compute_landlocked_containment_pairs,
464
+ "common_neighbor": compute_common_neighbor_pairs,
465
+ }
466
+
467
+
468
+ def compute_single_relation(
469
+ relation_type: str,
470
+ countries: list,
471
+ limit: int,
472
+ output_dir: Path,
473
+ ) -> int:
474
+ """Compute one relation type and save to output_dir.
475
+
476
+ Returns the number of rows computed. Usable from Modal or locally.
477
+ """
478
+ compute_fn = RELATION_FUNCTIONS.get(relation_type)
479
+ if compute_fn is None:
480
+ raise ValueError(
481
+ f"Unknown relation type: {relation_type}. "
482
+ f"Expected one of {list(RELATION_FUNCTIONS)}"
483
+ )
484
+ output_dir.mkdir(exist_ok=True, parents=True)
485
+ output_path = output_dir / f"{relation_type}_pairs.parquet"
486
+ df = _compute_and_save(compute_fn, countries, limit, output_path)
487
+ return len(df)
488
+
489
+
490
+ def main(countries: list = None, relation_limits: dict = None):
491
+ """Compute and save all relation tables in parallel.
492
+
493
+ Args:
494
+ countries: List of country codes to process
495
+ relation_limits: Dict with keys: adjacency, containment, intersection, cross_source
496
+ """
497
+ # Defaults
498
+ if countries is None:
499
+ countries = ['EC', 'BE', 'KE', 'AE', 'SG', 'CH']
500
+ if relation_limits is None:
501
+ relation_limits = {
502
+ 'adjacency': 50000,
503
+ 'containment': 1000,
504
+ 'intersection': 500,
505
+ 'cross_source': 500,
506
+ 'coastal_containment': 1000,
507
+ 'landlocked_containment': 500,
508
+ 'common_neighbor': 5000,
509
+ }
510
+
511
+ output_dir = Path(__file__).parent.parent / "intermediate"
512
+ output_dir.mkdir(exist_ok=True, parents=True)
513
+
514
+ # Define all relation tasks.
515
+ # common_neighbor depends on adjacency_pairs so it runs after adjacency.
516
+ tasks = [
517
+ ("adjacency", compute_adjacency_pairs, relation_limits['adjacency'], output_dir / "adjacency_pairs.parquet"),
518
+ ("containment", compute_containment_pairs, relation_limits['containment'], output_dir / "containment_pairs.parquet"),
519
+ ("intersection", compute_intersection_pairs, relation_limits['intersection'], output_dir / "intersection_pairs.parquet"),
520
+ ("cross_source", compute_cross_source_relations, relation_limits['cross_source'], output_dir / "cross_source_relations.parquet"),
521
+ ("coastal_containment", compute_coastal_containment_pairs, relation_limits['coastal_containment'], output_dir / "coastal_containment_pairs.parquet"),
522
+ ("landlocked_containment", compute_landlocked_containment_pairs, relation_limits['landlocked_containment'], output_dir / "landlocked_containment_pairs.parquet"),
523
+ ("common_neighbor", compute_common_neighbor_pairs, relation_limits['common_neighbor'], output_dir / "common_neighbor_pairs.parquet"),
524
+ ]
525
+
526
+ # common_neighbor reads adjacency_pairs.parquet so it must run after
527
+ # adjacency finishes. Split into two waves.
528
+ independent_tasks = [t for t in tasks if t[0] != "common_neighbor"]
529
+ dependent_tasks = [t for t in tasks if t[0] == "common_neighbor"]
530
+
531
+ print(f"Computing {len(independent_tasks)} relation types in parallel...")
532
+ with ThreadPoolExecutor(max_workers=len(independent_tasks)) as executor:
533
+ futures = {
534
+ executor.submit(_compute_and_save, compute_fn, countries, limit, path): name
535
+ for name, compute_fn, limit, path in independent_tasks
536
+ }
537
+ for future in as_completed(futures):
538
+ name = futures[future]
539
+ try:
540
+ future.result()
541
+ except Exception as e:
542
+ print(f"ERROR computing {name}: {e}")
543
+ raise
544
+
545
+ for name, compute_fn, limit, path in dependent_tasks:
546
+ print(f"\nComputing {name} (depends on adjacency)...")
547
+ try:
548
+ _compute_and_save(compute_fn, countries, limit, path)
549
+ except Exception as e:
550
+ print(f"ERROR computing {name}: {e}")
551
+ raise
552
+
553
+ print("\nRelation tables build complete")
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
dataset/scripts/cli.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CLI for synthetic dataset generation.
4
+
5
+ Usage:
6
+ python cli.py build-relations --config ../config.yaml
7
+ python cli.py generate-samples --config ../config.yaml
8
+ python cli.py generate-samples --config ../config.yaml --append
9
+ python cli.py full-pipeline --config ../config.yaml
10
+ """
11
+
12
+ import argparse
13
+ import subprocess
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Dict
17
+
18
+ import pandas as pd
19
+ import yaml
20
+
21
+
22
+ def load_config(config_path: Path) -> dict:
23
+ """Load configuration from YAML file."""
24
+ with open(config_path) as f:
25
+ return yaml.safe_load(f)
26
+
27
+
28
+ def should_rebuild_relations(config: dict, intermediate_dir: Path, append: bool) -> bool:
29
+ """Check if relation tables need to be rebuilt.
30
+
31
+ Returns True if:
32
+ - Not in append mode (always rebuild)
33
+ - Relation tables don't exist
34
+ - Countries in config differ from countries in existing relation tables
35
+ """
36
+ if not append:
37
+ return True
38
+
39
+ # Check if relation tables exist
40
+ adjacency_file = intermediate_dir / "adjacency_pairs.parquet"
41
+ if not adjacency_file.exists():
42
+ print("WARNING: Relation tables not found, will rebuild despite append mode")
43
+ return True
44
+
45
+ # Check if countries have changed
46
+ try:
47
+ df = pd.read_parquet(adjacency_file)
48
+ if 'anchor_country' in df.columns:
49
+ existing_countries = set(df['anchor_country'].unique())
50
+ config_countries = set(config['countries'])
51
+
52
+ if existing_countries != config_countries:
53
+ print(f"WARNING: Countries changed:")
54
+ print(f" Previous: {sorted(existing_countries)}")
55
+ print(f" New: {sorted(config_countries)}")
56
+ print(f" Will rebuild relation tables to include new countries")
57
+ return True
58
+ else:
59
+ print(f"Countries unchanged: {sorted(config_countries)}")
60
+ return False
61
+ else:
62
+ # Can't determine countries, rebuild to be safe
63
+ print("WARNING: Cannot determine countries from existing tables, will rebuild")
64
+ return True
65
+ except Exception as e:
66
+ print(f"WARNING: Error reading existing relation tables: {e}")
67
+ print(" Will rebuild to be safe")
68
+ return True
69
+
70
+
71
+ def calculate_relation_limits(config: dict) -> Dict[str, int]:
72
+ """Auto-calculate relation limits based on sample targets."""
73
+ sample_targets = config['sample_targets']
74
+ retry_mult = config['generation']['retry_multiplier']
75
+ safety = config.get('auto_scaling', {}).get('safety_factor', 1.5)
76
+
77
+ # Map each task family to the relation tables it draws anchors from.
78
+ # A family can need multiple relation types.
79
+ family_to_relations = {
80
+ 'direct_lookup': [],
81
+ 'adjacency': ['adjacency'],
82
+ 'multi_adjacency': ['adjacency', 'common_neighbor'],
83
+ 'containment': ['containment'],
84
+ 'intersection': ['intersection', 'cross_source'],
85
+ 'buffer': ['adjacency'],
86
+ 'chained': ['coastal_containment', 'landlocked_containment', 'containment'],
87
+ 'difference': ['containment', 'cross_source'],
88
+ 'border_corridor': ['adjacency'],
89
+ 'set_operations': ['containment', 'cross_source'],
90
+ 'partial_selection': ['containment', 'cross_source'],
91
+ 'aggregation': ['containment'],
92
+ 'window_function': [],
93
+ 'attribute_filter': [],
94
+ }
95
+
96
+ relation_needs: Dict[str, int] = {}
97
+ for family, target in sample_targets.items():
98
+ for rel_type in family_to_relations.get(family, []):
99
+ needed = int(target * retry_mult * safety)
100
+ relation_needs[rel_type] = relation_needs.get(rel_type, 0) + needed
101
+
102
+ # common_neighbor is derived from adjacency — keep its limit proportional
103
+ if 'common_neighbor' not in relation_needs and 'adjacency' in relation_needs:
104
+ relation_needs['common_neighbor'] = relation_needs['adjacency'] * 3
105
+
106
+ # Apply manual overrides if specified
107
+ manual = config.get('auto_scaling', {}).get('manual_limits', {})
108
+ relation_needs.update(manual)
109
+
110
+ return relation_needs
111
+
112
+
113
+ def build_relations(config_path: Path):
114
+ """Run relation building with config."""
115
+ config = load_config(config_path)
116
+
117
+ # Auto-calculate relation limits
118
+ relation_limits = calculate_relation_limits(config)
119
+
120
+ print("=" * 60)
121
+ print("STEP 1: Building Relation Tables")
122
+ print("=" * 60)
123
+ print(f"Countries: {', '.join(config['countries'])}")
124
+ print(f"\nAuto-calculated relation limits:")
125
+ for rel_type, limit in relation_limits.items():
126
+ print(f" {rel_type:20s}: {limit:,}")
127
+ print()
128
+
129
+ # Import and run the relation builder
130
+ from dataset.scripts import build_relations
131
+
132
+ # Run with config parameters
133
+ build_relations.main(
134
+ countries=config['countries'],
135
+ relation_limits=relation_limits
136
+ )
137
+
138
+ print("\nRelation tables built successfully")
139
+
140
+
141
+ def generate_samples(config_path: Path, append: bool = False):
142
+ """Run sample generation with config."""
143
+ config = load_config(config_path)
144
+
145
+ print("=" * 60)
146
+ print("STEP 2: Generating Samples")
147
+ print("=" * 60)
148
+ print(f"Targets: {config['sample_targets']}")
149
+ print(f"Workers: {config['generation']['max_workers']}")
150
+ print(f"Append mode: {append or config['generation']['append_mode']}")
151
+ print()
152
+
153
+ # Simple import - no number prefixes needed
154
+ from dataset.scripts import generate_samples as gs_module
155
+
156
+ # Override config values
157
+ gs_module.TARGET_COUNTS = config['sample_targets']
158
+ gs_module.MAX_WORKERS = config['generation']['max_workers']
159
+ gs_module.RETRY_MULTIPLIER = config['generation']['retry_multiplier']
160
+ gs_module.APPEND_MODE = append or config['generation']['append_mode']
161
+
162
+ # Run the main function
163
+ gs_module.main()
164
+
165
+ print("\nSamples generated successfully")
166
+
167
+
168
+ def validate_dataset(config_path: Path):
169
+ """Run dataset validation."""
170
+ print("=" * 60)
171
+ print("STEP 3: Validating Dataset")
172
+ print("=" * 60)
173
+
174
+ script_dir = Path(__file__).parent
175
+ result = subprocess.run(
176
+ [sys.executable, str(script_dir / "validate_dataset.py")],
177
+ check=True
178
+ )
179
+
180
+ print("\nDataset validated successfully")
181
+
182
+
183
+ def export_dataset(config_path: Path):
184
+ """Run dataset export for both SQL generation and place extraction tasks."""
185
+ print("=" * 60)
186
+ print("STEP 4: Exporting Dataset")
187
+ print("=" * 60)
188
+
189
+ from dataset.scripts.export_training_data import main as export_main
190
+ export_main(config_path=config_path)
191
+
192
+ print("\nDataset exported successfully")
193
+
194
+
195
+ def modal_upload(config_path: Path):
196
+ """Upload local data to Modal volume."""
197
+ subprocess.run(
198
+ [sys.executable, "-m", "modal", "run",
199
+ "dataset/modal_app.py::upload_data"],
200
+ check=True
201
+ )
202
+
203
+
204
+ def modal_generate(config_path: Path, num_containers: int = 0,
205
+ skip_inventory: bool = False, skip_relations: bool = False,
206
+ fresh: bool = False):
207
+ """Run distributed generation on Modal (appends by default)."""
208
+ cmd = [
209
+ sys.executable, "-m", "modal", "run",
210
+ "dataset/modal_app.py::run_pipeline",
211
+ "--config-path", str(config_path),
212
+ ]
213
+ if num_containers > 0:
214
+ cmd.extend(["--num-containers", str(num_containers)])
215
+ if skip_inventory:
216
+ cmd.append("--skip-inventory")
217
+ if skip_relations:
218
+ cmd.append("--skip-relations")
219
+ if fresh:
220
+ cmd.append("--fresh")
221
+
222
+ subprocess.run(cmd, check=True)
223
+ validate_dataset(config_path)
224
+ export_dataset(config_path)
225
+
226
+
227
+ def full_pipeline(config_path: Path, append: bool = False):
228
+ """Run the full pipeline."""
229
+ print("Running full dataset generation pipeline")
230
+
231
+ config = load_config(config_path)
232
+
233
+ # Check if inventory exists, create if not
234
+ script_dir = Path(__file__).parent
235
+ intermediate_dir = script_dir.parent / "intermediate"
236
+ inventory_files = [
237
+ intermediate_dir / "divisions_area_inventory.parquet",
238
+ intermediate_dir / "natural_earth_inventory.parquet"
239
+ ]
240
+
241
+ inventory_missing = any(not f.exists() for f in inventory_files)
242
+
243
+ if inventory_missing:
244
+ print("=" * 60)
245
+ print("STEP 0: Building Entity Inventory")
246
+ print("=" * 60)
247
+ print("Inventory files not found, building...")
248
+ from dataset.scripts import build_inventory
249
+ build_inventory.main()
250
+
251
+ # Check if we need to rebuild relations
252
+ need_rebuild = should_rebuild_relations(config, intermediate_dir, append)
253
+
254
+ if need_rebuild:
255
+ build_relations(config_path)
256
+ else:
257
+ print("Using existing relation tables (append mode, same countries)")
258
+
259
+ generate_samples(config_path, append=append)
260
+ validate_dataset(config_path)
261
+ export_dataset(config_path)
262
+
263
+ print("\nPipeline complete")
264
+
265
+
266
+ def main():
267
+ parser = argparse.ArgumentParser(
268
+ description="Synthetic dataset generation CLI",
269
+ formatter_class=argparse.RawDescriptionHelpFormatter,
270
+ epilog="""
271
+ Examples:
272
+ # Build relation tables only
273
+ python cli.py build-relations --config ../config.yaml
274
+
275
+ # Generate samples only
276
+ python cli.py generate-samples --config ../config.yaml
277
+
278
+ # Generate and append to existing dataset
279
+ python cli.py generate-samples --config ../config.yaml --append
280
+
281
+ # Run full pipeline
282
+ python cli.py full-pipeline --config ../config.yaml
283
+
284
+ # Run full pipeline in append mode (skip relation building)
285
+ python cli.py full-pipeline --config ../config.yaml --append
286
+
287
+ # Upload data to Modal volume (one-time)
288
+ python cli.py modal-upload --config ../config.yaml
289
+
290
+ # Run distributed generation on Modal
291
+ python cli.py modal-generate --config ../config.yaml
292
+ python cli.py modal-generate --config ../config.yaml --num-containers 100
293
+ python cli.py modal-generate --config ../config.yaml --skip-inventory --skip-relations
294
+ """
295
+ )
296
+
297
+ parser.add_argument(
298
+ 'command',
299
+ choices=['build-relations', 'generate-samples', 'validate', 'export',
300
+ 'full-pipeline', 'modal-upload', 'modal-generate'],
301
+ help='Command to run'
302
+ )
303
+
304
+ parser.add_argument(
305
+ '--config',
306
+ type=Path,
307
+ required=True,
308
+ help='Path to config YAML file'
309
+ )
310
+
311
+ parser.add_argument(
312
+ '--append',
313
+ action='store_true',
314
+ help='Append to existing dataset instead of overwriting'
315
+ )
316
+
317
+ parser.add_argument(
318
+ '--num-containers',
319
+ type=int,
320
+ default=0,
321
+ help='Number of Modal containers (0 = use config default)'
322
+ )
323
+
324
+ parser.add_argument(
325
+ '--skip-inventory',
326
+ action='store_true',
327
+ help='Skip inventory building on Modal'
328
+ )
329
+
330
+ parser.add_argument(
331
+ '--skip-relations',
332
+ action='store_true',
333
+ help='Skip relation building on Modal'
334
+ )
335
+
336
+ parser.add_argument(
337
+ '--fresh',
338
+ action='store_true',
339
+ help='Overwrite existing dataset instead of appending (Modal only)'
340
+ )
341
+
342
+ args = parser.parse_args()
343
+
344
+ # Validate config file exists
345
+ if not args.config.exists():
346
+ print(f"Error: Config file not found: {args.config}")
347
+ sys.exit(1)
348
+
349
+ # Run the appropriate command
350
+ try:
351
+ if args.command == 'build-relations':
352
+ build_relations(args.config)
353
+ elif args.command == 'generate-samples':
354
+ generate_samples(args.config, args.append)
355
+ elif args.command == 'validate':
356
+ validate_dataset(args.config)
357
+ elif args.command == 'export':
358
+ export_dataset(args.config)
359
+ elif args.command == 'full-pipeline':
360
+ full_pipeline(args.config, args.append)
361
+ elif args.command == 'modal-upload':
362
+ modal_upload(args.config)
363
+ elif args.command == 'modal-generate':
364
+ modal_generate(
365
+ args.config,
366
+ num_containers=args.num_containers,
367
+ skip_inventory=args.skip_inventory,
368
+ skip_relations=args.skip_relations,
369
+ fresh=args.fresh,
370
+ )
371
+ except Exception as e:
372
+ print(f"\nError: {e}")
373
+ sys.exit(1)
374
+
375
+
376
+ if __name__ == "__main__":
377
+ main()
dataset/scripts/export_training_data.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export validated dataset to train/val/test splits.
3
+
4
+ Produces two task datasets from the same source samples:
5
+
6
+ 1. SQL generation (prompt = question + candidates CSV, completion = SQL)
7
+ 2. Place extraction (prompt = question only, completion = PlacesResult JSON)
8
+
9
+ Place extraction pairs are derived automatically: for each SQL sample the
10
+ selected_candidates give us the correct place names, subtypes, and country
11
+ codes that the extractor should return.
12
+
13
+ Output layout (all paths relative to dataset/):
14
+ output/runs/{run_name}/sql/train.jsonl
15
+ output/runs/{run_name}/sql/val.jsonl
16
+ output/runs/{run_name}/sql/test.jsonl
17
+ output/runs/{run_name}/places/train.jsonl
18
+ output/runs/{run_name}/places/val.jsonl
19
+ output/runs/{run_name}/places/test.jsonl
20
+ output/runs/{run_name}/stats.json
21
+ """
22
+
23
+ import json
24
+ import random
25
+ import sys
26
+ from collections import defaultdict
27
+ from pathlib import Path
28
+ from typing import Any, Dict, List, Optional, Tuple
29
+
30
+ import sqlparse
31
+ import yaml
32
+
33
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Loading
38
+ # ---------------------------------------------------------------------------
39
+
40
+ def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
41
+ samples = []
42
+ with open(jsonl_path) as f:
43
+ for line in f:
44
+ line = line.strip()
45
+ if line:
46
+ samples.append(json.loads(line))
47
+ return samples
48
+
49
+
50
+ def load_run_name(config_path: Optional[Path]) -> str:
51
+ if config_path and config_path.exists():
52
+ with open(config_path) as f:
53
+ cfg = yaml.safe_load(f)
54
+ return cfg.get("run_name", "default")
55
+ return "default"
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Splitting
60
+ # ---------------------------------------------------------------------------
61
+
62
+ def stratified_split(
63
+ samples: List[Dict[str, Any]],
64
+ train_ratio: float = 0.8,
65
+ val_ratio: float = 0.1,
66
+ seed: int = 42,
67
+ ) -> Tuple[List[Dict], List[Dict], List[Dict]]:
68
+ """Split stratified by task_family so every family is represented in each split."""
69
+ random.seed(seed)
70
+ by_family: Dict[str, List] = defaultdict(list)
71
+ for s in samples:
72
+ by_family[s["metadata"]["task_family"]].append(s)
73
+
74
+ train, val, test = [], [], []
75
+ for family_samples in by_family.values():
76
+ random.shuffle(family_samples)
77
+ n = len(family_samples)
78
+ n_train = int(n * train_ratio)
79
+ n_val = int(n * val_ratio)
80
+ train.extend(family_samples[:n_train])
81
+ val.extend(family_samples[n_train : n_train + n_val])
82
+ test.extend(family_samples[n_train + n_val :])
83
+
84
+ random.shuffle(train)
85
+ random.shuffle(val)
86
+ random.shuffle(test)
87
+ return train, val, test
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # SQL generation format
92
+ # Conversational prompt-completion: model sees system + user, generates SQL.
93
+ # ---------------------------------------------------------------------------
94
+
95
+ _SQL_SYSTEM = """You are a text to SQL query translator that helps in natural language geocoding.
96
+
97
+ You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
98
+
99
+ <SCHEMA>
100
+ 1. divisions_area -- Overture polygon/multipolygon admin boundaries
101
+ query: read_parquet('divisions_area')
102
+ columns:
103
+ id VARCHAR -- unique feature id
104
+ names STRUCT("primary" VARCHAR, ...)
105
+ country VARCHAR -- ISO 3166-1 alpha-2
106
+ subtype VARCHAR -- country | region | dependency | county | localadmin |
107
+ locality | macrohood | neighborhood | microhood
108
+ class VARCHAR
109
+ region VARCHAR
110
+ admin_level INTEGER
111
+ division_id VARCHAR
112
+ is_land BOOLEAN
113
+ is_territorial BOOLEAN
114
+ geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
115
+
116
+ 2. natural_earth -- Natural Earth geography polygons (oceans, seas, rivers, terrain)
117
+ query: read_parquet('natural_earth')
118
+ columns:
119
+ id VARCHAR -- unique feature id prefixed 'ne_'
120
+ names STRUCT("primary" VARCHAR, ...)
121
+ country VARCHAR
122
+ subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
123
+ class VARCHAR
124
+ region VARCHAR
125
+ admin_level INTEGER
126
+ is_land BOOLEAN
127
+ is_territorial BOOLEAN
128
+ geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
129
+ </SCHEMA>
130
+
131
+ The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
132
+ Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
133
+ Use ST_AsGeoJSON(geometry) for all geometry outputs."""
134
+
135
+ _CANDIDATES_COLS = [
136
+ "source", "id", "name", "subtype", "country", "region",
137
+ "admin_level", "similarity",
138
+ ]
139
+
140
+
141
+ def _candidates_csv(candidates: List[Dict]) -> str:
142
+ import io
143
+ import csv
144
+ rows = []
145
+ for c in candidates:
146
+ row = {col: c.get(col, "") for col in _CANDIDATES_COLS if col in c}
147
+ rows.append(row)
148
+ if not rows:
149
+ return ""
150
+ buf = io.StringIO()
151
+ writer = csv.DictWriter(buf, fieldnames=[k for k in _CANDIDATES_COLS if k in rows[0]])
152
+ writer.writeheader()
153
+ writer.writerows(rows)
154
+ return buf.getvalue().strip()
155
+
156
+
157
+ def _to_symbolic_sql(sql: str) -> str:
158
+ """Normalize any hardcoded or runtime paths back to symbolic names."""
159
+ sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
160
+ sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
161
+ sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
162
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
163
+ sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
164
+ return sql
165
+
166
+
167
+ def _format_sql(sql: str) -> str:
168
+ """Pretty-print SQL so the model learns clean, readable style."""
169
+ return sqlparse.format(
170
+ sql,
171
+ reindent=True,
172
+ keyword_case="upper",
173
+ indent_width=4,
174
+ ).strip()
175
+
176
+
177
+ def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
178
+ """Convert a raw sample to a conversational prompt-completion pair for SQL generation."""
179
+ sql = sample.get("target", {}).get("sql", "").strip()
180
+ if not sql:
181
+ return None
182
+ sql = _format_sql(_to_symbolic_sql(sql))
183
+
184
+ user_content = (
185
+ f"<CANDIDATES>\n{_candidates_csv(sample.get('candidates', []))}\n</CANDIDATES>\n\n"
186
+ f"<USER_QUERY>\n{sample['question']}\n</USER_QUERY>"
187
+ )
188
+
189
+ return {
190
+ "messages": [
191
+ {"role": "system", "content": _SQL_SYSTEM},
192
+ {"role": "user", "content": user_content},
193
+ {"role": "assistant", "content": sql},
194
+ ],
195
+ "metadata": sample.get("metadata", {}),
196
+ }
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Place extraction format
201
+ # Derived from the same SQL samples: selected_candidates → PlacesResult JSON.
202
+ # ---------------------------------------------------------------------------
203
+
204
+ _PLACE_SYSTEM = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
205
+
206
+ OUTPUT FORMAT:
207
+ {"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
208
+ "country" and "subtype" are optional; omit if not applicable.
209
+
210
+ RULES:
211
+ - Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
212
+ - No duplicate place names.
213
+ - "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
214
+ - "subtype": include only when the geographic level is clear from the query.
215
+
216
+ SUBTYPES:
217
+ country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
218
+ - Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
219
+
220
+ # Overture division subtypes — used to filter out natural_earth candidates
221
+ # from the place extraction output (NE features don't have these subtypes).
222
+ _DIVISION_SUBTYPES = {
223
+ "country", "region", "dependency", "county", "localadmin",
224
+ "locality", "macrohood", "neighborhood", "microhood",
225
+ }
226
+
227
+
228
+ def _candidate_to_place(c: Dict) -> Optional[Dict]:
229
+ """Convert a selected candidate to a Place dict for PlacesResult."""
230
+ name = c.get("name", "").strip()
231
+ if not name:
232
+ return None
233
+
234
+ place: Dict[str, Any] = {"place": name}
235
+
236
+ subtype = c.get("subtype", "")
237
+ if subtype in _DIVISION_SUBTYPES:
238
+ place["subtype"] = subtype
239
+
240
+ country = c.get("country", "")
241
+ if country and len(country) == 2:
242
+ place["country"] = country
243
+
244
+ return place
245
+
246
+
247
+ def sample_to_place_pair(sample: Dict[str, Any]) -> Optional[Dict]:
248
+ """Convert a raw sample to a conversational prompt-completion pair for place extraction.
249
+
250
+ Uses selected_candidates to determine the correct PlacesResult output.
251
+ Skips samples where no valid places can be derived.
252
+ """
253
+ selected_ids = set(sample.get("target", {}).get("selected_candidates", []))
254
+ if not selected_ids:
255
+ return None
256
+
257
+ id_to_candidate = {c["candidate_id"]: c for c in sample.get("candidates", [])}
258
+ places = []
259
+ seen_names: set = set()
260
+
261
+ for cid in selected_ids:
262
+ c = id_to_candidate.get(cid)
263
+ if not c:
264
+ continue
265
+ place = _candidate_to_place(c)
266
+ if place and place["place"].lower() not in seen_names:
267
+ places.append(place)
268
+ seen_names.add(place["place"].lower())
269
+
270
+ if not places:
271
+ return None
272
+
273
+ completion_json = json.dumps({"places": places}, ensure_ascii=False)
274
+
275
+ return {
276
+ "messages": [
277
+ {"role": "system", "content": _PLACE_SYSTEM},
278
+ {"role": "user", "content": sample["question"]},
279
+ {"role": "assistant", "content": completion_json},
280
+ ],
281
+ "metadata": sample.get("metadata", {}),
282
+ }
283
+
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # I/O helpers
287
+ # ---------------------------------------------------------------------------
288
+
289
+ def save_jsonl(records: List[Dict], path: Path) -> None:
290
+ path.parent.mkdir(parents=True, exist_ok=True)
291
+ with open(path, "w") as f:
292
+ for r in records:
293
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
294
+
295
+
296
+ def split_stats(samples: List[Dict]) -> Dict[str, int]:
297
+ counts: Dict[str, int] = defaultdict(int)
298
+ for s in samples:
299
+ counts[s.get("metadata", {}).get("task_family", "unknown")] += 1
300
+ return dict(sorted(counts.items()))
301
+
302
+
303
+ # ---------------------------------------------------------------------------
304
+ # Main
305
+ # ---------------------------------------------------------------------------
306
+
307
+ def main(config_path: Optional[Path] = None) -> None:
308
+ script_dir = Path(__file__).parent
309
+ dataset_dir = script_dir.parent
310
+ output_dir = dataset_dir / "output"
311
+
312
+ run_name = load_run_name(config_path or dataset_dir / "config.yaml")
313
+
314
+ validated_file = output_dir / "dataset_validated.jsonl"
315
+ if not validated_file.exists():
316
+ print(f"Error: {validated_file} not found. Run validate first.")
317
+ sys.exit(1)
318
+
319
+ run_dir = output_dir / "runs" / run_name
320
+ sql_dir = run_dir / "sql"
321
+ places_dir = run_dir / "places"
322
+
323
+ print(f"Run name : {run_name}")
324
+ print(f"Output dir : {run_dir}")
325
+
326
+ # Load
327
+ print("\nLoading validated samples...")
328
+ samples = load_samples(validated_file)
329
+ print(f" {len(samples):,} samples loaded")
330
+
331
+ # Split once, reuse for both tasks
332
+ print("\nSplitting 80 / 10 / 10 (stratified by task family)...")
333
+ train_raw, val_raw, test_raw = stratified_split(samples)
334
+ print(f" train={len(train_raw):,} val={len(val_raw):,} test={len(test_raw):,}")
335
+
336
+ # --- SQL generation ---
337
+ print("\nBuilding SQL generation splits...")
338
+ sql_stats: Dict = {}
339
+ for split_name, raw in [("train", train_raw), ("val", val_raw), ("test", test_raw)]:
340
+ pairs = [p for s in raw if (p := sample_to_sql_pair(s)) is not None]
341
+ save_jsonl(pairs, sql_dir / f"{split_name}.jsonl")
342
+ sql_stats[split_name] = {"total": len(pairs), "by_family": split_stats(pairs)}
343
+ print(f" sql/{split_name}.jsonl — {len(pairs):,} pairs")
344
+
345
+ # --- Place extraction ---
346
+ print("\nBuilding place extraction splits...")
347
+ place_stats: Dict = {}
348
+ for split_name, raw in [("train", train_raw), ("val", val_raw), ("test", test_raw)]:
349
+ pairs = [p for s in raw if (p := sample_to_place_pair(s)) is not None]
350
+ save_jsonl(pairs, places_dir / f"{split_name}.jsonl")
351
+ place_stats[split_name] = {"total": len(pairs), "by_family": split_stats(pairs)}
352
+ print(f" places/{split_name}.jsonl — {len(pairs):,} pairs")
353
+
354
+ # --- Stats ---
355
+ stats = {
356
+ "run_name": run_name,
357
+ "total_samples": len(samples),
358
+ "sql_generation": sql_stats,
359
+ "place_extraction": place_stats,
360
+ }
361
+ stats_path = run_dir / "stats.json"
362
+ with open(stats_path, "w") as f:
363
+ json.dump(stats, f, indent=2)
364
+
365
+ print(f"\nStats written to {stats_path}")
366
+ print("\nDone. Training-ready files:")
367
+ print(f" SQL generation : {sql_dir}/{{train,val,test}}.jsonl")
368
+ print(f" Place extraction: {places_dir}/{{train,val,test}}.jsonl")
369
+
370
+
371
+ if __name__ == "__main__":
372
+ main()
dataset/scripts/generate_samples.py ADDED
@@ -0,0 +1,1560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate synthetic training samples for text-to-SQL task.
3
+
4
+ This script:
5
+ 1. Loads relation tables and entity inventories
6
+ 2. For each SQL template, samples valid anchors
7
+ 3. Renders and executes SQL to verify it works
8
+ 4. Builds candidate lists with controlled distractors
9
+ 5. Generates natural language questions using LLM
10
+ 6. Saves complete training samples
11
+
12
+ Output:
13
+ - output/samples/sample_*.json (individual samples)
14
+ - output/dataset_raw.jsonl (all samples)
15
+ """
16
+
17
+ import json
18
+ import random
19
+ import warnings
20
+ from pathlib import Path
21
+ from typing import List, Dict, Any, Optional
22
+ from concurrent.futures import ProcessPoolExecutor, as_completed
23
+ from functools import partial
24
+
25
+ import duckdb
26
+ import pandas as pd
27
+ from pydantic import BaseModel
28
+
29
+ # Suppress warnings
30
+ warnings.filterwarnings('ignore')
31
+
32
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
33
+
34
+ # Fixed paths embedded in every training SQL string.
35
+ # The model learns these short, stable strings rather than machine-specific
36
+ # local paths. At inference, sql.py's _rewrite_data_paths substitutes them
37
+ # with the actual runtime paths from gazet.config.
38
+ _DIVISIONS_SQL_PATH = 'divisions_area'
39
+ _NATURAL_EARTH_SQL_PATH = 'natural_earth'
40
+
41
+
42
+ def _for_execution(sql: str) -> str:
43
+ """Replace symbolic placeholder paths with actual local paths for verification."""
44
+ return (
45
+ sql
46
+ .replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
47
+ .replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
48
+ )
49
+
50
+ # Configurable parameters (can be overridden by CLI)
51
+ TARGET_COUNTS = None # Will be set in main() or by CLI
52
+ MAX_WORKERS = 8
53
+ RETRY_MULTIPLIER = 2
54
+ APPEND_MODE = False
55
+
56
+ # Import templates from same directory
57
+ from . import sql_templates
58
+ TEMPLATES = sql_templates.TEMPLATES
59
+ SQLTemplate = sql_templates.SQLTemplate
60
+ get_templates_by_family = sql_templates.get_templates_by_family
61
+
62
+
63
+ class Candidate(BaseModel):
64
+ """Candidate entity for grounding."""
65
+ candidate_id: str
66
+ source: str
67
+ id: str
68
+ name: str
69
+ subtype: Optional[str] = None
70
+ country: Optional[str] = None
71
+ region: Optional[str] = None
72
+ admin_level: Optional[int] = None
73
+ similarity: float = 0.0
74
+
75
+
76
+ class TrainingSample(BaseModel):
77
+ """Complete training sample."""
78
+ id: str
79
+ question: str
80
+ candidates: List[Candidate]
81
+ target: Dict[str, Any]
82
+ metadata: Dict[str, Any]
83
+
84
+
85
+ def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[str, pd.DataFrame]:
86
+ """Load all precomputed relation tables."""
87
+ tables = {}
88
+
89
+ for file in intermediate_dir.glob("*.parquet"):
90
+ name = file.stem
91
+ tables[name] = pd.read_parquet(file)
92
+ if not quiet:
93
+ print(f" {name}: {len(tables[name])} rows")
94
+
95
+ return tables
96
+
97
+
98
+ def sample_adjacency_anchor(adjacency_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
99
+ """Sample a random adjacency pair."""
100
+ if adjacency_df.empty:
101
+ return None
102
+
103
+ row = adjacency_df.sample(n=1).iloc[0]
104
+ return {
105
+ 'anchor_id': row['anchor_id'],
106
+ 'anchor_name': row['anchor_name'],
107
+ 'anchor_subtype': row['anchor_subtype'],
108
+ 'anchor_country': row.get('anchor_country'), # May not exist in all tables
109
+ 'target_subtype': row.get('target_subtype')
110
+ }
111
+
112
+
113
+ def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
114
+ """Sample a random intersection pair."""
115
+ if intersection_df.empty:
116
+ return None
117
+
118
+ row = intersection_df.sample(n=1).iloc[0]
119
+ return {
120
+ 'anchor_id': row['anchor_id'],
121
+ 'anchor_name': row['anchor_name'],
122
+ 'anchor_subtype': row['anchor_subtype'],
123
+ 'target_id': row.get('target_id'),
124
+ 'target_name': row.get('target_name'),
125
+ 'target_subtype': row.get('target_subtype')
126
+ }
127
+
128
+
129
+ def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
130
+ """Sample a random containment pair."""
131
+ if containment_df.empty:
132
+ return None
133
+
134
+ row = containment_df.sample(n=1).iloc[0]
135
+ return {
136
+ 'container_id': row['container_id'],
137
+ 'container_name': row['container_name'],
138
+ 'container_subtype': row['container_subtype'],
139
+ 'contained_subtype': row['contained_subtype']
140
+ }
141
+
142
+
143
+ def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
144
+ """Sample a random cross-source relation."""
145
+ if cross_source_df.empty:
146
+ return None
147
+
148
+ row = cross_source_df.sample(n=1).iloc[0]
149
+ return {
150
+ 'division_id': row['division_id'],
151
+ 'division_name': row['division_name'],
152
+ 'division_subtype': row['division_subtype'],
153
+ 'natural_id': row['natural_id'],
154
+ 'natural_name': row['natural_name'],
155
+ 'natural_subtype': row['natural_subtype'],
156
+ 'relation_type': row['relation_type']
157
+ }
158
+
159
+
160
+ def _merge_candidate_lists(
161
+ *lists: List[Candidate],
162
+ max_total: int = 10,
163
+ ) -> List[Candidate]:
164
+ """Merge N candidate lists, deduplicate by id, reassign candidate_ids.
165
+
166
+ Interleaves the lists so each anchor is represented before any anchor
167
+ gets a second candidate — matching the grouped-then-interleaved order
168
+ that inference produces.
169
+ """
170
+ from itertools import zip_longest
171
+
172
+ seen: set = set()
173
+ merged: List[Candidate] = []
174
+ for row in zip_longest(*lists):
175
+ for c in row:
176
+ if c is None:
177
+ continue
178
+ if c.id not in seen:
179
+ merged.append(c)
180
+ seen.add(c.id)
181
+ if len(merged) >= max_total:
182
+ break
183
+ if len(merged) >= max_total:
184
+ break
185
+ for i, c in enumerate(merged, 1):
186
+ c.candidate_id = f"c{i}"
187
+ return merged
188
+
189
+
190
+ def build_candidate_list(
191
+ con: duckdb.DuckDBPyConnection,
192
+ anchor_id: str,
193
+ anchor_name: str,
194
+ anchor_source: str,
195
+ num_candidates: int = 10,
196
+ difficulty: str = "medium"
197
+ ) -> List[Candidate]:
198
+ """Build candidate list with true anchor + distractors."""
199
+
200
+ # Helper to convert pandas NA to None
201
+ def safe_get(row, key, default=None):
202
+ val = row.get(key, default)
203
+ return None if pd.isna(val) else val
204
+
205
+ # Get the true anchor
206
+ if anchor_source == "divisions_area":
207
+ query = """
208
+ SELECT
209
+ id,
210
+ names."primary" AS name,
211
+ subtype,
212
+ country,
213
+ region,
214
+ admin_level
215
+ FROM read_parquet(?)
216
+ WHERE id = ?
217
+ """
218
+ anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
219
+ else:
220
+ query = """
221
+ SELECT
222
+ id,
223
+ names."primary" AS name,
224
+ subtype
225
+ FROM read_parquet(?)
226
+ WHERE id = ?
227
+ """
228
+ anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
229
+
230
+ # Build true candidate
231
+ true_candidate = Candidate(
232
+ candidate_id="c1",
233
+ source=anchor_source,
234
+ id=anchor_id,
235
+ name=safe_get(anchor_row, 'name'),
236
+ subtype=safe_get(anchor_row, 'subtype'),
237
+ country=safe_get(anchor_row, 'country'),
238
+ region=safe_get(anchor_row, 'region'),
239
+ admin_level=safe_get(anchor_row, 'admin_level'),
240
+ similarity=1.0
241
+ )
242
+
243
+ # Build distractors based on difficulty
244
+ distractors = build_distractors(
245
+ con,
246
+ anchor_name,
247
+ anchor_source,
248
+ anchor_id,
249
+ num_candidates - 1,
250
+ difficulty
251
+ )
252
+
253
+ # Order: true anchor first, then same-source distractors, then cross-source
254
+ # distractors. This mirrors inference order (anchor at top by similarity,
255
+ # same source grouped before the other source).
256
+ candidates = [true_candidate] + distractors
257
+
258
+ # Reassign candidate IDs in order
259
+ for i, cand in enumerate(candidates, 1):
260
+ cand.candidate_id = f"c{i}"
261
+
262
+ return candidates
263
+
264
+
265
+ def build_distractors(
266
+ con: duckdb.DuckDBPyConnection,
267
+ anchor_name: str,
268
+ anchor_source: str,
269
+ exclude_id: str,
270
+ num_distractors: int,
271
+ difficulty: str,
272
+ cross_source_ratio: float = 0.5,
273
+ ) -> List[Candidate]:
274
+ """Build distractor candidates using fuzzy search.
275
+
276
+ Always includes candidates from both sources so the model sees mixed
277
+ ``source`` values in every training example — matching the inference
278
+ behaviour where search.py queries divisions_area AND natural_earth equally
279
+ (5 results each per place).
280
+
281
+ Args:
282
+ cross_source_ratio: Fraction of distractors drawn from the *other*
283
+ source. Defaults to 0.5 (50/50 split) to match inference exactly.
284
+ """
285
+
286
+ def safe_get(row, key, default=None):
287
+ val = row.get(key, default)
288
+ return None if pd.isna(val) else val
289
+
290
+ def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
291
+ query = """
292
+ SELECT
293
+ id,
294
+ names."primary" AS name,
295
+ subtype,
296
+ country,
297
+ region,
298
+ admin_level,
299
+ jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
300
+ FROM read_parquet(?)
301
+ WHERE id != ?
302
+ AND names."primary" IS NOT NULL
303
+ ORDER BY similarity DESC
304
+ LIMIT ?
305
+ """
306
+ df = con.execute(query, [anchor_name, path, excl_id, n]).fetchdf()
307
+ results = []
308
+ for _, row in df.iterrows():
309
+ results.append(Candidate(
310
+ candidate_id="temp",
311
+ source=src_name,
312
+ id=row["id"],
313
+ name=safe_get(row, "name"),
314
+ subtype=safe_get(row, "subtype"),
315
+ country=safe_get(row, "country"),
316
+ region=safe_get(row, "region"),
317
+ admin_level=safe_get(row, "admin_level"),
318
+ similarity=float(row["similarity"]),
319
+ ))
320
+ return results
321
+
322
+ cross_n = max(1, round(num_distractors * cross_source_ratio))
323
+ same_n = num_distractors - cross_n
324
+
325
+ if anchor_source == "divisions_area":
326
+ same = _query_source(DIVISIONS_AREA_PATH, "divisions_area", same_n, exclude_id)
327
+ cross = _query_source(NATURAL_EARTH_PATH, "natural_earth", cross_n, "")
328
+ else:
329
+ same = _query_source(NATURAL_EARTH_PATH, "natural_earth", same_n, exclude_id)
330
+ cross = _query_source(DIVISIONS_AREA_PATH, "divisions_area", cross_n, "")
331
+
332
+ return same + cross
333
+
334
+
335
+ def generate_adjacency_sample(
336
+ con: duckdb.DuckDBPyConnection,
337
+ adjacency_df: pd.DataFrame,
338
+ sample_id: str
339
+ ) -> Optional[TrainingSample]:
340
+ """Generate a sample for adjacency task."""
341
+
342
+ anchor = sample_adjacency_anchor(adjacency_df)
343
+ if not anchor:
344
+ return None
345
+
346
+ # Build SQL
347
+ sql = f"""WITH a AS (
348
+ SELECT geometry FROM read_parquet('divisions_area')
349
+ WHERE id = '{anchor['anchor_id']}'
350
+ )
351
+ SELECT b.id, b.names."primary" AS name, b.geometry
352
+ FROM read_parquet('divisions_area') AS b, a
353
+ WHERE b.id != '{anchor['anchor_id']}'
354
+ AND b.subtype = '{anchor['target_subtype']}'
355
+ AND ST_Touches(a.geometry, b.geometry)"""
356
+
357
+ # Execute to verify
358
+ try:
359
+ result = con.execute(_for_execution(sql)).fetchdf()
360
+ if result.empty:
361
+ return None
362
+ except Exception as e:
363
+ print(f"SQL execution failed: {e}")
364
+ return None
365
+
366
+ # Build candidates
367
+ candidates = build_candidate_list(
368
+ con,
369
+ anchor['anchor_id'],
370
+ anchor['anchor_name'],
371
+ "divisions_area",
372
+ num_candidates=10,
373
+ difficulty="medium"
374
+ )
375
+
376
+ # Find which candidate is the true anchor
377
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['anchor_id']]
378
+
379
+ # Generate question
380
+ question = f"Which {anchor['target_subtype']}s border {anchor['anchor_name']}?"
381
+
382
+ return TrainingSample(
383
+ id=sample_id,
384
+ question=question,
385
+ candidates=candidates,
386
+ target={
387
+ "selected_candidates": selected_candidate_ids,
388
+ "sql": sql
389
+ },
390
+ metadata={
391
+ "task_family": "adjacency",
392
+ "sql_difficulty": "medium",
393
+ "grounding_difficulty": "medium",
394
+ "template_id": "adj_02",
395
+ "num_candidates": len(candidates),
396
+ "anchor_source": "divisions_area",
397
+ "sql_verified": True
398
+ }
399
+ )
400
+
401
+
402
+ def generate_containment_sample(
403
+ con: duckdb.DuckDBPyConnection,
404
+ containment_df: pd.DataFrame,
405
+ sample_id: str
406
+ ) -> Optional[TrainingSample]:
407
+ """Generate a sample for containment task."""
408
+
409
+ anchor = sample_containment_anchor(containment_df)
410
+ if not anchor:
411
+ return None
412
+
413
+ # Build SQL
414
+ sql = f"""WITH a AS (
415
+ SELECT geometry FROM read_parquet('divisions_area')
416
+ WHERE id = '{anchor['container_id']}'
417
+ )
418
+ SELECT b.id, b.names."primary" AS name, b.geometry
419
+ FROM read_parquet('divisions_area') AS b, a
420
+ WHERE b.id != '{anchor['container_id']}'
421
+ AND b.subtype = '{anchor['contained_subtype']}'
422
+ AND ST_Within(b.geometry, a.geometry)"""
423
+
424
+ # Execute to verify
425
+ try:
426
+ result = con.execute(_for_execution(sql)).fetchdf()
427
+ if result.empty:
428
+ return None
429
+ except Exception as e:
430
+ print(f"SQL execution failed: {e}")
431
+ return None
432
+
433
+ # Build candidates
434
+ candidates = build_candidate_list(
435
+ con,
436
+ anchor['container_id'],
437
+ anchor['container_name'],
438
+ "divisions_area",
439
+ num_candidates=10,
440
+ difficulty="medium"
441
+ )
442
+
443
+ # Find which candidate is the true anchor
444
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['container_id']]
445
+
446
+ # Generate question
447
+ question = f"What {anchor['contained_subtype']}s are in {anchor['container_name']}?"
448
+
449
+ return TrainingSample(
450
+ id=sample_id,
451
+ question=question,
452
+ candidates=candidates,
453
+ target={
454
+ "selected_candidates": selected_candidate_ids,
455
+ "sql": sql
456
+ },
457
+ metadata={
458
+ "task_family": "containment",
459
+ "sql_difficulty": "medium",
460
+ "grounding_difficulty": "medium",
461
+ "template_id": "contain_01",
462
+ "num_candidates": len(candidates),
463
+ "anchor_source": "divisions_area",
464
+ "sql_verified": True
465
+ }
466
+ )
467
+
468
+
469
+ def sample_random_entity(
470
+ con: duckdb.DuckDBPyConnection,
471
+ inventory_df: pd.DataFrame,
472
+ source: str
473
+ ) -> Optional[Dict[str, Any]]:
474
+ """Sample a random entity from inventory."""
475
+ if inventory_df.empty:
476
+ return None
477
+
478
+ row = inventory_df.sample(n=1).iloc[0]
479
+ return {
480
+ 'id': row['id'],
481
+ 'name': row['name'],
482
+ 'subtype': row.get('subtype'),
483
+ 'country': row.get('country'),
484
+ 'source': source
485
+ }
486
+
487
+
488
+ def generate_template_based_sample(
489
+ con: duckdb.DuckDBPyConnection,
490
+ template: SQLTemplate,
491
+ tables: Dict[str, pd.DataFrame],
492
+ sample_id: str
493
+ ) -> Optional[TrainingSample]:
494
+ """Generate a sample based on a SQL template."""
495
+
496
+ # Sample anchor based on template requirements
497
+ if template.family == "direct_lookup":
498
+ # Just pick a random entity
499
+ if template.anchor_source == "divisions_area":
500
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
501
+ else:
502
+ anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
503
+
504
+ if not anchor:
505
+ return None
506
+
507
+ # Render SQL
508
+ sql = template.sql_template.format(
509
+ anchor_id=anchor['id']
510
+ )
511
+
512
+ # Build candidates
513
+ candidates = build_candidate_list(
514
+ con, anchor['id'], anchor['name'], anchor['source'],
515
+ num_candidates=10, difficulty="easy"
516
+ )
517
+
518
+ # Question
519
+ question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
520
+
521
+ elif template.family == "adjacency":
522
+ anchor = sample_adjacency_anchor(tables['adjacency_pairs'])
523
+ if not anchor:
524
+ return None
525
+
526
+ sql = template.sql_template.format(
527
+ anchor_id=anchor['anchor_id'],
528
+ target_subtype=anchor['target_subtype']
529
+ )
530
+
531
+ candidates = build_candidate_list(
532
+ con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
533
+ num_candidates=10, difficulty="medium"
534
+ )
535
+
536
+ question = random.choice(template.question_hints).format(
537
+ anchor_name=anchor['anchor_name'],
538
+ target_subtype=anchor['target_subtype']
539
+ )
540
+
541
+ elif template.family == "containment":
542
+ anchor = sample_containment_anchor(tables['containment_pairs'])
543
+ if not anchor:
544
+ return None
545
+
546
+ sql = template.sql_template.format(
547
+ anchor_id=anchor['container_id'],
548
+ target_subtype=anchor['contained_subtype']
549
+ )
550
+
551
+ candidates = build_candidate_list(
552
+ con, anchor['container_id'], anchor['container_name'], 'divisions_area',
553
+ num_candidates=10, difficulty="medium"
554
+ )
555
+
556
+ question = random.choice(template.question_hints).format(
557
+ anchor_name=anchor['container_name'],
558
+ target_subtype=anchor['contained_subtype']
559
+ )
560
+
561
+ elif template.family == "intersection":
562
+ if template.anchor_source == "natural_earth":
563
+ anchor = sample_cross_source_anchor(tables['cross_source_relations'])
564
+ if not anchor:
565
+ return None
566
+
567
+ sql = template.sql_template.format(
568
+ anchor_id=anchor['natural_id'],
569
+ target_subtype='country'
570
+ )
571
+
572
+ candidates = build_candidate_list(
573
+ con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
574
+ num_candidates=10, difficulty="medium"
575
+ )
576
+
577
+ question = random.choice(template.question_hints).format(
578
+ anchor_name=anchor['natural_name'],
579
+ target_subtype='country'
580
+ )
581
+ else:
582
+ # Same-source intersection
583
+ anchor = sample_intersection_anchor(tables['intersection_pairs'])
584
+ if not anchor:
585
+ return None
586
+
587
+ # Use a generic subtype if not available
588
+ target_subtype = anchor.get('target_subtype') or 'region'
589
+
590
+ sql = template.sql_template.format(
591
+ anchor_id=anchor['anchor_id'],
592
+ target_subtype=target_subtype
593
+ )
594
+
595
+ candidates = build_candidate_list(
596
+ con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
597
+ num_candidates=10, difficulty="medium"
598
+ )
599
+
600
+ question = random.choice(template.question_hints).format(
601
+ anchor_name=anchor['anchor_name'],
602
+ target_subtype=target_subtype
603
+ )
604
+
605
+ elif template.family == "set_operations":
606
+ if template.template_id == "union_03":
607
+ # 3-anchor union by ID — candidates: 3 per anchor (9 total)
608
+ anchors = [
609
+ sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
610
+ for _ in range(3)
611
+ ]
612
+ if any(a is None for a in anchors):
613
+ return None
614
+ anchor1, anchor2, anchor3 = anchors
615
+
616
+ sql = template.sql_template.format(
617
+ anchor_id_1=anchor1['id'],
618
+ anchor_id_2=anchor2['id'],
619
+ anchor_id_3=anchor3['id'],
620
+ )
621
+
622
+ per_anchor = 3
623
+ cands = [
624
+ build_candidate_list(con, a['id'], a['name'], 'divisions_area',
625
+ num_candidates=per_anchor, difficulty="medium")
626
+ for a in anchors
627
+ ]
628
+ candidates = _merge_candidate_lists(*cands, max_total=9)
629
+
630
+ question = random.choice(template.question_hints).format(
631
+ anchor_1_name=anchor1['name'],
632
+ anchor_2_name=anchor2['name'],
633
+ anchor_3_name=anchor3['name'],
634
+ )
635
+
636
+ elif template.template_id in ("contain_multi_01", "contain_multi_02"):
637
+ # country IN clause — 2 or 3 anchors, each contributes its country code
638
+ num_a = 3 if template.template_id == "contain_multi_02" else 2
639
+ anchors = [
640
+ sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
641
+ for _ in range(num_a)
642
+ ]
643
+ if any(a is None for a in anchors):
644
+ return None
645
+
646
+ countries = [a.get('country') or 'US' for a in anchors]
647
+ target_subtype = random.choice(['region', 'locality'])
648
+ per_anchor = 3 if num_a == 3 else 4
649
+
650
+ fmt_kwargs = dict(
651
+ target_subtype=target_subtype,
652
+ )
653
+ for i, c in enumerate(countries, 1):
654
+ fmt_kwargs[f'country_{i}'] = c
655
+
656
+ sql = template.sql_template.format(**fmt_kwargs)
657
+
658
+ cands = [
659
+ build_candidate_list(con, a['id'], a['name'], 'divisions_area',
660
+ num_candidates=per_anchor, difficulty="medium")
661
+ for a in anchors
662
+ ]
663
+ candidates = _merge_candidate_lists(*cands, max_total=num_a * per_anchor)
664
+
665
+ q_kwargs = dict(target_subtype=target_subtype)
666
+ for i, a in enumerate(anchors, 1):
667
+ q_kwargs[f'anchor_{i}_name'] = a['name']
668
+
669
+ question = random.choice(template.question_hints).format(**q_kwargs)
670
+
671
+ elif template.template_id == "union_02":
672
+ # Filtered union: ST_Union_Agg of contained sub-features
673
+ pair = sample_containment_anchor(tables['containment_pairs'])
674
+ if not pair:
675
+ return None
676
+
677
+ target_subtype = pair.get('contained_subtype', 'locality')
678
+ sql = template.sql_template.format(
679
+ anchor_id=pair['container_id'],
680
+ target_subtype=target_subtype,
681
+ )
682
+
683
+ candidates = build_candidate_list(
684
+ con, pair['container_id'], pair['container_name'], 'divisions_area',
685
+ num_candidates=10, difficulty="medium"
686
+ )
687
+
688
+ question = random.choice(template.question_hints).format(
689
+ anchor_name=pair['container_name'],
690
+ target_subtype=target_subtype,
691
+ )
692
+
693
+ else:
694
+ # union_01: 2-anchor union by ID — candidates: 5 per anchor
695
+ anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
696
+ anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
697
+ if not anchor1 or not anchor2:
698
+ return None
699
+
700
+ sql = template.sql_template.format(
701
+ anchor_id_1=anchor1['id'],
702
+ anchor_id_2=anchor2['id'],
703
+ )
704
+
705
+ cands1 = build_candidate_list(
706
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
707
+ num_candidates=5, difficulty="medium"
708
+ )
709
+ cands2 = build_candidate_list(
710
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
711
+ num_candidates=5, difficulty="medium"
712
+ )
713
+ candidates = _merge_candidate_lists(cands1, cands2, max_total=10)
714
+
715
+ question = random.choice(template.question_hints).format(
716
+ anchor_1_name=anchor1['name'],
717
+ anchor_2_name=anchor2['name'],
718
+ )
719
+
720
+ elif template.family == "buffer":
721
+ # Buffer operations
722
+ # Kilometre distances used by buffer_01 and buffer_03 templates.
723
+ # Metre distances used by buffer_02 and buffer_04 templates.
724
+ # The template SQL divides by 111 320 to convert to degrees.
725
+ _buffer_km_choices = [1, 2, 5, 10, 25, 50, 100, 200]
726
+ _buffer_m_choices = [100, 250, 500, 1000, 2000, 5000]
727
+
728
+ if template.num_anchors == 1:
729
+ if template.anchor_source == "natural_earth":
730
+ anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
731
+ else:
732
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
733
+ if not anchor:
734
+ return None
735
+
736
+ # Choose unit based on which placeholder the template uses.
737
+ uses_km = "{buffer_km}" in template.sql_template
738
+ if uses_km:
739
+ buffer_val = random.choice(_buffer_km_choices)
740
+ fmt_kwargs = dict(
741
+ anchor_id=anchor['id'],
742
+ buffer_km=buffer_val,
743
+ )
744
+ q_kwargs = dict(anchor_name=anchor['name'], buffer_km=buffer_val)
745
+ else:
746
+ buffer_val = random.choice(_buffer_m_choices)
747
+ fmt_kwargs = dict(
748
+ anchor_id=anchor['id'],
749
+ buffer_m=buffer_val,
750
+ )
751
+ q_kwargs = dict(anchor_name=anchor['name'], buffer_m=buffer_val)
752
+
753
+ sql = template.sql_template.format(**fmt_kwargs)
754
+
755
+ candidates = build_candidate_list(
756
+ con, anchor['id'], anchor['name'], anchor['source'],
757
+ num_candidates=10, difficulty="medium"
758
+ )
759
+
760
+ question = random.choice(template.question_hints).format(**q_kwargs)
761
+ else:
762
+ # Two anchor buffer (union / set-op style) — kept for completeness.
763
+ anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
764
+ anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
765
+
766
+ if not anchor1 or not anchor2:
767
+ return None
768
+
769
+ buffer_val = random.choice(_buffer_km_choices[:4]) # smaller range for two-anchor
770
+ sql = template.sql_template.format(
771
+ anchor_id_1=anchor1['id'],
772
+ anchor_id_2=anchor2['id'],
773
+ buffer_km=buffer_val,
774
+ )
775
+
776
+ candidates1 = build_candidate_list(
777
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
778
+ num_candidates=5, difficulty="medium"
779
+ )
780
+ candidates2 = build_candidate_list(
781
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
782
+ num_candidates=5, difficulty="medium"
783
+ )
784
+
785
+ candidates = candidates1 + candidates2
786
+ seen_ids = set()
787
+ unique_candidates = []
788
+ for c in candidates:
789
+ if c.id not in seen_ids:
790
+ unique_candidates.append(c)
791
+ seen_ids.add(c.id)
792
+ candidates = unique_candidates[:10]
793
+
794
+ for i, c in enumerate(candidates, 1):
795
+ c.candidate_id = f"c{i}"
796
+
797
+ question = random.choice(template.question_hints).format(
798
+ anchor_1_name=anchor1['name'],
799
+ anchor_2_name=anchor2['name'],
800
+ buffer_km=buffer_val,
801
+ )
802
+
803
+ elif template.family == "partial_selection":
804
+ # Partial selection (northern half, clipping, etc.)
805
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
806
+ if not anchor:
807
+ return None
808
+
809
+ if template.num_anchors == 1:
810
+ sql = template.sql_template.format(
811
+ anchor_id=anchor['id'],
812
+ )
813
+ question = random.choice(template.question_hints).format(
814
+ anchor_name=anchor['name'],
815
+ )
816
+ candidates = build_candidate_list(
817
+ con, anchor['id'], anchor['name'], 'divisions_area',
818
+ num_candidates=10, difficulty="hard",
819
+ )
820
+ else:
821
+ # Mixed-source clip: division intersected with a natural_earth feature.
822
+ # Use cross_source_relations so the pair is guaranteed to intersect —
823
+ # random sampling almost never produces an intersecting pair.
824
+ cs_df = tables.get('cross_source_relations', pd.DataFrame())
825
+ if cs_df.empty:
826
+ return None
827
+ row = cs_df.sample(n=1).iloc[0]
828
+ clip_feature = {
829
+ 'id': row['natural_id'],
830
+ 'name': row['natural_name'],
831
+ 'source': 'natural_earth',
832
+ }
833
+ # Override the division anchor with the paired division so the
834
+ # ST_Intersects check in the SQL is guaranteed to pass.
835
+ anchor = {
836
+ 'id': row['division_id'],
837
+ 'name': row['division_name'],
838
+ 'source': 'divisions_area',
839
+ }
840
+
841
+ sql = template.sql_template.format(
842
+ anchor_id=anchor['id'],
843
+ clip_feature_id=clip_feature['id'],
844
+ )
845
+ question = random.choice(template.question_hints).format(
846
+ anchor_name=anchor['name'],
847
+ clip_feature_name=clip_feature['name'],
848
+ )
849
+ # Build candidates for BOTH anchors so the model sees both IDs
850
+ # in context and learns to pick the right one for each placeholder.
851
+ div_cands = build_candidate_list(
852
+ con, anchor['id'], anchor['name'], 'divisions_area',
853
+ num_candidates=5, difficulty="hard",
854
+ )
855
+ ne_cands = build_candidate_list(
856
+ con, clip_feature['id'], clip_feature['name'], 'natural_earth',
857
+ num_candidates=5, difficulty="hard",
858
+ )
859
+ candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
860
+
861
+ elif template.family == "aggregation":
862
+ top_n = random.choice([3, 5, 10])
863
+ target_subtype = random.choice(['locality', 'region'])
864
+
865
+ if template.template_id in ['agg_03', 'agg_04']:
866
+ # Country-level aggregation: SQL uses country code, not anchor id.
867
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
868
+ if not anchor:
869
+ return None
870
+
871
+ country = anchor.get('country') or 'US'
872
+
873
+ sql = template.sql_template.format(
874
+ country=country,
875
+ target_subtype=target_subtype,
876
+ top_n=top_n,
877
+ )
878
+
879
+ candidates = build_candidate_list(
880
+ con, anchor['id'], anchor['name'], 'divisions_area',
881
+ num_candidates=10, difficulty="hard"
882
+ )
883
+
884
+ question = random.choice(template.question_hints).format(
885
+ top_n=top_n,
886
+ target_subtype=target_subtype,
887
+ anchor_name=anchor['name'],
888
+ )
889
+ else:
890
+ # Containment-based aggregation: anchor is the container region.
891
+ anchor = sample_containment_anchor(tables['containment_pairs'])
892
+ if not anchor:
893
+ return None
894
+
895
+ sql = template.sql_template.format(
896
+ anchor_id=anchor['container_id'],
897
+ target_subtype=target_subtype,
898
+ top_n=top_n,
899
+ )
900
+
901
+ candidates = build_candidate_list(
902
+ con, anchor['container_id'], anchor['container_name'], 'divisions_area',
903
+ num_candidates=10, difficulty="hard"
904
+ )
905
+
906
+ question = random.choice(template.question_hints).format(
907
+ top_n=top_n,
908
+ target_subtype=target_subtype,
909
+ anchor_name=anchor['container_name'],
910
+ )
911
+
912
+ elif template.family == "chained":
913
+ # Use pre-filtered coastal/landlocked containment pairs so the SQL
914
+ # verification step doesn't constantly return empty results.
915
+ if template.template_id == "chained_01":
916
+ table_key = 'coastal_containment_pairs'
917
+ elif template.template_id == "chained_02":
918
+ table_key = 'landlocked_containment_pairs'
919
+ else:
920
+ table_key = 'containment_pairs'
921
+ anchor = sample_containment_anchor(tables.get(table_key, tables['containment_pairs']))
922
+ if not anchor:
923
+ return None
924
+
925
+ target_subtype = anchor.get('contained_subtype', 'locality')
926
+
927
+ sql = template.sql_template.format(
928
+ anchor_id=anchor['container_id'],
929
+ target_subtype=target_subtype,
930
+ )
931
+
932
+ candidates = build_candidate_list(
933
+ con, anchor['container_id'], anchor['container_name'], 'divisions_area',
934
+ num_candidates=10, difficulty="hard"
935
+ )
936
+
937
+ question = random.choice(template.question_hints).format(
938
+ anchor_name=anchor['container_name'],
939
+ target_subtype=target_subtype,
940
+ )
941
+
942
+ elif template.family == "multi_adjacency":
943
+ # Use common_neighbor_pairs so anchor1 and anchor2 are guaranteed to
944
+ # share at least one touching neighbour — SQL will return non-empty.
945
+ cn_df = tables.get('common_neighbor_pairs', pd.DataFrame())
946
+ if cn_df.empty:
947
+ return None
948
+ row = cn_df.sample(n=1).iloc[0]
949
+ anchor1 = {'id': row['anchor_id_1'], 'name': row['anchor_name_1'], 'source': 'divisions_area'}
950
+ anchor2 = {'id': row['anchor_id_2'], 'name': row['anchor_name_2'], 'source': 'divisions_area'}
951
+
952
+ sql = template.sql_template.format(
953
+ anchor_id_1=anchor1['id'],
954
+ anchor_id_2=anchor2['id'],
955
+ )
956
+
957
+ candidates1 = build_candidate_list(
958
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
959
+ num_candidates=5, difficulty="medium"
960
+ )
961
+ candidates2 = build_candidate_list(
962
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
963
+ num_candidates=5, difficulty="medium"
964
+ )
965
+ candidates = _merge_candidate_lists(candidates1, candidates2)
966
+
967
+ question = random.choice(template.question_hints).format(
968
+ anchor_1_name=anchor1['name'],
969
+ anchor_2_name=anchor2['name'],
970
+ )
971
+
972
+ elif template.family == "difference":
973
+ if template.anchor_source == "mixed":
974
+ # divisions_area anchor differenced against a natural_earth feature.
975
+ # Use cross_source_relations so the pair is guaranteed to intersect
976
+ # (ST_Difference on non-intersecting geometries is always equal to
977
+ # the original geometry — a trivial and uninformative sample).
978
+ cs_df = tables.get('cross_source_relations', pd.DataFrame())
979
+ if cs_df.empty:
980
+ return None
981
+ row = cs_df.sample(n=1).iloc[0]
982
+ anchor = {
983
+ 'id': row['division_id'],
984
+ 'name': row['division_name'],
985
+ 'source': 'divisions_area',
986
+ }
987
+ clip_feature = {
988
+ 'id': row['natural_id'],
989
+ 'name': row['natural_name'],
990
+ 'source': 'natural_earth',
991
+ }
992
+
993
+ sql = template.sql_template.format(
994
+ anchor_id=anchor['id'],
995
+ clip_feature_id=clip_feature['id'],
996
+ )
997
+ question = random.choice(template.question_hints).format(
998
+ anchor_name=anchor['name'],
999
+ clip_feature_name=clip_feature['name'],
1000
+ )
1001
+ # Build candidates for BOTH anchors — model must see both IDs
1002
+ # to correctly assign anchor_id vs clip_feature_id in the SQL.
1003
+ div_cands = build_candidate_list(
1004
+ con, anchor['id'], anchor['name'], 'divisions_area',
1005
+ num_candidates=5, difficulty="hard",
1006
+ )
1007
+ ne_cands = build_candidate_list(
1008
+ con, clip_feature['id'], clip_feature['name'], 'natural_earth',
1009
+ num_candidates=5, difficulty="hard",
1010
+ )
1011
+ candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
1012
+
1013
+ else:
1014
+ # Two divisions_area anchors — use containment pairs so the
1015
+ # smaller (contained) is guaranteed to intersect the larger.
1016
+ pair = sample_containment_anchor(tables['containment_pairs'])
1017
+ if not pair:
1018
+ return None
1019
+
1020
+ anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
1021
+ anchor2_row = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
1022
+ if not anchor2_row:
1023
+ return None
1024
+ anchor2 = anchor2_row
1025
+
1026
+ sql = template.sql_template.format(
1027
+ anchor_id_1=anchor1['id'],
1028
+ anchor_id_2=anchor2['id'],
1029
+ )
1030
+
1031
+ candidates1 = build_candidate_list(
1032
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
1033
+ num_candidates=5, difficulty="medium"
1034
+ )
1035
+ candidates2 = build_candidate_list(
1036
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
1037
+ num_candidates=5, difficulty="medium"
1038
+ )
1039
+ candidates = _merge_candidate_lists(candidates1, candidates2)
1040
+
1041
+ question = random.choice(template.question_hints).format(
1042
+ anchor_1_name=anchor1['name'],
1043
+ anchor_2_name=anchor2['name'],
1044
+ )
1045
+
1046
+ elif template.family == "border_corridor":
1047
+ # Buffered border zone — needs two anchors that actually touch.
1048
+ pair = sample_adjacency_anchor(tables['adjacency_pairs'])
1049
+ if not pair:
1050
+ return None
1051
+
1052
+ # The adjacency table only records one direction; sample a second
1053
+ # anchor that is known to be adjacent to the first.
1054
+ anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
1055
+
1056
+ # Find a random neighbour of anchor1 from adjacency pairs
1057
+ neighbours = tables['adjacency_pairs']
1058
+ neighbours = neighbours[neighbours['anchor_id'] == anchor1['id']]
1059
+ if neighbours.empty:
1060
+ return None
1061
+ nb_row = neighbours.sample(n=1).iloc[0]
1062
+ anchor2 = {'id': nb_row.get('target_id', nb_row['anchor_id']), 'name': nb_row.get('target_name', nb_row['anchor_name'])}
1063
+ if anchor1['id'] == anchor2['id']:
1064
+ return None
1065
+
1066
+ buffer_val = random.choice([5, 10, 25, 50])
1067
+
1068
+ sql = template.sql_template.format(
1069
+ anchor_id_1=anchor1['id'],
1070
+ anchor_id_2=anchor2['id'],
1071
+ buffer_km=buffer_val,
1072
+ )
1073
+
1074
+ candidates1 = build_candidate_list(
1075
+ con, anchor1['id'], anchor1['name'], 'divisions_area',
1076
+ num_candidates=5, difficulty="medium"
1077
+ )
1078
+ candidates2 = build_candidate_list(
1079
+ con, anchor2['id'], anchor2['name'], 'divisions_area',
1080
+ num_candidates=5, difficulty="medium"
1081
+ )
1082
+ candidates = _merge_candidate_lists(candidates1, candidates2)
1083
+
1084
+ question = random.choice(template.question_hints).format(
1085
+ anchor_1_name=anchor1['name'],
1086
+ anchor_2_name=anchor2['name'],
1087
+ buffer_km=buffer_val,
1088
+ )
1089
+
1090
+ elif template.family == "window_function":
1091
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
1092
+ if not anchor:
1093
+ return None
1094
+
1095
+ country = anchor.get('country') or 'US'
1096
+ target_subtype = random.choice(['locality', 'neighborhood'])
1097
+
1098
+ sql = template.sql_template.format(
1099
+ country=country,
1100
+ target_subtype=target_subtype,
1101
+ )
1102
+
1103
+ candidates = build_candidate_list(
1104
+ con, anchor['id'], anchor['name'], 'divisions_area',
1105
+ num_candidates=10, difficulty="hard"
1106
+ )
1107
+
1108
+ question = random.choice(template.question_hints).format(
1109
+ anchor_name=anchor['name'],
1110
+ target_subtype=target_subtype,
1111
+ )
1112
+
1113
+ elif template.family == "attribute_filter":
1114
+ anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
1115
+ if not anchor:
1116
+ return None
1117
+
1118
+ country = anchor.get('country') or 'US'
1119
+ target_subtype = template.target_subtype or random.choice(['dependency', 'region', 'locality'])
1120
+
1121
+ sql = template.sql_template.format(
1122
+ country=country,
1123
+ target_subtype=target_subtype,
1124
+ )
1125
+
1126
+ candidates = build_candidate_list(
1127
+ con, anchor['id'], anchor['name'], 'divisions_area',
1128
+ num_candidates=10, difficulty="medium"
1129
+ )
1130
+
1131
+ question = random.choice(template.question_hints).format(
1132
+ anchor_name=anchor['name'],
1133
+ target_subtype=target_subtype,
1134
+ country=country,
1135
+ )
1136
+
1137
+ else:
1138
+ # Skip unsupported families
1139
+ return None
1140
+
1141
+ # Execute SQL to verify
1142
+ try:
1143
+ result = con.execute(_for_execution(sql)).fetchdf()
1144
+ if result.empty:
1145
+ return None
1146
+ except Exception as e:
1147
+ # Errors are tracked in worker return, no need to print
1148
+ return None
1149
+
1150
+ # Collect every anchor ID that appears in the generated SQL so we can
1151
+ # mark them as the "selected" candidates in the training sample.
1152
+ _multi_anchor_families = {"set_operations", "multi_adjacency", "difference", "border_corridor"}
1153
+
1154
+ # Mixed partial_selection (partial_05) and mixed difference (diff_02) each
1155
+ # have two anchors from different sources — both must be marked selected.
1156
+ _is_mixed_two_anchor = (
1157
+ template.anchor_source == "mixed" and template.num_anchors == 2
1158
+ )
1159
+
1160
+ if template.family in _multi_anchor_families and template.num_anchors >= 2:
1161
+ anchor_ids: set = set()
1162
+ for var in ("anchor1", "anchor2", "anchor3"):
1163
+ obj = locals().get(var)
1164
+ if obj:
1165
+ anchor_ids.add(obj.get("id", ""))
1166
+ if "anchors" in locals():
1167
+ for a in locals()["anchors"]:
1168
+ if a:
1169
+ anchor_ids.add(a.get("id", ""))
1170
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id in anchor_ids]
1171
+
1172
+ elif _is_mixed_two_anchor:
1173
+ # partial_05 / diff_02: anchor (division) + clip_feature (natural_earth)
1174
+ mixed_ids = {anchor.get("id", ""), clip_feature.get("id", "")}
1175
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id in mixed_ids]
1176
+
1177
+ else:
1178
+ anchor_id_to_find = (
1179
+ anchor.get('anchor_id')
1180
+ or anchor.get('container_id')
1181
+ or anchor.get('natural_id')
1182
+ or anchor.get('id')
1183
+ )
1184
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
1185
+
1186
+ return TrainingSample(
1187
+ id=sample_id,
1188
+ question=question,
1189
+ candidates=candidates,
1190
+ target={
1191
+ "selected_candidates": selected_candidate_ids,
1192
+ "sql": sql
1193
+ },
1194
+ metadata={
1195
+ "task_family": template.family,
1196
+ "sql_difficulty": template.sql_difficulty,
1197
+ "grounding_difficulty": "medium",
1198
+ "template_id": template.template_id,
1199
+ "num_candidates": len(candidates),
1200
+ "anchor_source": template.anchor_source,
1201
+ "sql_verified": True
1202
+ }
1203
+ )
1204
+
1205
+
1206
+ def generate_cross_source_sample(
1207
+ con: duckdb.DuckDBPyConnection,
1208
+ cross_source_df: pd.DataFrame,
1209
+ sample_id: str
1210
+ ) -> Optional[TrainingSample]:
1211
+ """Generate a sample for cross-source intersection task."""
1212
+
1213
+ anchor = sample_cross_source_anchor(cross_source_df)
1214
+ if not anchor:
1215
+ return None
1216
+
1217
+ # Build SQL (natural feature -> divisions)
1218
+ sql = f"""WITH a AS (
1219
+ SELECT geometry FROM read_parquet('natural_earth')
1220
+ WHERE id = '{anchor['natural_id']}'
1221
+ )
1222
+ SELECT b.id, b.names."primary" AS name, b.geometry
1223
+ FROM read_parquet('divisions_area') AS b, a
1224
+ WHERE b.subtype = 'country'
1225
+ AND ST_Intersects(b.geometry, a.geometry)"""
1226
+
1227
+ # Execute to verify
1228
+ try:
1229
+ result = con.execute(_for_execution(sql)).fetchdf()
1230
+ if result.empty:
1231
+ return None
1232
+ except Exception as e:
1233
+ print(f"SQL execution failed: {e}")
1234
+ return None
1235
+
1236
+ # Build candidates for natural feature
1237
+ candidates = build_candidate_list(
1238
+ con,
1239
+ anchor['natural_id'],
1240
+ anchor['natural_name'],
1241
+ "natural_earth",
1242
+ num_candidates=10,
1243
+ difficulty="medium"
1244
+ )
1245
+
1246
+ # Find which candidate is the true anchor
1247
+ selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['natural_id']]
1248
+
1249
+ # Generate question
1250
+ question = f"Which countries intersect the {anchor['natural_name']}?"
1251
+
1252
+ return TrainingSample(
1253
+ id=sample_id,
1254
+ question=question,
1255
+ candidates=candidates,
1256
+ target={
1257
+ "selected_candidates": selected_candidate_ids,
1258
+ "sql": sql
1259
+ },
1260
+ metadata={
1261
+ "task_family": "intersection",
1262
+ "sql_difficulty": "medium-hard",
1263
+ "grounding_difficulty": "medium",
1264
+ "template_id": "intersect_02",
1265
+ "num_candidates": len(candidates),
1266
+ "anchor_source": "natural_earth",
1267
+ "sql_verified": True
1268
+ }
1269
+ )
1270
+
1271
+
1272
+ def generate_sample_batch_worker(args):
1273
+ """Worker function that processes a batch of work items with a single DuckDB connection.
1274
+
1275
+ Initializes DuckDB, spatial extension, templates module, and relation tables
1276
+ ONCE per batch, then processes all items sequentially.
1277
+ """
1278
+ from pathlib import Path
1279
+
1280
+ work_items, intermediate_dir_str = args
1281
+
1282
+ # Convert string back to Path
1283
+ intermediate_dir = Path(intermediate_dir_str)
1284
+
1285
+ # Initialize DuckDB ONCE for the entire batch
1286
+ con = duckdb.connect()
1287
+ con.execute("SET enable_progress_bar=false")
1288
+ con.execute("INSTALL spatial")
1289
+ con.execute("LOAD spatial")
1290
+
1291
+ # Load relation tables ONCE
1292
+ tables = load_relation_tables(intermediate_dir, quiet=True)
1293
+
1294
+ # Process all items in batch
1295
+ results = []
1296
+ for family, template_dict, sample_id, _ in work_items:
1297
+ # Reconstruct template from dict (sql_templates is already imported at module level)
1298
+ template = sql_templates.SQLTemplate(**template_dict)
1299
+ try:
1300
+ sample = generate_template_based_sample(con, template, tables, sample_id)
1301
+ if sample:
1302
+ results.append((sample, family, template.template_id, None))
1303
+ else:
1304
+ results.append((None, family, template.template_id, "Empty result"))
1305
+ except Exception as e:
1306
+ results.append((None, family, template_dict.get('template_id', 'unknown'), str(e)))
1307
+
1308
+ con.close()
1309
+ return results
1310
+
1311
+
1312
+ def generate_batch_core(
1313
+ work_items: List[tuple],
1314
+ intermediate_dir: str,
1315
+ ) -> List[Dict[str, Any]]:
1316
+ """Standalone batch worker usable from Modal or any remote context.
1317
+
1318
+ Data paths are resolved via GAZET_DATA_DIR env var (set in Modal image).
1319
+
1320
+ Args:
1321
+ work_items: List of (family, template_dict, sample_id, _) tuples
1322
+ intermediate_dir: Path to intermediate dir with relation parquets
1323
+
1324
+ Returns:
1325
+ List of dicts with keys: sample (dict or None), family, template_id, error
1326
+ """
1327
+ from pathlib import Path as _Path
1328
+ intermediate = _Path(intermediate_dir)
1329
+
1330
+ con = duckdb.connect()
1331
+ con.execute("SET enable_progress_bar=false")
1332
+ con.execute("INSTALL spatial")
1333
+ con.execute("LOAD spatial")
1334
+
1335
+ tables = load_relation_tables(intermediate, quiet=True)
1336
+
1337
+ results = []
1338
+ for family, template_dict, sample_id, _ in work_items:
1339
+ template = sql_templates.SQLTemplate(**template_dict)
1340
+ try:
1341
+ sample = generate_template_based_sample(con, template, tables, sample_id)
1342
+ if sample:
1343
+ results.append({
1344
+ "sample": sample.model_dump(),
1345
+ "family": family,
1346
+ "template_id": template.template_id,
1347
+ "error": None,
1348
+ })
1349
+ else:
1350
+ results.append({
1351
+ "sample": None,
1352
+ "family": family,
1353
+ "template_id": template.template_id,
1354
+ "error": "Empty result",
1355
+ })
1356
+ except Exception as e:
1357
+ results.append({
1358
+ "sample": None,
1359
+ "family": family,
1360
+ "template_id": template_dict.get('template_id', 'unknown'),
1361
+ "error": str(e),
1362
+ })
1363
+
1364
+ con.close()
1365
+ return results
1366
+
1367
+
1368
+ def prepare_work_items(
1369
+ target_counts: Dict[str, int],
1370
+ retry_multiplier: int = 2,
1371
+ start_counter: int = 1,
1372
+ intermediate_dir_str: str = "",
1373
+ ) -> List[tuple]:
1374
+ """Prepare shuffled work items for sample generation.
1375
+
1376
+ Returns list of (family, template_dict, sample_id, intermediate_dir_str) tuples.
1377
+ Reusable by both local main() and Modal orchestrator.
1378
+ """
1379
+ work_items = []
1380
+ sample_counter = start_counter
1381
+
1382
+ for family, target_count in target_counts.items():
1383
+ if target_count == 0:
1384
+ continue
1385
+
1386
+ family_templates = [t for t in TEMPLATES if t.family == family]
1387
+ if not family_templates:
1388
+ print(f"No templates found for {family}, skipping...")
1389
+ continue
1390
+
1391
+ for _ in range(target_count * retry_multiplier):
1392
+ template = random.choice(family_templates)
1393
+ template_dict = {
1394
+ 'template_id': template.template_id,
1395
+ 'family': template.family,
1396
+ 'sql_difficulty': template.sql_difficulty,
1397
+ 'anchor_source': template.anchor_source,
1398
+ 'num_anchors': template.num_anchors,
1399
+ 'sql_template': template.sql_template,
1400
+ 'question_hints': template.question_hints,
1401
+ 'target_subtype': template.target_subtype,
1402
+ 'requires_buffer': template.requires_buffer,
1403
+ 'requires_aggregation': template.requires_aggregation
1404
+ }
1405
+ work_items.append((
1406
+ family,
1407
+ template_dict,
1408
+ f"sample_{sample_counter:06d}",
1409
+ intermediate_dir_str,
1410
+ ))
1411
+ sample_counter += 1
1412
+
1413
+ random.shuffle(work_items)
1414
+ return work_items
1415
+
1416
+
1417
+ def main():
1418
+ """Generate training samples."""
1419
+ global TARGET_COUNTS, MAX_WORKERS, RETRY_MULTIPLIER, APPEND_MODE
1420
+
1421
+ # Setup paths
1422
+ script_dir = Path(__file__).parent
1423
+ intermediate_dir = script_dir.parent / "intermediate"
1424
+ output_dir = script_dir.parent / "output"
1425
+
1426
+ output_dir.mkdir(exist_ok=True, parents=True)
1427
+
1428
+ # Load relation tables once to check availability
1429
+ print("Loading relation tables...")
1430
+ tables = load_relation_tables(intermediate_dir, quiet=False)
1431
+
1432
+ # Use configured target counts or defaults
1433
+ if TARGET_COUNTS is None:
1434
+ target_counts = {
1435
+ 'direct_lookup': 100,
1436
+ 'adjacency': 150,
1437
+ 'multi_adjacency': 75,
1438
+ 'containment': 100,
1439
+ 'intersection': 100,
1440
+ 'buffer': 100,
1441
+ 'chained': 150,
1442
+ 'difference': 75,
1443
+ 'border_corridor': 75,
1444
+ 'set_operations': 150,
1445
+ 'partial_selection': 75,
1446
+ 'aggregation': 100,
1447
+ 'window_function': 75,
1448
+ 'attribute_filter': 75,
1449
+ }
1450
+ else:
1451
+ target_counts = TARGET_COUNTS
1452
+
1453
+ # Load existing samples if in append mode
1454
+ existing_samples = []
1455
+ existing_sample_ids = set()
1456
+ jsonl_file = output_dir / "dataset_raw.jsonl"
1457
+
1458
+ if APPEND_MODE and jsonl_file.exists():
1459
+ print(f"\nAppend mode: Loading existing samples from {jsonl_file}")
1460
+ with open(jsonl_file, 'r') as f:
1461
+ for line in f:
1462
+ if line.strip():
1463
+ sample_data = json.loads(line)
1464
+ existing_samples.append(sample_data)
1465
+ existing_sample_ids.add(sample_data['id'])
1466
+ print(f" Found {len(existing_samples)} existing samples")
1467
+
1468
+ # Determine starting sample counter
1469
+ max_existing_id = max([int(s['id'].split('_')[1]) for s in existing_samples if s['id'].startswith('sample_')], default=0)
1470
+ sample_counter = max_existing_id + 1
1471
+ else:
1472
+ sample_counter = 1
1473
+
1474
+ # Prepare work items using shared helper
1475
+ work_items = prepare_work_items(
1476
+ target_counts=target_counts,
1477
+ retry_multiplier=RETRY_MULTIPLIER,
1478
+ start_counter=sample_counter,
1479
+ intermediate_dir_str=str(intermediate_dir),
1480
+ )
1481
+ starting_sample_counter = sample_counter
1482
+
1483
+ # Partition work items into batches (one per worker)
1484
+ num_workers = min(MAX_WORKERS, len(work_items))
1485
+ if num_workers == 0:
1486
+ print("No work items to process")
1487
+ return
1488
+ batch_size = (len(work_items) + num_workers - 1) // num_workers
1489
+ batches = []
1490
+ for i in range(0, len(work_items), batch_size):
1491
+ batch = work_items[i:i + batch_size]
1492
+ batches.append((batch, str(intermediate_dir)))
1493
+
1494
+ # Generate samples in parallel (one batch per worker)
1495
+ active_families = len([f for f in target_counts.values() if f > 0])
1496
+ print(f"\nGenerating {len(work_items)} samples across {active_families} families...")
1497
+ print(f" Split into {len(batches)} batches of ~{batch_size} items (1 DuckDB init per batch)")
1498
+ if APPEND_MODE and existing_samples:
1499
+ print(f"Appending: starting from sample_{starting_sample_counter:03d}")
1500
+
1501
+ all_samples = []
1502
+ family_progress = {f: {'success': 0, 'failed': 0} for f in target_counts.keys() if target_counts[f] > 0}
1503
+
1504
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
1505
+ # Submit one batch per worker
1506
+ futures = {executor.submit(generate_sample_batch_worker, batch): i for i, batch in enumerate(batches)}
1507
+
1508
+ # Collect results as batches complete
1509
+ batches_done = 0
1510
+ for future in as_completed(futures):
1511
+ try:
1512
+ batch_results = future.result()
1513
+ for sample, family, template_id, error in batch_results:
1514
+ if sample:
1515
+ all_samples.append(sample)
1516
+ family_progress[family]['success'] += 1
1517
+ else:
1518
+ family_progress[family]['failed'] += 1
1519
+ except Exception as e:
1520
+ print(f"\n Batch failed: {e}")
1521
+
1522
+ batches_done += 1
1523
+ total_done = sum(p['success'] + p['failed'] for p in family_progress.values())
1524
+ print(f"\r Progress: {total_done}/{len(work_items)} samples ({batches_done}/{len(batches)} batches) ", end='', flush=True)
1525
+
1526
+ print() # New line after progress
1527
+
1528
+ # Show distribution (keep all samples, no filtering)
1529
+ print("\nResults by family:")
1530
+ for family in sorted(family_progress.keys()):
1531
+ success = family_progress[family]['success']
1532
+ failed = family_progress[family]['failed']
1533
+ target = target_counts.get(family, 0)
1534
+ total = success + failed
1535
+ success_rate = (success / total * 100) if total > 0 else 0
1536
+ print(f" {family:20s}: {success:3d} success / {failed:3d} failed ({success_rate:5.1f}% success rate, target: {target})")
1537
+
1538
+ # Save combined JSONL (skip individual JSON files for speed at scale)
1539
+ print(f"\nSaving {len(all_samples)} new samples...")
1540
+ if APPEND_MODE and existing_samples:
1541
+ # Append to existing dataset
1542
+ print(f"Appending to existing dataset ({len(existing_samples)} existing samples)")
1543
+ with open(jsonl_file, 'a') as f:
1544
+ for sample in all_samples:
1545
+ f.write(json.dumps(sample.model_dump()) + '\n')
1546
+ total_samples = len(existing_samples) + len(all_samples)
1547
+ else:
1548
+ # Overwrite dataset
1549
+ with open(jsonl_file, 'w') as f:
1550
+ for sample in all_samples:
1551
+ f.write(json.dumps(sample.model_dump()) + '\n')
1552
+ total_samples = len(all_samples)
1553
+
1554
+ print(f"\nGenerated {len(all_samples)} new samples")
1555
+ print(f"Total dataset size: {total_samples} samples")
1556
+ print(f" Dataset: {jsonl_file}")
1557
+
1558
+
1559
+ if __name__ == "__main__":
1560
+ main()
dataset/scripts/sql_templates.py ADDED
@@ -0,0 +1,1651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL template definitions for synthetic data generation.
3
+
4
+ Geometry output convention
5
+ --------------------------
6
+ Every final SELECT wraps geometry with ST_AsGeoJSON():
7
+ ST_AsGeoJSON(geometry) AS geometry
8
+ This returns a GeoJSON string instead of raw WKB bytes, which is directly
9
+ JSON-serialisable and matches what the serving stack expects.
10
+
11
+ CTEs that compute intermediate geometries (used only for spatial predicates
12
+ or ST_Area) keep the column as raw GEOMETRY so DuckDB spatial functions work.
13
+
14
+ Buffer distance convention
15
+ --------------------------
16
+ All buffer templates use {buffer_km} or {buffer_m} (never degrees).
17
+ SQL converts to degrees: metres / 111_320.
18
+
19
+ Mixed-source candidates
20
+ -----------------------
21
+ generate_samples.py pads every candidate list with 50 % cross-source
22
+ distractors so the model always sees both source values and learns the
23
+ correct parquet path from the candidates table.
24
+
25
+ Template families
26
+ -----------------
27
+ direct_lookup Simple single-feature fetch by ID.
28
+ adjacency ST_Touches — features sharing a border.
29
+ multi_adjacency Features that simultaneously touch TWO anchors.
30
+ containment ST_Within / ST_Contains — hierarchical nesting.
31
+ intersection ST_Intersects — overlapping or crossing features.
32
+ buffer ST_Buffer — proximity zones in km or metres.
33
+ chained Containment + EXISTS/NOT EXISTS sea predicate.
34
+ difference ST_Difference — geometry subtraction.
35
+ border_corridor Buffered ST_Intersection of a shared border.
36
+ set_operations ST_Union_Agg — merging multiple geometries.
37
+ partial_selection Bbox clipping — directional halves or feature clips.
38
+ aggregation TOP-N by area with ORDER BY.
39
+ window_function ROW_NUMBER() OVER (PARTITION BY) — per-group ranking.
40
+ attribute_filter Pure attribute predicates: is_land, country, etc.
41
+ """
42
+
43
+ from dataclasses import dataclass
44
+ from typing import List, Literal
45
+
46
+
47
+ @dataclass
48
+ class SQLTemplate:
49
+ """SQL template for synthetic data generation."""
50
+
51
+ template_id: str
52
+ family: str
53
+ sql_difficulty: Literal["easy", "medium", "medium-hard", "hard"]
54
+ anchor_source: Literal["divisions_area", "natural_earth", "mixed"]
55
+ num_anchors: int
56
+ sql_template: str
57
+ question_hints: List[str]
58
+ target_subtype: str = None
59
+ requires_buffer: bool = False
60
+ requires_aggregation: bool = False
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Template catalog
65
+ # ---------------------------------------------------------------------------
66
+
67
+ TEMPLATES = [
68
+
69
+ # ── DIRECT LOOKUP ────────────────────────────────────────────────────────
70
+
71
+ SQLTemplate(
72
+ template_id="lookup_01",
73
+ family="direct_lookup",
74
+ sql_difficulty="easy",
75
+ anchor_source="divisions_area",
76
+ num_anchors=1,
77
+ sql_template=(
78
+ "SELECT ST_AsGeoJSON(geometry) AS geometry,"
79
+ " names.\"primary\" AS name, id, subtype, country"
80
+ " FROM read_parquet('divisions_area')"
81
+ " WHERE id = '{anchor_id}'"
82
+ ),
83
+ question_hints=[
84
+ "Show me {anchor_name}",
85
+ "Get the boundary of {anchor_name}",
86
+ "Find {anchor_name}",
87
+ "Where is {anchor_name}?",
88
+ "Give me the outline of {anchor_name}",
89
+ "Display {anchor_name} on a map",
90
+ "What does {anchor_name} look like?",
91
+ "I need the shape of {anchor_name}",
92
+ "Pull up {anchor_name}",
93
+ "Can you show {anchor_name}?",
94
+ "Map of {anchor_name}",
95
+ "{anchor_name} boundary",
96
+ "Locate {anchor_name} for me",
97
+ ],
98
+ ),
99
+
100
+ SQLTemplate(
101
+ template_id="lookup_02",
102
+ family="direct_lookup",
103
+ sql_difficulty="easy",
104
+ anchor_source="natural_earth",
105
+ num_anchors=1,
106
+ sql_template=(
107
+ "SELECT ST_AsGeoJSON(geometry) AS geometry,"
108
+ " names.\"primary\" AS name, id, subtype"
109
+ " FROM read_parquet('natural_earth')"
110
+ " WHERE id = '{anchor_id}'"
111
+ ),
112
+ question_hints=[
113
+ "Show me the {anchor_name}",
114
+ "Get {anchor_name}",
115
+ "Find the {anchor_name}",
116
+ "Where is the {anchor_name}?",
117
+ "Show the extent of the {anchor_name}",
118
+ "Give me the geometry of the {anchor_name}",
119
+ "Display the {anchor_name}",
120
+ "Pull up the {anchor_name}",
121
+ "I want to see the {anchor_name}",
122
+ "Map the {anchor_name}",
123
+ "How big is the {anchor_name}?",
124
+ "Outline of the {anchor_name}",
125
+ ],
126
+ ),
127
+
128
+ # ── ADJACENCY ────────────────────────────────────────────────────────────
129
+
130
+ SQLTemplate(
131
+ template_id="adj_01",
132
+ family="adjacency",
133
+ sql_difficulty="medium",
134
+ anchor_source="divisions_area",
135
+ num_anchors=1,
136
+ sql_template=(
137
+ "WITH a AS ("
138
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
139
+ ")"
140
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
141
+ " ST_AsGeoJSON(b.geometry) AS geometry"
142
+ " FROM read_parquet('divisions_area') AS b, a"
143
+ " WHERE b.id != '{anchor_id}'"
144
+ " AND ST_Touches(a.geometry, b.geometry)"
145
+ ),
146
+ question_hints=[
147
+ "Which regions border {anchor_name}?",
148
+ "What administrative units touch {anchor_name}?",
149
+ "List all places adjacent to {anchor_name}",
150
+ "What shares a border with {anchor_name}?",
151
+ "Neighbours of {anchor_name}",
152
+ "What is adjacent to {anchor_name}?",
153
+ "What surrounds {anchor_name}?",
154
+ "Places next to {anchor_name}",
155
+ "Everything bordering {anchor_name}",
156
+ ],
157
+ ),
158
+
159
+ SQLTemplate(
160
+ template_id="adj_02",
161
+ family="adjacency",
162
+ sql_difficulty="medium",
163
+ anchor_source="divisions_area",
164
+ num_anchors=1,
165
+ target_subtype="region",
166
+ sql_template=(
167
+ "WITH a AS ("
168
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
169
+ ")"
170
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
171
+ " ST_AsGeoJSON(b.geometry) AS geometry"
172
+ " FROM read_parquet('divisions_area') AS b, a"
173
+ " WHERE b.id != '{anchor_id}'"
174
+ " AND b.subtype = '{target_subtype}'"
175
+ " AND ST_Touches(a.geometry, b.geometry)"
176
+ ),
177
+ question_hints=[
178
+ "Which {target_subtype}s border {anchor_name}?",
179
+ "What {target_subtype}s share a border with {anchor_name}?",
180
+ "{target_subtype}s that touch {anchor_name}",
181
+ "Neighbouring {target_subtype}s of {anchor_name}",
182
+ "Which {target_subtype}s are adjacent to {anchor_name}?",
183
+ "{target_subtype}s along the {anchor_name} border",
184
+ "Find {target_subtype}s next to {anchor_name}",
185
+ ],
186
+ ),
187
+
188
+ SQLTemplate(
189
+ template_id="adj_03",
190
+ family="adjacency",
191
+ sql_difficulty="medium",
192
+ anchor_source="divisions_area",
193
+ num_anchors=1,
194
+ target_subtype="sea",
195
+ sql_template=(
196
+ "WITH a AS ("
197
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
198
+ ")"
199
+ " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
200
+ " ST_AsGeoJSON(n.geometry) AS geometry"
201
+ " FROM read_parquet('natural_earth') AS n, a"
202
+ " WHERE n.subtype IN ('ocean', 'sea')"
203
+ " AND ST_Touches(a.geometry, n.geometry)"
204
+ ),
205
+ question_hints=[
206
+ "Which seas touch {anchor_name}?",
207
+ "What seas border {anchor_name}?",
208
+ "Which bodies of water is {anchor_name} adjacent to?",
209
+ "What ocean or sea borders {anchor_name}?",
210
+ "Which oceans touch {anchor_name}?",
211
+ "What coastline does {anchor_name} have?",
212
+ "Which water bodies does {anchor_name} border?",
213
+ "Does {anchor_name} have access to the sea?",
214
+ "What ocean is {anchor_name} on?",
215
+ ],
216
+ ),
217
+
218
+ # ── MULTI-ADJACENCY ──────────────────────────────────────────────────────
219
+
220
+ SQLTemplate(
221
+ template_id="multi_adj_01",
222
+ family="multi_adjacency",
223
+ sql_difficulty="hard",
224
+ anchor_source="divisions_area",
225
+ num_anchors=2,
226
+ sql_template=(
227
+ "WITH a AS ("
228
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_1}'"
229
+ "),"
230
+ " b AS ("
231
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_2}'"
232
+ ")"
233
+ " SELECT c.id, c.names.\"primary\" AS name, c.subtype, c.country,"
234
+ " ST_AsGeoJSON(c.geometry) AS geometry"
235
+ " FROM read_parquet('divisions_area') AS c, a, b"
236
+ " WHERE c.id NOT IN ('{anchor_id_1}', '{anchor_id_2}')"
237
+ " AND ST_Touches(c.geometry, a.geometry)"
238
+ " AND ST_Touches(c.geometry, b.geometry)"
239
+ ),
240
+ question_hints=[
241
+ "Which regions border both {anchor_1_name} and {anchor_2_name}?",
242
+ "What places touch both {anchor_1_name} and {anchor_2_name}?",
243
+ "Regions adjacent to both {anchor_1_name} and {anchor_2_name}",
244
+ "What lies between {anchor_1_name} and {anchor_2_name}?",
245
+ "Common neighbours of {anchor_1_name} and {anchor_2_name}",
246
+ ],
247
+ ),
248
+
249
+ # ── CONTAINMENT ──────────────────────────────────────────────────────────
250
+
251
+ SQLTemplate(
252
+ template_id="contain_01",
253
+ family="containment",
254
+ sql_difficulty="medium",
255
+ anchor_source="divisions_area",
256
+ num_anchors=1,
257
+ target_subtype="locality",
258
+ sql_template=(
259
+ "WITH a AS ("
260
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
261
+ ")"
262
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
263
+ " ST_AsGeoJSON(b.geometry) AS geometry"
264
+ " FROM read_parquet('divisions_area') AS b, a"
265
+ " WHERE b.id != '{anchor_id}'"
266
+ " AND b.subtype = '{target_subtype}'"
267
+ " AND ST_Within(b.geometry, a.geometry)"
268
+ ),
269
+ question_hints=[
270
+ "What {target_subtype}s are in {anchor_name}?",
271
+ "Which {target_subtype}s fall within {anchor_name}?",
272
+ "List all {target_subtype}s inside {anchor_name}",
273
+ "{target_subtype}s contained by {anchor_name}",
274
+ "All {target_subtype}s within the boundaries of {anchor_name}",
275
+ "{target_subtype}s of {anchor_name}",
276
+ "Show every {target_subtype} in {anchor_name}",
277
+ ],
278
+ ),
279
+
280
+ SQLTemplate(
281
+ template_id="contain_02",
282
+ family="containment",
283
+ sql_difficulty="medium",
284
+ anchor_source="divisions_area",
285
+ num_anchors=1,
286
+ target_subtype="country",
287
+ sql_template=(
288
+ "WITH a AS ("
289
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
290
+ ")"
291
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype,"
292
+ " ST_AsGeoJSON(b.geometry) AS geometry"
293
+ " FROM read_parquet('divisions_area') AS b, a"
294
+ " WHERE b.id != '{anchor_id}'"
295
+ " AND b.subtype = '{target_subtype}'"
296
+ " AND ST_Contains(b.geometry, a.geometry)"
297
+ ),
298
+ question_hints=[
299
+ "What country contains {anchor_name}?",
300
+ "Which country is {anchor_name} in?",
301
+ "What country does {anchor_name} belong to?",
302
+ "Which nation contains {anchor_name}?",
303
+ "{anchor_name} is part of which country?",
304
+ "Where does {anchor_name} fall geographically?",
305
+ "What country is {anchor_name} located in?",
306
+ ],
307
+ ),
308
+
309
+ SQLTemplate(
310
+ template_id="contain_03",
311
+ family="containment",
312
+ sql_difficulty="medium",
313
+ anchor_source="natural_earth",
314
+ num_anchors=1,
315
+ target_subtype="region",
316
+ sql_template=(
317
+ "WITH a AS ("
318
+ " SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
319
+ ")"
320
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
321
+ " ST_AsGeoJSON(b.geometry) AS geometry"
322
+ " FROM read_parquet('divisions_area') AS b, a"
323
+ " WHERE b.subtype = '{target_subtype}'"
324
+ " AND ST_Within(b.geometry, a.geometry)"
325
+ ),
326
+ question_hints=[
327
+ "Which {target_subtype}s are in the {anchor_name}?",
328
+ "What {target_subtype}s fall within the {anchor_name}?",
329
+ "{target_subtype}s inside the {anchor_name}",
330
+ "Administrative {target_subtype}s within the {anchor_name}",
331
+ "All regions contained by the {anchor_name}",
332
+ "What {target_subtype}s does the {anchor_name} contain?",
333
+ "{target_subtype}s covered by the {anchor_name}",
334
+ ],
335
+ ),
336
+
337
+ # ── INTERSECTION ─────────────────────────────────────────────────────────
338
+
339
+ SQLTemplate(
340
+ template_id="intersect_01",
341
+ family="intersection",
342
+ sql_difficulty="medium-hard",
343
+ anchor_source="divisions_area",
344
+ num_anchors=1,
345
+ target_subtype="region",
346
+ sql_template=(
347
+ "WITH a AS ("
348
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
349
+ ")"
350
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
351
+ " ST_AsGeoJSON(b.geometry) AS geometry"
352
+ " FROM read_parquet('divisions_area') AS b, a"
353
+ " WHERE b.id != '{anchor_id}'"
354
+ " AND b.subtype = '{target_subtype}'"
355
+ " AND ST_Intersects(b.geometry, a.geometry)"
356
+ ),
357
+ question_hints=[
358
+ "Which {target_subtype}s intersect {anchor_name}?",
359
+ "What {target_subtype}s overlap with {anchor_name}?",
360
+ "{target_subtype}s that cross into {anchor_name}",
361
+ "Which {target_subtype}s overlap {anchor_name}?",
362
+ "{target_subtype}s partially inside {anchor_name}",
363
+ "What {target_subtype}s extend into {anchor_name}?",
364
+ ],
365
+ ),
366
+
367
+ SQLTemplate(
368
+ template_id="intersect_02",
369
+ family="intersection",
370
+ sql_difficulty="medium-hard",
371
+ anchor_source="natural_earth",
372
+ num_anchors=1,
373
+ target_subtype="country",
374
+ sql_template=(
375
+ "WITH a AS ("
376
+ " SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
377
+ ")"
378
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype,"
379
+ " ST_AsGeoJSON(b.geometry) AS geometry"
380
+ " FROM read_parquet('divisions_area') AS b, a"
381
+ " WHERE b.subtype = '{target_subtype}'"
382
+ " AND ST_Intersects(b.geometry, a.geometry)"
383
+ ),
384
+ question_hints=[
385
+ "Which countries intersect the {anchor_name}?",
386
+ "What countries does the {anchor_name} pass through?",
387
+ "Countries that overlap with the {anchor_name}",
388
+ "Which countries touch the {anchor_name}?",
389
+ "Nations intersected by the {anchor_name}",
390
+ "Which nations does the {anchor_name} cross?",
391
+ "Countries along the {anchor_name}",
392
+ "What countries does the {anchor_name} cover?",
393
+ "Countries that the {anchor_name} spans across",
394
+ ],
395
+ ),
396
+
397
+ # ── BUFFER ───────────────────────────────────────────────────────────────
398
+ # CTE computes the buffered geometry (raw) for the spatial join.
399
+ # Final SELECT wraps the result features with ST_AsGeoJSON.
400
+
401
+ SQLTemplate(
402
+ template_id="buffer_01",
403
+ family="buffer",
404
+ sql_difficulty="hard",
405
+ anchor_source="divisions_area",
406
+ num_anchors=1,
407
+ requires_buffer=True,
408
+ sql_template=(
409
+ "WITH a AS ("
410
+ " SELECT ST_Buffer(geometry, {buffer_km} * 1000.0 / 111320.0) AS geom"
411
+ " FROM read_parquet('divisions_area')"
412
+ " WHERE id = '{anchor_id}'"
413
+ ")"
414
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
415
+ " ST_AsGeoJSON(b.geometry) AS geometry"
416
+ " FROM read_parquet('divisions_area') AS b, a"
417
+ " WHERE b.id != '{anchor_id}'"
418
+ " AND ST_Intersects(b.geometry, a.geom)"
419
+ ),
420
+ question_hints=[
421
+ "What is within {buffer_km} km of {anchor_name}?",
422
+ "Administrative units within {buffer_km} km of {anchor_name}",
423
+ "Features within a {buffer_km} km radius of {anchor_name}",
424
+ "Places within {buffer_km} kilometers of {anchor_name}",
425
+ "{buffer_km} km buffer around {anchor_name}",
426
+ "What falls within {buffer_km} km of {anchor_name}?",
427
+ "Everything within {buffer_km} km of {anchor_name}",
428
+ ],
429
+ ),
430
+
431
+ SQLTemplate(
432
+ template_id="buffer_02",
433
+ family="buffer",
434
+ sql_difficulty="hard",
435
+ anchor_source="divisions_area",
436
+ num_anchors=1,
437
+ requires_buffer=True,
438
+ sql_template=(
439
+ "WITH a AS ("
440
+ " SELECT ST_Buffer(geometry, {buffer_m} / 111320.0) AS geom"
441
+ " FROM read_parquet('divisions_area')"
442
+ " WHERE id = '{anchor_id}'"
443
+ ")"
444
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
445
+ " ST_AsGeoJSON(b.geometry) AS geometry"
446
+ " FROM read_parquet('divisions_area') AS b, a"
447
+ " WHERE b.id != '{anchor_id}'"
448
+ " AND ST_Intersects(b.geometry, a.geom)"
449
+ ),
450
+ question_hints=[
451
+ "What is within {buffer_m} meters of {anchor_name}?",
452
+ "Features within {buffer_m} m of {anchor_name}",
453
+ "Places within {buffer_m} metres of {anchor_name}",
454
+ "{buffer_m} meter buffer around {anchor_name}",
455
+ "What falls within {buffer_m} m of {anchor_name}?",
456
+ "Administrative units within {buffer_m} metres of {anchor_name}",
457
+ ],
458
+ ),
459
+
460
+ SQLTemplate(
461
+ template_id="buffer_03",
462
+ family="buffer",
463
+ sql_difficulty="hard",
464
+ anchor_source="natural_earth",
465
+ num_anchors=1,
466
+ requires_buffer=True,
467
+ sql_template=(
468
+ "WITH a AS ("
469
+ " SELECT ST_Buffer(geometry, {buffer_km} * 1000.0 / 111320.0) AS geom"
470
+ " FROM read_parquet('natural_earth')"
471
+ " WHERE id = '{anchor_id}'"
472
+ ")"
473
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
474
+ " ST_AsGeoJSON(b.geometry) AS geometry"
475
+ " FROM read_parquet('divisions_area') AS b, a"
476
+ " WHERE ST_Intersects(b.geometry, a.geom)"
477
+ ),
478
+ question_hints=[
479
+ "What administrative units are within {buffer_km} km of the {anchor_name}?",
480
+ "Countries within {buffer_km} km of the {anchor_name}",
481
+ "Regions within {buffer_km} km of the {anchor_name}",
482
+ "What falls within {buffer_km} km of the {anchor_name}?",
483
+ "Administrative divisions within a {buffer_km} km radius of the {anchor_name}",
484
+ "Places within {buffer_km} kilometers of the {anchor_name}",
485
+ ],
486
+ ),
487
+
488
+ SQLTemplate(
489
+ template_id="buffer_04",
490
+ family="buffer",
491
+ sql_difficulty="hard",
492
+ anchor_source="natural_earth",
493
+ num_anchors=1,
494
+ requires_buffer=True,
495
+ sql_template=(
496
+ "WITH a AS ("
497
+ " SELECT ST_Buffer(geometry, {buffer_m} / 111320.0) AS geom"
498
+ " FROM read_parquet('natural_earth')"
499
+ " WHERE id = '{anchor_id}'"
500
+ ")"
501
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
502
+ " ST_AsGeoJSON(b.geometry) AS geometry"
503
+ " FROM read_parquet('divisions_area') AS b, a"
504
+ " WHERE ST_Intersects(b.geometry, a.geom)"
505
+ ),
506
+ question_hints=[
507
+ "What is within {buffer_m} meters of the {anchor_name}?",
508
+ "Administrative units within {buffer_m} m of the {anchor_name}",
509
+ "Places within {buffer_m} metres of the {anchor_name}",
510
+ "{buffer_m} meter buffer around the {anchor_name}",
511
+ ],
512
+ ),
513
+
514
+ # ── CHAINED ──────────────────────────────────────────────────────────────
515
+ # Containment + EXISTS/NOT EXISTS ocean/sea.
516
+ # CTE holds raw geometry for ST_Within; final SELECT wraps with ST_AsGeoJSON.
517
+
518
+ SQLTemplate(
519
+ template_id="chained_01",
520
+ family="chained",
521
+ sql_difficulty="hard",
522
+ anchor_source="divisions_area",
523
+ num_anchors=1,
524
+ target_subtype="locality",
525
+ sql_template=(
526
+ "WITH region AS ("
527
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
528
+ ")"
529
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
530
+ " ST_AsGeoJSON(b.geometry) AS geometry"
531
+ " FROM read_parquet('divisions_area') AS b, region"
532
+ " WHERE b.subtype = '{target_subtype}'"
533
+ " AND ST_Within(b.geometry, region.geometry)"
534
+ " AND EXISTS ("
535
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
536
+ " WHERE n.subtype IN ('ocean', 'sea')"
537
+ " AND ST_Intersects(b.geometry, n.geometry)"
538
+ " )"
539
+ ),
540
+ question_hints=[
541
+ "Coastal {target_subtype}s of {anchor_name}",
542
+ "{target_subtype}s in {anchor_name} with sea access",
543
+ "Which {target_subtype}s in {anchor_name} are on the coast?",
544
+ "Seaside {target_subtype}s within {anchor_name}",
545
+ "{target_subtype}s in {anchor_name} bordering the sea",
546
+ "Oceanfront {target_subtype}s in {anchor_name}",
547
+ "Which {target_subtype}s in {anchor_name} have a coastline?",
548
+ ],
549
+ ),
550
+
551
+ SQLTemplate(
552
+ template_id="chained_02",
553
+ family="chained",
554
+ sql_difficulty="hard",
555
+ anchor_source="divisions_area",
556
+ num_anchors=1,
557
+ target_subtype="country",
558
+ sql_template=(
559
+ "WITH region AS ("
560
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
561
+ ")"
562
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
563
+ " ST_AsGeoJSON(b.geometry) AS geometry"
564
+ " FROM read_parquet('divisions_area') AS b, region"
565
+ " WHERE b.subtype = '{target_subtype}'"
566
+ " AND ST_Intersects(b.geometry, region.geometry)"
567
+ " AND NOT EXISTS ("
568
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
569
+ " WHERE n.subtype IN ('ocean', 'sea')"
570
+ " AND ST_Intersects(b.geometry, n.geometry)"
571
+ " )"
572
+ ),
573
+ question_hints=[
574
+ "Landlocked {target_subtype}s in {anchor_name}",
575
+ "Which {target_subtype}s in {anchor_name} have no sea access?",
576
+ "{target_subtype}s in {anchor_name} that are landlocked",
577
+ "{target_subtype}s in {anchor_name} with no coastline",
578
+ "Which {target_subtype}s within {anchor_name} are landlocked?",
579
+ "Interior {target_subtype}s of {anchor_name} with no ocean border",
580
+ ],
581
+ ),
582
+
583
+ SQLTemplate(
584
+ template_id="chained_03",
585
+ family="chained",
586
+ sql_difficulty="hard",
587
+ anchor_source="divisions_area",
588
+ num_anchors=1,
589
+ target_subtype="locality",
590
+ sql_template=(
591
+ "WITH region AS ("
592
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
593
+ ")"
594
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
595
+ " ST_AsGeoJSON(b.geometry) AS geometry"
596
+ " FROM read_parquet('divisions_area') AS b, region"
597
+ " WHERE b.subtype = '{target_subtype}'"
598
+ " AND ST_Within(b.geometry, region.geometry)"
599
+ " AND EXISTS ("
600
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
601
+ " WHERE n.subtype IN ('Terrain area', 'Island group', 'Peninsula')"
602
+ " AND ST_Intersects(b.geometry, n.geometry)"
603
+ " )"
604
+ ),
605
+ question_hints=[
606
+ "{target_subtype}s in {anchor_name} on a terrain feature or island",
607
+ "{target_subtype}s of {anchor_name} on a peninsula or island group",
608
+ "{target_subtype}s within {anchor_name} on notable landforms",
609
+ "Island and peninsula {target_subtype}s of {anchor_name}",
610
+ ],
611
+ ),
612
+
613
+ # ── DIFFERENCE ───────────────────────────────────────────────────────────
614
+ # CTEs hold raw geometry; ST_Difference result wrapped with ST_AsGeoJSON.
615
+
616
+ SQLTemplate(
617
+ template_id="diff_01",
618
+ family="difference",
619
+ sql_difficulty="hard",
620
+ anchor_source="divisions_area",
621
+ num_anchors=2,
622
+ sql_template=(
623
+ "WITH a AS ("
624
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_1}'"
625
+ "),"
626
+ " b AS ("
627
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_2}'"
628
+ ")"
629
+ " SELECT ST_AsGeoJSON(ST_Difference(a.geometry, b.geometry)) AS geometry"
630
+ " FROM a, b"
631
+ " WHERE ST_Intersects(a.geometry, b.geometry)"
632
+ ),
633
+ question_hints=[
634
+ "{anchor_1_name} excluding {anchor_2_name}",
635
+ "{anchor_1_name} minus {anchor_2_name}",
636
+ "The part of {anchor_1_name} that is not in {anchor_2_name}",
637
+ "{anchor_1_name} without the {anchor_2_name} area",
638
+ "Remove {anchor_2_name} from {anchor_1_name}",
639
+ "{anchor_1_name} with {anchor_2_name} cut out",
640
+ "Subtract {anchor_2_name} from {anchor_1_name}",
641
+ "What is left of {anchor_1_name} after removing {anchor_2_name}?",
642
+ ],
643
+ ),
644
+
645
+ SQLTemplate(
646
+ template_id="diff_02",
647
+ family="difference",
648
+ sql_difficulty="hard",
649
+ anchor_source="mixed",
650
+ num_anchors=2,
651
+ sql_template=(
652
+ "WITH a AS ("
653
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
654
+ "),"
655
+ " b AS ("
656
+ " SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{clip_feature_id}'"
657
+ ")"
658
+ " SELECT ST_AsGeoJSON(ST_Difference(a.geometry, b.geometry)) AS geometry"
659
+ " FROM a, b"
660
+ " WHERE ST_Intersects(a.geometry, b.geometry)"
661
+ ),
662
+ question_hints=[
663
+ "The part of {anchor_name} outside the {clip_feature_name}",
664
+ "{anchor_name} excluding the {clip_feature_name}",
665
+ "{anchor_name} minus the {clip_feature_name}",
666
+ "The land area of {anchor_name} not covered by the {clip_feature_name}",
667
+ "{anchor_name} with the {clip_feature_name} removed",
668
+ "What remains of {anchor_name} after removing the {clip_feature_name}?",
669
+ ],
670
+ ),
671
+
672
+ # ── BORDER CORRIDOR ──────────────────────────────────────────────────────
673
+ # Intermediate intersection kept raw; final buffer wrapped with ST_AsGeoJSON.
674
+
675
+ SQLTemplate(
676
+ template_id="corridor_01",
677
+ family="border_corridor",
678
+ sql_difficulty="hard",
679
+ anchor_source="divisions_area",
680
+ num_anchors=2,
681
+ requires_buffer=True,
682
+ sql_template=(
683
+ "WITH a AS ("
684
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_1}'"
685
+ "),"
686
+ " b AS ("
687
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_2}'"
688
+ "),"
689
+ " border AS ("
690
+ " SELECT ST_Intersection(a.geometry, b.geometry) AS line"
691
+ " FROM a, b"
692
+ " WHERE ST_Intersects(a.geometry, b.geometry)"
693
+ ")"
694
+ " SELECT ST_AsGeoJSON(ST_Buffer(border.line, {buffer_km} * 1000.0 / 111320.0)) AS geometry"
695
+ " FROM border"
696
+ " WHERE border.line IS NOT NULL"
697
+ ),
698
+ question_hints=[
699
+ "{buffer_km} km zone along the border between {anchor_1_name} and {anchor_2_name}",
700
+ "The {buffer_km} km border corridor between {anchor_1_name} and {anchor_2_name}",
701
+ "Area within {buffer_km} km of the {anchor_1_name}-{anchor_2_name} border",
702
+ "The region straddling the border of {anchor_1_name} and {anchor_2_name} within {buffer_km} km",
703
+ "{buffer_km} km on either side of the {anchor_1_name} and {anchor_2_name} border",
704
+ "Buffer the {anchor_1_name}-{anchor_2_name} boundary by {buffer_km} km",
705
+ ],
706
+ ),
707
+
708
+ # ── SET OPERATIONS ───────────────────────────────────────────────────────
709
+ # union_01 / union_02: 2-anchor and filtered-containment unions.
710
+ # union_03: 3-anchor union — trains the model on IN-clause with 3 IDs.
711
+ # contain_multi: subtype within multiple countries via country IN clause.
712
+
713
+ SQLTemplate(
714
+ template_id="union_01",
715
+ family="set_operations",
716
+ sql_difficulty="medium-hard",
717
+ anchor_source="divisions_area",
718
+ num_anchors=2,
719
+ sql_template=(
720
+ "SELECT ST_AsGeoJSON(ST_Union_Agg(geometry)) AS geometry,"
721
+ " array_agg(names.\"primary\") AS names"
722
+ " FROM read_parquet('divisions_area')"
723
+ " WHERE id IN ('{anchor_id_1}', '{anchor_id_2}')"
724
+ ),
725
+ question_hints=[
726
+ "The combined area of {anchor_1_name} and {anchor_2_name}",
727
+ "Union of {anchor_1_name} and {anchor_2_name}",
728
+ "Merge {anchor_1_name} and {anchor_2_name}",
729
+ "{anchor_1_name} and {anchor_2_name} together",
730
+ "Combined geometry of {anchor_1_name} and {anchor_2_name}",
731
+ ],
732
+ ),
733
+
734
+ SQLTemplate(
735
+ template_id="union_03",
736
+ family="set_operations",
737
+ sql_difficulty="medium-hard",
738
+ anchor_source="divisions_area",
739
+ num_anchors=3,
740
+ sql_template=(
741
+ "SELECT ST_AsGeoJSON(ST_Union_Agg(geometry)) AS geometry,"
742
+ " array_agg(names.\"primary\") AS names"
743
+ " FROM read_parquet('divisions_area')"
744
+ " WHERE id IN ('{anchor_id_1}', '{anchor_id_2}', '{anchor_id_3}')"
745
+ ),
746
+ question_hints=[
747
+ "Show me {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
748
+ "The combined area of {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
749
+ "Union of {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
750
+ "Merge {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
751
+ "{anchor_1_name}, {anchor_2_name} and {anchor_3_name} together",
752
+ "Display {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
753
+ ],
754
+ ),
755
+
756
+ SQLTemplate(
757
+ template_id="contain_multi_01",
758
+ family="set_operations",
759
+ sql_difficulty="medium-hard",
760
+ anchor_source="divisions_area",
761
+ num_anchors=2,
762
+ target_subtype="region",
763
+ sql_template=(
764
+ "SELECT id, names.\"primary\" AS name, subtype, country,"
765
+ " ST_AsGeoJSON(geometry) AS geometry"
766
+ " FROM read_parquet('divisions_area')"
767
+ " WHERE country IN ('{country_1}', '{country_2}')"
768
+ " AND subtype = '{target_subtype}'"
769
+ ),
770
+ question_hints=[
771
+ "{target_subtype}s of {anchor_1_name} and {anchor_2_name}",
772
+ "All {target_subtype}s in {anchor_1_name} and {anchor_2_name}",
773
+ "Show {target_subtype}s across {anchor_1_name} and {anchor_2_name}",
774
+ "{target_subtype}s belonging to {anchor_1_name} and {anchor_2_name}",
775
+ "List {target_subtype}s in both {anchor_1_name} and {anchor_2_name}",
776
+ ],
777
+ ),
778
+
779
+ SQLTemplate(
780
+ template_id="contain_multi_02",
781
+ family="set_operations",
782
+ sql_difficulty="medium-hard",
783
+ anchor_source="divisions_area",
784
+ num_anchors=3,
785
+ target_subtype="region",
786
+ sql_template=(
787
+ "SELECT id, names.\"primary\" AS name, subtype, country,"
788
+ " ST_AsGeoJSON(geometry) AS geometry"
789
+ " FROM read_parquet('divisions_area')"
790
+ " WHERE country IN ('{country_1}', '{country_2}', '{country_3}')"
791
+ " AND subtype = '{target_subtype}'"
792
+ ),
793
+ question_hints=[
794
+ "{target_subtype}s of {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
795
+ "All {target_subtype}s in {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
796
+ "Show {target_subtype}s across {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
797
+ "List {target_subtype}s in {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
798
+ ],
799
+ ),
800
+
801
+ SQLTemplate(
802
+ template_id="union_02",
803
+ family="set_operations",
804
+ sql_difficulty="hard",
805
+ anchor_source="divisions_area",
806
+ num_anchors=1,
807
+ target_subtype="locality",
808
+ requires_aggregation=True,
809
+ sql_template=(
810
+ "WITH a AS ("
811
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
812
+ ")"
813
+ " SELECT ST_AsGeoJSON(ST_Union_Agg(b.geometry)) AS geometry"
814
+ " FROM read_parquet('divisions_area') AS b, a"
815
+ " WHERE b.subtype = '{target_subtype}'"
816
+ " AND ST_Within(b.geometry, a.geometry)"
817
+ ),
818
+ question_hints=[
819
+ "Merge all {target_subtype}s in {anchor_name} into one geometry",
820
+ "Combined geometry of all {target_subtype}s in {anchor_name}",
821
+ "Union of all {target_subtype}s within {anchor_name}",
822
+ "All {target_subtype}s of {anchor_name} merged together",
823
+ "The overall extent of {target_subtype}s in {anchor_name}",
824
+ ],
825
+ ),
826
+
827
+ # ── PARTIAL SELECTION ────────────────────────────────────────────────────
828
+ # Bbox clip CTEs use raw geometry; ST_Intersection result wrapped.
829
+
830
+ SQLTemplate(
831
+ template_id="partial_01",
832
+ family="partial_selection",
833
+ sql_difficulty="hard",
834
+ anchor_source="divisions_area",
835
+ num_anchors=1,
836
+ sql_template=(
837
+ "WITH a AS ("
838
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
839
+ "),"
840
+ " bbox AS ("
841
+ " SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
842
+ " ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
843
+ "),"
844
+ " clip AS ("
845
+ " SELECT ST_MakeEnvelope(xmin, (ymin + ymax) / 2.0, xmax, ymax) AS half_geom FROM bbox"
846
+ ")"
847
+ " SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
848
+ " FROM a, clip"
849
+ ),
850
+ question_hints=[
851
+ "The northern half of {anchor_name}",
852
+ "Northern part of {anchor_name}",
853
+ "The top half of {anchor_name}",
854
+ "Northern portion of {anchor_name}",
855
+ ],
856
+ ),
857
+
858
+ SQLTemplate(
859
+ template_id="partial_02",
860
+ family="partial_selection",
861
+ sql_difficulty="hard",
862
+ anchor_source="divisions_area",
863
+ num_anchors=1,
864
+ sql_template=(
865
+ "WITH a AS ("
866
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
867
+ "),"
868
+ " bbox AS ("
869
+ " SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
870
+ " ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
871
+ "),"
872
+ " clip AS ("
873
+ " SELECT ST_MakeEnvelope(xmin, ymin, xmax, (ymin + ymax) / 2.0) AS half_geom FROM bbox"
874
+ ")"
875
+ " SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
876
+ " FROM a, clip"
877
+ ),
878
+ question_hints=[
879
+ "The southern half of {anchor_name}",
880
+ "Southern part of {anchor_name}",
881
+ "The bottom half of {anchor_name}",
882
+ "Southern portion of {anchor_name}",
883
+ ],
884
+ ),
885
+
886
+ SQLTemplate(
887
+ template_id="partial_03",
888
+ family="partial_selection",
889
+ sql_difficulty="hard",
890
+ anchor_source="divisions_area",
891
+ num_anchors=1,
892
+ sql_template=(
893
+ "WITH a AS ("
894
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
895
+ "),"
896
+ " bbox AS ("
897
+ " SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
898
+ " ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
899
+ "),"
900
+ " clip AS ("
901
+ " SELECT ST_MakeEnvelope((xmin + xmax) / 2.0, ymin, xmax, ymax) AS half_geom FROM bbox"
902
+ ")"
903
+ " SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
904
+ " FROM a, clip"
905
+ ),
906
+ question_hints=[
907
+ "The eastern half of {anchor_name}",
908
+ "Eastern part of {anchor_name}",
909
+ "The right half of {anchor_name}",
910
+ "Eastern portion of {anchor_name}",
911
+ ],
912
+ ),
913
+
914
+ SQLTemplate(
915
+ template_id="partial_04",
916
+ family="partial_selection",
917
+ sql_difficulty="hard",
918
+ anchor_source="divisions_area",
919
+ num_anchors=1,
920
+ sql_template=(
921
+ "WITH a AS ("
922
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
923
+ "),"
924
+ " bbox AS ("
925
+ " SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
926
+ " ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
927
+ "),"
928
+ " clip AS ("
929
+ " SELECT ST_MakeEnvelope(xmin, ymin, (xmin + xmax) / 2.0, ymax) AS half_geom FROM bbox"
930
+ ")"
931
+ " SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
932
+ " FROM a, clip"
933
+ ),
934
+ question_hints=[
935
+ "The western half of {anchor_name}",
936
+ "Western part of {anchor_name}",
937
+ "The left half of {anchor_name}",
938
+ "Western portion of {anchor_name}",
939
+ ],
940
+ ),
941
+
942
+ SQLTemplate(
943
+ template_id="partial_05",
944
+ family="partial_selection",
945
+ sql_difficulty="hard",
946
+ anchor_source="mixed",
947
+ num_anchors=2,
948
+ sql_template=(
949
+ "WITH a AS ("
950
+ " SELECT geometry AS g1 FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
951
+ "),"
952
+ " b AS ("
953
+ " SELECT geometry AS g2 FROM read_parquet('natural_earth') WHERE id = '{clip_feature_id}'"
954
+ ")"
955
+ " SELECT ST_AsGeoJSON(ST_Intersection(a.g1, b.g2)) AS geometry"
956
+ " FROM a, b"
957
+ " WHERE ST_Intersects(a.g1, b.g2)"
958
+ ),
959
+ question_hints=[
960
+ "The part of {anchor_name} that overlaps the {clip_feature_name}",
961
+ "{anchor_name} within the {clip_feature_name}",
962
+ "The portion of {anchor_name} inside the {clip_feature_name}",
963
+ "Clip {anchor_name} to the {clip_feature_name}",
964
+ ],
965
+ ),
966
+
967
+ # ── AGGREGATION ──────────────────────────────────────────────────────────
968
+ # ST_Area uses raw geometry in the ORDER BY; final SELECT wraps output.
969
+
970
+ SQLTemplate(
971
+ template_id="agg_01",
972
+ family="aggregation",
973
+ sql_difficulty="hard",
974
+ anchor_source="divisions_area",
975
+ num_anchors=1,
976
+ target_subtype=None, # filled at generation time: locality or region
977
+ requires_aggregation=True,
978
+ sql_template=(
979
+ "WITH a AS ("
980
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
981
+ ")"
982
+ " SELECT b.id, b.names.\"primary\" AS name,"
983
+ " ST_AsGeoJSON(b.geometry) AS geometry,"
984
+ " ST_Area(b.geometry) AS area"
985
+ " FROM read_parquet('divisions_area') AS b, a"
986
+ " WHERE ST_Within(b.geometry, a.geometry)"
987
+ " AND b.subtype = '{target_subtype}'"
988
+ " ORDER BY area DESC"
989
+ " LIMIT {top_n}"
990
+ ),
991
+ question_hints=[
992
+ "Top {top_n} largest {target_subtype}s in {anchor_name}",
993
+ "Biggest {top_n} {target_subtype}s in {anchor_name}",
994
+ "{top_n} largest {target_subtype}s inside {anchor_name}",
995
+ "The {top_n} biggest {target_subtype}s within {anchor_name}",
996
+ "Largest {target_subtype} in {anchor_name}",
997
+ "Which {target_subtype} in {anchor_name} has the most area?",
998
+ ],
999
+ ),
1000
+
1001
+ SQLTemplate(
1002
+ template_id="agg_02",
1003
+ family="aggregation",
1004
+ sql_difficulty="hard",
1005
+ anchor_source="divisions_area",
1006
+ num_anchors=1,
1007
+ target_subtype=None, # filled at generation time: locality or region
1008
+ requires_aggregation=True,
1009
+ sql_template=(
1010
+ "WITH a AS ("
1011
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1012
+ ")"
1013
+ " SELECT b.id, b.names.\"primary\" AS name,"
1014
+ " ST_AsGeoJSON(b.geometry) AS geometry,"
1015
+ " ST_Area(b.geometry) AS area"
1016
+ " FROM read_parquet('divisions_area') AS b, a"
1017
+ " WHERE ST_Within(b.geometry, a.geometry)"
1018
+ " AND b.subtype = '{target_subtype}'"
1019
+ " ORDER BY area ASC"
1020
+ " LIMIT {top_n}"
1021
+ ),
1022
+ question_hints=[
1023
+ "Top {top_n} smallest {target_subtype}s in {anchor_name}",
1024
+ "Smallest {top_n} {target_subtype}s in {anchor_name}",
1025
+ "{top_n} smallest {target_subtype}s inside {anchor_name}",
1026
+ "The {top_n} tiniest {target_subtype}s within {anchor_name}",
1027
+ "Smallest {target_subtype} in {anchor_name}",
1028
+ "Which {target_subtype} in {anchor_name} has the least area?",
1029
+ ],
1030
+ ),
1031
+
1032
+ SQLTemplate(
1033
+ template_id="agg_03",
1034
+ family="aggregation",
1035
+ sql_difficulty="hard",
1036
+ anchor_source="divisions_area",
1037
+ num_anchors=1,
1038
+ target_subtype=None, # filled at generation time: locality or region
1039
+ requires_aggregation=True,
1040
+ sql_template=(
1041
+ "SELECT id, names.\"primary\" AS name,"
1042
+ " ST_AsGeoJSON(geometry) AS geometry,"
1043
+ " ST_Area(geometry) AS area"
1044
+ " FROM read_parquet('divisions_area')"
1045
+ " WHERE country = '{country}'"
1046
+ " AND subtype = '{target_subtype}'"
1047
+ " ORDER BY area DESC"
1048
+ " LIMIT {top_n}"
1049
+ ),
1050
+ question_hints=[
1051
+ "Top {top_n} largest {target_subtype}s in {anchor_name}",
1052
+ "{top_n} biggest {target_subtype}s in {anchor_name}",
1053
+ "Largest {top_n} {target_subtype}s in {anchor_name}",
1054
+ "The {top_n} largest {target_subtype}s in {anchor_name}",
1055
+ "Biggest {target_subtype} in {anchor_name}",
1056
+ "Which {target_subtype} in {anchor_name} is the largest?",
1057
+ ],
1058
+ ),
1059
+
1060
+ SQLTemplate(
1061
+ template_id="agg_04",
1062
+ family="aggregation",
1063
+ sql_difficulty="hard",
1064
+ anchor_source="divisions_area",
1065
+ num_anchors=1,
1066
+ target_subtype=None, # filled at generation time: locality or region
1067
+ requires_aggregation=True,
1068
+ sql_template=(
1069
+ "SELECT id, names.\"primary\" AS name,"
1070
+ " ST_AsGeoJSON(geometry) AS geometry,"
1071
+ " ST_Area(geometry) AS area"
1072
+ " FROM read_parquet('divisions_area')"
1073
+ " WHERE country = '{country}'"
1074
+ " AND subtype = '{target_subtype}'"
1075
+ " ORDER BY area ASC"
1076
+ " LIMIT {top_n}"
1077
+ ),
1078
+ question_hints=[
1079
+ "Top {top_n} smallest {target_subtype}s in {anchor_name}",
1080
+ "{top_n} smallest {target_subtype}s in {anchor_name}",
1081
+ "Smallest {top_n} {target_subtype}s in {anchor_name}",
1082
+ "The {top_n} smallest {target_subtype}s in {anchor_name}",
1083
+ "Smallest {target_subtype} in {anchor_name}",
1084
+ "Which {target_subtype} in {anchor_name} is the smallest?",
1085
+ ],
1086
+ ),
1087
+
1088
+ # ── WINDOW FUNCTION ──────────────────────────────────────────────────────
1089
+ # CTE keeps raw geometry for ST_Area; final SELECT wraps with ST_AsGeoJSON.
1090
+
1091
+ SQLTemplate(
1092
+ template_id="window_01",
1093
+ family="window_function",
1094
+ sql_difficulty="hard",
1095
+ anchor_source="divisions_area",
1096
+ num_anchors=1,
1097
+ target_subtype="locality",
1098
+ requires_aggregation=True,
1099
+ sql_template=(
1100
+ "WITH ranked AS ("
1101
+ " SELECT id, names.\"primary\" AS name, subtype, country, region, geometry,"
1102
+ " ST_Area(geometry) AS area,"
1103
+ " ROW_NUMBER() OVER (PARTITION BY region ORDER BY ST_Area(geometry) DESC) AS rn"
1104
+ " FROM read_parquet('divisions_area')"
1105
+ " WHERE country = '{country}'"
1106
+ " AND subtype = '{target_subtype}'"
1107
+ ")"
1108
+ " SELECT id, name, subtype, country, region,"
1109
+ " ST_AsGeoJSON(geometry) AS geometry, area"
1110
+ " FROM ranked"
1111
+ " WHERE rn = 1"
1112
+ ),
1113
+ question_hints=[
1114
+ "The largest {target_subtype} in each region of {anchor_name}",
1115
+ "Biggest {target_subtype} per region in {anchor_name}",
1116
+ "Largest {target_subtype} for every region of {anchor_name}",
1117
+ "The biggest {target_subtype} in each province of {anchor_name}",
1118
+ ],
1119
+ ),
1120
+
1121
+ SQLTemplate(
1122
+ template_id="window_02",
1123
+ family="window_function",
1124
+ sql_difficulty="hard",
1125
+ anchor_source="divisions_area",
1126
+ num_anchors=1,
1127
+ target_subtype="locality",
1128
+ requires_aggregation=True,
1129
+ sql_template=(
1130
+ "WITH ranked AS ("
1131
+ " SELECT id, names.\"primary\" AS name, subtype, country, region, geometry,"
1132
+ " ST_Area(geometry) AS area,"
1133
+ " ROW_NUMBER() OVER (PARTITION BY region ORDER BY ST_Area(geometry) ASC) AS rn"
1134
+ " FROM read_parquet('divisions_area')"
1135
+ " WHERE country = '{country}'"
1136
+ " AND subtype = '{target_subtype}'"
1137
+ ")"
1138
+ " SELECT id, name, subtype, country, region,"
1139
+ " ST_AsGeoJSON(geometry) AS geometry, area"
1140
+ " FROM ranked"
1141
+ " WHERE rn = 1"
1142
+ ),
1143
+ question_hints=[
1144
+ "The smallest {target_subtype} in each region of {anchor_name}",
1145
+ "Smallest {target_subtype} per region in {anchor_name}",
1146
+ "Tiniest {target_subtype} for every region of {anchor_name}",
1147
+ "The smallest {target_subtype} in each province of {anchor_name}",
1148
+ ],
1149
+ ),
1150
+
1151
+ # ── ATTRIBUTE FILTER ─────────────────────────────────────────────────────
1152
+ # No spatial op — pure WHERE on is_land / is_territorial / country.
1153
+
1154
+ SQLTemplate(
1155
+ template_id="attr_01",
1156
+ family="attribute_filter",
1157
+ sql_difficulty="medium",
1158
+ anchor_source="divisions_area",
1159
+ num_anchors=1,
1160
+ target_subtype="dependency",
1161
+ sql_template=(
1162
+ "SELECT id, names.\"primary\" AS name, subtype, country,"
1163
+ " ST_AsGeoJSON(geometry) AS geometry"
1164
+ " FROM read_parquet('divisions_area')"
1165
+ " WHERE country = '{country}'"
1166
+ " AND is_land = TRUE"
1167
+ " AND subtype = '{target_subtype}'"
1168
+ ),
1169
+ question_hints=[
1170
+ "Island territories of {anchor_name}",
1171
+ "Overseas island {target_subtype}s belonging to {anchor_name}",
1172
+ "Which islands are part of {anchor_name}?",
1173
+ "Land territories of {anchor_name}",
1174
+ "Island possessions of {anchor_name}",
1175
+ "{anchor_name}'s island {target_subtype}s",
1176
+ ],
1177
+ ),
1178
+
1179
+ SQLTemplate(
1180
+ template_id="attr_02",
1181
+ family="attribute_filter",
1182
+ sql_difficulty="medium",
1183
+ anchor_source="divisions_area",
1184
+ num_anchors=1,
1185
+ target_subtype="region",
1186
+ sql_template=(
1187
+ "SELECT id, names.\"primary\" AS name, subtype, country,"
1188
+ " ST_AsGeoJSON(geometry) AS geometry"
1189
+ " FROM read_parquet('divisions_area')"
1190
+ " WHERE country = '{country}'"
1191
+ " AND is_territorial = TRUE"
1192
+ " AND subtype = '{target_subtype}'"
1193
+ ),
1194
+ question_hints=[
1195
+ "Territorial {target_subtype}s of {anchor_name}",
1196
+ "Official territorial divisions of {anchor_name}",
1197
+ "Recognised territorial {target_subtype}s belonging to {anchor_name}",
1198
+ "Which territorial regions does {anchor_name} have?",
1199
+ ],
1200
+ ),
1201
+
1202
+ SQLTemplate(
1203
+ template_id="attr_03",
1204
+ family="attribute_filter",
1205
+ sql_difficulty="medium",
1206
+ anchor_source="divisions_area",
1207
+ num_anchors=1,
1208
+ target_subtype="locality",
1209
+ sql_template=(
1210
+ "SELECT id, names.\"primary\" AS name, subtype, country,"
1211
+ " ST_AsGeoJSON(geometry) AS geometry"
1212
+ " FROM read_parquet('divisions_area')"
1213
+ " WHERE country = '{country}'"
1214
+ " AND subtype = '{target_subtype}'"
1215
+ " AND is_land = TRUE"
1216
+ ),
1217
+ question_hints=[
1218
+ "Land-based {target_subtype}s of {anchor_name}",
1219
+ "{target_subtype}s on the mainland of {anchor_name}",
1220
+ "All {target_subtype}s on land in {anchor_name}",
1221
+ "Non-island {target_subtype}s of {anchor_name}",
1222
+ ],
1223
+ ),
1224
+
1225
+ # ── NATURAL EARTH ADJACENCY ─────────────────────────────────────────────
1226
+ # Division anchor, natural_earth targets. Handler formats anchor_id and
1227
+ # target_subtype but the SQL hardcodes NE subtypes (like adj_03).
1228
+
1229
+ SQLTemplate(
1230
+ template_id="adj_04",
1231
+ family="adjacency",
1232
+ sql_difficulty="medium",
1233
+ anchor_source="divisions_area",
1234
+ num_anchors=1,
1235
+ target_subtype="river",
1236
+ sql_template=(
1237
+ "WITH a AS ("
1238
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1239
+ ")"
1240
+ " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
1241
+ " ST_AsGeoJSON(n.geometry) AS geometry"
1242
+ " FROM read_parquet('natural_earth') AS n, a"
1243
+ " WHERE n.subtype IN ('River', 'Lake', 'Basin')"
1244
+ " AND ST_Intersects(a.geometry, n.geometry)"
1245
+ ),
1246
+ question_hints=[
1247
+ "What rivers or lakes are in {anchor_name}?",
1248
+ "Natural water features of {anchor_name}",
1249
+ "Which rivers flow through {anchor_name}?",
1250
+ "Lakes and rivers within {anchor_name}",
1251
+ "Water features inside {anchor_name}",
1252
+ "What bodies of water cross {anchor_name}?",
1253
+ "Rivers of {anchor_name}",
1254
+ "Show me the lakes in {anchor_name}",
1255
+ ],
1256
+ ),
1257
+
1258
+ SQLTemplate(
1259
+ template_id="adj_05",
1260
+ family="adjacency",
1261
+ sql_difficulty="medium",
1262
+ anchor_source="divisions_area",
1263
+ num_anchors=1,
1264
+ target_subtype="range",
1265
+ sql_template=(
1266
+ "WITH a AS ("
1267
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1268
+ ")"
1269
+ " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
1270
+ " ST_AsGeoJSON(n.geometry) AS geometry"
1271
+ " FROM read_parquet('natural_earth') AS n, a"
1272
+ " WHERE n.subtype IN ('Range/Mts', 'Terrain area', 'Peninsula', 'Depression')"
1273
+ " AND ST_Intersects(a.geometry, n.geometry)"
1274
+ ),
1275
+ question_hints=[
1276
+ "What mountain ranges are in {anchor_name}?",
1277
+ "Terrain features of {anchor_name}",
1278
+ "Which mountain ranges cross {anchor_name}?",
1279
+ "Landforms inside {anchor_name}",
1280
+ "Peninsulas and ranges in {anchor_name}",
1281
+ "Geographic features within {anchor_name}",
1282
+ "Mountains of {anchor_name}",
1283
+ "What terrain does {anchor_name} contain?",
1284
+ ],
1285
+ ),
1286
+
1287
+ # ── NATURAL EARTH INTERSECTION ──────────────────────────────────────────
1288
+ # intersect_03: NE anchor, finding overlapping regions (vs countries in
1289
+ # intersect_02). Uses cross_source_relations handler.
1290
+ # intersect_04: division anchor, finding NE features that overlap it.
1291
+ # Uses intersection_pairs handler (extra NE subtypes ignored in SQL).
1292
+
1293
+ SQLTemplate(
1294
+ template_id="intersect_03",
1295
+ family="intersection",
1296
+ sql_difficulty="medium-hard",
1297
+ anchor_source="natural_earth",
1298
+ num_anchors=1,
1299
+ target_subtype="region",
1300
+ sql_template=(
1301
+ "WITH a AS ("
1302
+ " SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
1303
+ ")"
1304
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1305
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1306
+ " FROM read_parquet('divisions_area') AS b, a"
1307
+ " WHERE b.subtype = '{target_subtype}'"
1308
+ " AND ST_Intersects(b.geometry, a.geometry)"
1309
+ ),
1310
+ question_hints=[
1311
+ "Which regions does the {anchor_name} pass through?",
1312
+ "What administrative regions overlap with the {anchor_name}?",
1313
+ "Regions that the {anchor_name} crosses",
1314
+ "Administrative areas intersected by the {anchor_name}",
1315
+ "What provinces does the {anchor_name} span?",
1316
+ "Regions along the {anchor_name}",
1317
+ "Which provinces overlap the {anchor_name}?",
1318
+ ],
1319
+ ),
1320
+
1321
+ SQLTemplate(
1322
+ template_id="intersect_04",
1323
+ family="intersection",
1324
+ sql_difficulty="medium-hard",
1325
+ anchor_source="divisions_area",
1326
+ num_anchors=1,
1327
+ target_subtype="region",
1328
+ sql_template=(
1329
+ "WITH a AS ("
1330
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1331
+ ")"
1332
+ " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
1333
+ " ST_AsGeoJSON(n.geometry) AS geometry"
1334
+ " FROM read_parquet('natural_earth') AS n, a"
1335
+ " WHERE ST_Intersects(n.geometry, a.geometry)"
1336
+ ),
1337
+ question_hints=[
1338
+ "What natural features intersect {anchor_name}?",
1339
+ "Natural earth features that overlap {anchor_name}",
1340
+ "Which geographic features cross {anchor_name}?",
1341
+ "Everything from natural earth that touches {anchor_name}",
1342
+ "What geographic features does {anchor_name} contain?",
1343
+ "Natural features within or crossing {anchor_name}",
1344
+ ],
1345
+ ),
1346
+
1347
+ # ── NATURAL EARTH CHAINED ───────────────────────────────────────────────
1348
+ # chained_04: localities in a region that intersect a river or lake.
1349
+ # chained_05: localities in a region that lie on a mountain range.
1350
+
1351
+ SQLTemplate(
1352
+ template_id="chained_04",
1353
+ family="chained",
1354
+ sql_difficulty="hard",
1355
+ anchor_source="divisions_area",
1356
+ num_anchors=1,
1357
+ target_subtype="locality",
1358
+ sql_template=(
1359
+ "WITH region AS ("
1360
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1361
+ ")"
1362
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1363
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1364
+ " FROM read_parquet('divisions_area') AS b, region"
1365
+ " WHERE b.subtype = '{target_subtype}'"
1366
+ " AND ST_Within(b.geometry, region.geometry)"
1367
+ " AND EXISTS ("
1368
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
1369
+ " WHERE n.subtype IN ('River', 'Lake', 'Basin')"
1370
+ " AND ST_Intersects(b.geometry, n.geometry)"
1371
+ " )"
1372
+ ),
1373
+ question_hints=[
1374
+ "Riverside {target_subtype}s in {anchor_name}",
1375
+ "{target_subtype}s in {anchor_name} near a river or lake",
1376
+ "Which {target_subtype}s in {anchor_name} are on a waterway?",
1377
+ "Lakeside or riverside {target_subtype}s within {anchor_name}",
1378
+ "{target_subtype}s in {anchor_name} that touch a river",
1379
+ "Which {target_subtype}s in {anchor_name} are on a lake?",
1380
+ "Waterfront {target_subtype}s of {anchor_name}",
1381
+ ],
1382
+ ),
1383
+
1384
+ SQLTemplate(
1385
+ template_id="chained_05",
1386
+ family="chained",
1387
+ sql_difficulty="hard",
1388
+ anchor_source="divisions_area",
1389
+ num_anchors=1,
1390
+ target_subtype="locality",
1391
+ sql_template=(
1392
+ "WITH region AS ("
1393
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1394
+ ")"
1395
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1396
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1397
+ " FROM read_parquet('divisions_area') AS b, region"
1398
+ " WHERE b.subtype = '{target_subtype}'"
1399
+ " AND ST_Within(b.geometry, region.geometry)"
1400
+ " AND EXISTS ("
1401
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
1402
+ " WHERE n.subtype IN ('Range/Mts', 'Depression')"
1403
+ " AND ST_Intersects(b.geometry, n.geometry)"
1404
+ " )"
1405
+ ),
1406
+ question_hints=[
1407
+ "Mountain {target_subtype}s in {anchor_name}",
1408
+ "{target_subtype}s in {anchor_name} on a mountain range",
1409
+ "Which {target_subtype}s in {anchor_name} are in the mountains?",
1410
+ "Highland {target_subtype}s within {anchor_name}",
1411
+ "{target_subtype}s of {anchor_name} in mountainous terrain",
1412
+ "{target_subtype}s in {anchor_name} near a mountain range",
1413
+ ],
1414
+ ),
1415
+
1416
+ # ── CHAINED (county-level) ──────────────────────────────────────────────
1417
+ # Same spatial patterns as chained_01..05 but targeting counties/districts
1418
+ # so the model learns "coastal districts of X", "riverside counties", etc.
1419
+
1420
+ SQLTemplate(
1421
+ template_id="chained_06",
1422
+ family="chained",
1423
+ sql_difficulty="hard",
1424
+ anchor_source="divisions_area",
1425
+ num_anchors=1,
1426
+ target_subtype="county",
1427
+ sql_template=(
1428
+ "WITH region AS ("
1429
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1430
+ ")"
1431
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1432
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1433
+ " FROM read_parquet('divisions_area') AS b, region"
1434
+ " WHERE b.subtype = '{target_subtype}'"
1435
+ " AND ST_Within(b.geometry, region.geometry)"
1436
+ " AND EXISTS ("
1437
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
1438
+ " WHERE n.subtype IN ('ocean', 'sea')"
1439
+ " AND ST_Intersects(b.geometry, n.geometry)"
1440
+ " )"
1441
+ ),
1442
+ question_hints=[
1443
+ "Coastal {target_subtype}s of {anchor_name}",
1444
+ "Which districts of {anchor_name} are on the coast?",
1445
+ "{target_subtype}s in {anchor_name} that border the sea",
1446
+ "Seaside {target_subtype}s within {anchor_name}",
1447
+ "{target_subtype}s of {anchor_name} with ocean access",
1448
+ "Which {target_subtype}s in {anchor_name} touch the sea?",
1449
+ "Maritime {target_subtype}s of {anchor_name}",
1450
+ ],
1451
+ ),
1452
+
1453
+ SQLTemplate(
1454
+ template_id="chained_07",
1455
+ family="chained",
1456
+ sql_difficulty="hard",
1457
+ anchor_source="divisions_area",
1458
+ num_anchors=1,
1459
+ target_subtype="county",
1460
+ sql_template=(
1461
+ "WITH region AS ("
1462
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1463
+ ")"
1464
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1465
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1466
+ " FROM read_parquet('divisions_area') AS b, region"
1467
+ " WHERE b.subtype = '{target_subtype}'"
1468
+ " AND ST_Within(b.geometry, region.geometry)"
1469
+ " AND NOT EXISTS ("
1470
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
1471
+ " WHERE n.subtype IN ('ocean', 'sea')"
1472
+ " AND ST_Intersects(b.geometry, n.geometry)"
1473
+ " )"
1474
+ ),
1475
+ question_hints=[
1476
+ "Landlocked {target_subtype}s of {anchor_name}",
1477
+ "Which districts of {anchor_name} have no coastline?",
1478
+ "Interior {target_subtype}s within {anchor_name}",
1479
+ "{target_subtype}s in {anchor_name} with no sea access",
1480
+ "Non-coastal {target_subtype}s of {anchor_name}",
1481
+ "Inland {target_subtype}s of {anchor_name}",
1482
+ ],
1483
+ ),
1484
+
1485
+ SQLTemplate(
1486
+ template_id="chained_08",
1487
+ family="chained",
1488
+ sql_difficulty="hard",
1489
+ anchor_source="divisions_area",
1490
+ num_anchors=1,
1491
+ target_subtype="county",
1492
+ sql_template=(
1493
+ "WITH region AS ("
1494
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1495
+ ")"
1496
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1497
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1498
+ " FROM read_parquet('divisions_area') AS b, region"
1499
+ " WHERE b.subtype = '{target_subtype}'"
1500
+ " AND ST_Within(b.geometry, region.geometry)"
1501
+ " AND EXISTS ("
1502
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
1503
+ " WHERE n.subtype IN ('River', 'Lake', 'Basin')"
1504
+ " AND ST_Intersects(b.geometry, n.geometry)"
1505
+ " )"
1506
+ ),
1507
+ question_hints=[
1508
+ "Riverside {target_subtype}s of {anchor_name}",
1509
+ "Which districts of {anchor_name} have a river or lake?",
1510
+ "{target_subtype}s in {anchor_name} on a waterway",
1511
+ "Lakeside {target_subtype}s within {anchor_name}",
1512
+ "{target_subtype}s of {anchor_name} along a river",
1513
+ "Which {target_subtype}s in {anchor_name} border a lake?",
1514
+ ],
1515
+ ),
1516
+
1517
+ SQLTemplate(
1518
+ template_id="chained_09",
1519
+ family="chained",
1520
+ sql_difficulty="hard",
1521
+ anchor_source="divisions_area",
1522
+ num_anchors=1,
1523
+ target_subtype="county",
1524
+ sql_template=(
1525
+ "WITH region AS ("
1526
+ " SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
1527
+ ")"
1528
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
1529
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1530
+ " FROM read_parquet('divisions_area') AS b, region"
1531
+ " WHERE b.subtype = '{target_subtype}'"
1532
+ " AND ST_Within(b.geometry, region.geometry)"
1533
+ " AND EXISTS ("
1534
+ " SELECT 1 FROM read_parquet('natural_earth') AS n"
1535
+ " WHERE n.subtype IN ('Range/Mts', 'Depression')"
1536
+ " AND ST_Intersects(b.geometry, n.geometry)"
1537
+ " )"
1538
+ ),
1539
+ question_hints=[
1540
+ "Mountain {target_subtype}s of {anchor_name}",
1541
+ "Which districts of {anchor_name} are in the mountains?",
1542
+ "{target_subtype}s in {anchor_name} on a mountain range",
1543
+ "Highland {target_subtype}s within {anchor_name}",
1544
+ "{target_subtype}s of {anchor_name} in mountainous terrain",
1545
+ "Which {target_subtype}s in {anchor_name} have mountain ranges?",
1546
+ ],
1547
+ ),
1548
+
1549
+ # ── NATURAL EARTH CONTAINMENT ───────────────────────────────────────────
1550
+ # contain_04: NE anchor (sea/gulf/bay), find countries that touch it.
1551
+ # Uses containment handler via containment_pairs.
1552
+
1553
+ SQLTemplate(
1554
+ template_id="contain_04",
1555
+ family="containment",
1556
+ sql_difficulty="medium",
1557
+ anchor_source="natural_earth",
1558
+ num_anchors=1,
1559
+ target_subtype="country",
1560
+ sql_template=(
1561
+ "WITH a AS ("
1562
+ " SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
1563
+ ")"
1564
+ " SELECT b.id, b.names.\"primary\" AS name, b.subtype,"
1565
+ " ST_AsGeoJSON(b.geometry) AS geometry"
1566
+ " FROM read_parquet('divisions_area') AS b, a"
1567
+ " WHERE b.subtype = '{target_subtype}'"
1568
+ " AND ST_Intersects(b.geometry, a.geometry)"
1569
+ ),
1570
+ question_hints=[
1571
+ "Which countries border the {anchor_name}?",
1572
+ "What countries are along the {anchor_name}?",
1573
+ "Countries surrounding the {anchor_name}",
1574
+ "Nations on the {anchor_name}",
1575
+ "Which countries touch the {anchor_name}?",
1576
+ "Countries with coastline on the {anchor_name}",
1577
+ "What nations lie on the {anchor_name}?",
1578
+ ],
1579
+ ),
1580
+
1581
+ # ── NATURAL EARTH BUFFER ────────────────────────────────────────────────
1582
+ # buffer_05: NE anchor, find other NE features within a buffer distance.
1583
+ # Uses buffer handler for natural_earth.
1584
+
1585
+ SQLTemplate(
1586
+ template_id="buffer_05",
1587
+ family="buffer",
1588
+ sql_difficulty="hard",
1589
+ anchor_source="natural_earth",
1590
+ num_anchors=1,
1591
+ requires_buffer=True,
1592
+ sql_template=(
1593
+ "WITH a AS ("
1594
+ " SELECT ST_Buffer(geometry, {buffer_km} * 1000.0 / 111320.0) AS geom"
1595
+ " FROM read_parquet('natural_earth')"
1596
+ " WHERE id = '{anchor_id}'"
1597
+ ")"
1598
+ " SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
1599
+ " ST_AsGeoJSON(n.geometry) AS geometry"
1600
+ " FROM read_parquet('natural_earth') AS n, a"
1601
+ " WHERE ST_Intersects(n.geometry, a.geom)"
1602
+ ),
1603
+ question_hints=[
1604
+ "Natural features within {buffer_km} km of the {anchor_name}",
1605
+ "What is within {buffer_km} km of the {anchor_name}?",
1606
+ "Geographic features near the {anchor_name} within {buffer_km} km",
1607
+ "Everything within {buffer_km} km of the {anchor_name}",
1608
+ "What natural features are close to the {anchor_name}?",
1609
+ "{buffer_km} km radius around the {anchor_name}",
1610
+ ],
1611
+ ),
1612
+
1613
+ ]
1614
+
1615
+
1616
+ # ---------------------------------------------------------------------------
1617
+ # Helpers
1618
+ # ---------------------------------------------------------------------------
1619
+
1620
+ def get_templates_by_family(family: str) -> List[SQLTemplate]:
1621
+ """Return all templates for a specific task family."""
1622
+ return [t for t in TEMPLATES if t.family == family]
1623
+
1624
+
1625
+ def get_template_by_id(template_id: str) -> SQLTemplate:
1626
+ """Return a template by its ID, raising ValueError if not found."""
1627
+ for t in TEMPLATES:
1628
+ if t.template_id == template_id:
1629
+ return t
1630
+ raise ValueError(f"Template '{template_id}' not found")
1631
+
1632
+
1633
+ if __name__ == "__main__":
1634
+ families: dict = {}
1635
+ for t in TEMPLATES:
1636
+ families[t.family] = families.get(t.family, 0) + 1
1637
+
1638
+ print("SQL Template Catalog")
1639
+ print("=" * 60)
1640
+ for family, count in sorted(families.items()):
1641
+ print(f"{family:20s}: {count:2d} templates")
1642
+ print(f"{'TOTAL':20s}: {len(TEMPLATES):2d} templates")
1643
+
1644
+ # Verify every template's final SELECT wraps geometry with ST_AsGeoJSON
1645
+ print()
1646
+ print("Geometry output check (all should show ST_AsGeoJSON)")
1647
+ print("=" * 60)
1648
+ for t in TEMPLATES:
1649
+ has_geojson = "ST_AsGeoJSON" in t.sql_template
1650
+ status = "OK" if has_geojson else "MISSING"
1651
+ print(f" {t.template_id:20s}: {status}")
dataset/scripts/validate_dataset.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validate and balance the generated dataset.
3
+
4
+ This script:
5
+ 1. Loads all generated samples
6
+ 2. Validates SQL executability
7
+ 3. Checks candidate list quality
8
+ 4. Balances across task families and difficulty
9
+ 5. Removes duplicates
10
+ 6. Generates dataset statistics
11
+
12
+ Output:
13
+ - output/dataset_validated.jsonl
14
+ - output/dataset_stats.json
15
+ """
16
+
17
+ import json
18
+ from pathlib import Path
19
+ from typing import List, Dict, Any, Tuple
20
+ from collections import Counter
21
+ from concurrent.futures import ProcessPoolExecutor, as_completed
22
+
23
+ import duckdb
24
+ import pandas as pd
25
+
26
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
27
+
28
+
29
+ def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
30
+ """Load samples from JSONL file."""
31
+ samples = []
32
+ with open(jsonl_path, 'r') as f:
33
+ for line in f:
34
+ samples.append(json.loads(line))
35
+ return samples
36
+
37
+
38
+ def _resolve_paths(sql: str) -> str:
39
+ """Replace symbolic placeholder paths with actual runtime paths for execution."""
40
+ sql = sql.replace(
41
+ "read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')"
42
+ )
43
+ sql = sql.replace(
44
+ "read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
45
+ )
46
+ # Legacy fixed Docker paths from earlier dataset versions
47
+ sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
48
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
49
+ sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
50
+ return sql
51
+
52
+
53
+ def _to_symbolic_sql(sql: str) -> str:
54
+ """Normalize any hardcoded or runtime paths back to symbolic names for storage."""
55
+ # Current local runtime paths
56
+ sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
57
+ sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
58
+ # Legacy Docker paths
59
+ sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
60
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
61
+ sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
62
+ return sql
63
+
64
+
65
+ def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
66
+ """Validate that SQL executes without error.
67
+
68
+ Resolves symbolic path placeholders to actual runtime paths before execution.
69
+ """
70
+ try:
71
+ result = con.execute(_resolve_paths(sql)).fetchdf()
72
+ if result.empty:
73
+ return False, "Empty result"
74
+ return True, "OK"
75
+ except Exception as e:
76
+ return False, str(e)
77
+
78
+
79
+ def validate_candidates(sample: Dict[str, Any]) -> tuple[bool, str]:
80
+ """Validate candidate list quality."""
81
+ candidates = sample['candidates']
82
+ selected = sample['target']['selected_candidates']
83
+
84
+ # Check we have candidates
85
+ if not candidates:
86
+ return False, "No candidates"
87
+
88
+ # Check selected candidates exist
89
+ candidate_ids = {c['candidate_id'] for c in candidates}
90
+ for sel_id in selected:
91
+ if sel_id not in candidate_ids:
92
+ return False, f"Selected candidate {sel_id} not in candidate list"
93
+
94
+ # Check for duplicates
95
+ ids = [c['id'] for c in candidates]
96
+ if len(ids) != len(set(ids)):
97
+ return False, "Duplicate candidates"
98
+
99
+ return True, "OK"
100
+
101
+
102
+ def validate_sample(con: duckdb.DuckDBPyConnection, sample: Dict[str, Any]) -> tuple[bool, List[str]]:
103
+ """Validate a single sample. Returns (is_valid, list_of_issues)."""
104
+ issues = []
105
+
106
+ # Skip SQL re-execution if already verified during generation
107
+ if not sample.get('metadata', {}).get('sql_verified', False):
108
+ sql_valid, sql_msg = validate_sql(con, sample['target']['sql'])
109
+ if not sql_valid:
110
+ issues.append(f"SQL: {sql_msg}")
111
+
112
+ # Validate candidates
113
+ cand_valid, cand_msg = validate_candidates(sample)
114
+ if not cand_valid:
115
+ issues.append(f"Candidates: {cand_msg}")
116
+
117
+ # Check question exists
118
+ if not sample.get('question') or len(sample['question'].strip()) == 0:
119
+ issues.append("Empty question")
120
+
121
+ return len(issues) == 0, issues
122
+
123
+
124
+ def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]]:
125
+ """Worker function for parallel validation. Returns (sample_id, is_valid, issues)."""
126
+ # Each worker creates its own DuckDB connection
127
+ con = duckdb.connect()
128
+ con.execute("SET enable_progress_bar=false")
129
+ con.execute("INSTALL spatial")
130
+ con.execute("LOAD spatial")
131
+
132
+ try:
133
+ is_valid, issues = validate_sample(con, sample)
134
+ con.close()
135
+ if is_valid:
136
+ sample['target']['sql'] = _to_symbolic_sql(sample['target']['sql'])
137
+ return (sample['id'], is_valid, issues, sample if is_valid else None)
138
+ except Exception as e:
139
+ con.close()
140
+ return (sample['id'], False, [f"Validation error: {str(e)}"], None)
141
+
142
+
143
+ def compute_statistics(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
144
+ """Compute dataset statistics."""
145
+
146
+ stats = {
147
+ 'total_samples': len(samples),
148
+ 'task_families': {},
149
+ 'sql_difficulty': {},
150
+ 'grounding_difficulty': {},
151
+ 'anchor_sources': {},
152
+ 'avg_candidates_per_sample': 0,
153
+ 'avg_question_length': 0,
154
+ 'countries_covered': set(),
155
+ 'subtypes_covered': set()
156
+ }
157
+
158
+ total_candidates = 0
159
+ total_question_length = 0
160
+
161
+ for sample in samples:
162
+ meta = sample['metadata']
163
+
164
+ # Count by family
165
+ family = meta['task_family']
166
+ stats['task_families'][family] = stats['task_families'].get(family, 0) + 1
167
+
168
+ # Count by SQL difficulty
169
+ sql_diff = meta['sql_difficulty']
170
+ stats['sql_difficulty'][sql_diff] = stats['sql_difficulty'].get(sql_diff, 0) + 1
171
+
172
+ # Count by grounding difficulty
173
+ ground_diff = meta['grounding_difficulty']
174
+ stats['grounding_difficulty'][ground_diff] = stats['grounding_difficulty'].get(ground_diff, 0) + 1
175
+
176
+ # Count by anchor source
177
+ anchor_src = meta['anchor_source']
178
+ stats['anchor_sources'][anchor_src] = stats['anchor_sources'].get(anchor_src, 0) + 1
179
+
180
+ # Candidates
181
+ total_candidates += len(sample['candidates'])
182
+
183
+ # Question length
184
+ total_question_length += len(sample['question'].split())
185
+
186
+ # Countries and subtypes (from selected/answer candidates only)
187
+ selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
188
+ for cand in sample['candidates']:
189
+ if cand['candidate_id'] in selected_ids:
190
+ if cand.get('country'):
191
+ stats['countries_covered'].add(cand['country'])
192
+ if cand.get('subtype'):
193
+ stats['subtypes_covered'].add(cand['subtype'])
194
+
195
+ stats['avg_candidates_per_sample'] = total_candidates / len(samples) if samples else 0
196
+ stats['avg_question_length'] = total_question_length / len(samples) if samples else 0
197
+ stats['countries_covered'] = sorted(list(stats['countries_covered']))
198
+ stats['subtypes_covered'] = sorted(list(stats['subtypes_covered']))
199
+
200
+ return stats
201
+
202
+
203
+ def main():
204
+ """Validate and analyze dataset."""
205
+
206
+ script_dir = Path(__file__).parent
207
+ output_dir = script_dir.parent / "output"
208
+
209
+ raw_file = output_dir / "dataset_raw.jsonl"
210
+ validated_file = output_dir / "dataset_validated.jsonl"
211
+ stats_file = output_dir / "dataset_stats.json"
212
+
213
+ if not raw_file.exists():
214
+ print(f"Error: {raw_file} not found. Run generate_samples.py first.")
215
+ return
216
+
217
+ # Load samples
218
+ print("Loading samples...")
219
+ samples = load_samples(raw_file)
220
+ print(f"Loaded {len(samples)} samples")
221
+
222
+ # Validate samples in parallel
223
+ print("\nValidating samples in parallel...")
224
+ valid_samples = []
225
+ invalid_samples = []
226
+
227
+ with ProcessPoolExecutor(max_workers=8) as executor:
228
+ # Submit all validation tasks
229
+ futures = {executor.submit(validate_sample_worker, sample): sample for sample in samples}
230
+
231
+ # Collect results as they complete
232
+ completed = 0
233
+ for future in as_completed(futures):
234
+ sample_id, is_valid, issues, validated_sample = future.result()
235
+
236
+ if is_valid:
237
+ valid_samples.append(validated_sample)
238
+ else:
239
+ invalid_samples.append((sample_id, issues))
240
+
241
+ completed += 1
242
+ if completed % 50 == 0 or completed == len(samples):
243
+ print(f"\r Progress: {completed}/{len(samples)} ", end='', flush=True)
244
+
245
+ print() # New line after progress
246
+
247
+ print(f"\nValidation results:")
248
+ print(f" Valid: {len(valid_samples)}")
249
+ print(f" Invalid: {len(invalid_samples)}")
250
+
251
+ if invalid_samples and len(invalid_samples) <= 20:
252
+ print("\nInvalid samples:")
253
+ for sample_id, issues in invalid_samples[:20]:
254
+ print(f" {sample_id}: {', '.join(issues)}")
255
+ elif invalid_samples:
256
+ print(f"\n{len(invalid_samples)} invalid samples (showing first 20):")
257
+ for sample_id, issues in invalid_samples[:20]:
258
+ print(f" {sample_id}: {', '.join(issues)}")
259
+
260
+ # Save validated samples
261
+ if valid_samples:
262
+ with open(validated_file, 'w') as f:
263
+ for sample in valid_samples:
264
+ f.write(json.dumps(sample) + '\n')
265
+ print(f"\nSaved {len(valid_samples)} valid samples to {validated_file}")
266
+
267
+ # Compute statistics
268
+ print("\nComputing statistics...")
269
+ stats = compute_statistics(valid_samples)
270
+
271
+ # Save statistics
272
+ # Convert sets to lists for JSON serialization
273
+ stats_json = {k: (list(v) if isinstance(v, set) else v) for k, v in stats.items()}
274
+ with open(stats_file, 'w') as f:
275
+ json.dump(stats_json, f, indent=2)
276
+ print(f"Saved statistics to {stats_file}")
277
+
278
+ # Print summary
279
+ print("\n" + "=" * 60)
280
+ print("DATASET STATISTICS")
281
+ print("=" * 60)
282
+ print(f"\nTotal samples: {stats['total_samples']}")
283
+
284
+ print("\nTask families:")
285
+ for family, count in sorted(stats['task_families'].items()):
286
+ print(f" {family:20s}: {count:3d}")
287
+
288
+ print("\nSQL difficulty:")
289
+ for diff, count in sorted(stats['sql_difficulty'].items()):
290
+ print(f" {diff:20s}: {count:3d}")
291
+
292
+ print("\nGrounding difficulty:")
293
+ for diff, count in sorted(stats['grounding_difficulty'].items()):
294
+ print(f" {diff:20s}: {count:3d}")
295
+
296
+ print("\nAnchor sources:")
297
+ for src, count in sorted(stats['anchor_sources'].items()):
298
+ print(f" {src:20s}: {count:3d}")
299
+
300
+ print(f"\nAverage candidates per sample: {stats['avg_candidates_per_sample']:.1f}")
301
+ print(f"Average question length (words): {stats['avg_question_length']:.1f}")
302
+ print(f"Countries covered: {len(stats['countries_covered'])}")
303
+ print(f"Subtypes covered: {len(stats['subtypes_covered'])}")
304
+
305
+ print("\n✓ Validation complete")
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()
docker-compose.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ llama:
3
+ image: ghcr.io/ggml-org/llama.cpp:server
4
+ volumes:
5
+ - ./finetune/models/qwen-base-run/ckpt-001.gguf:/models/model.gguf:ro
6
+ command: >
7
+ -m /models/model.gguf
8
+ --port 9000
9
+ --host 0.0.0.0
10
+ --ctx-size 2048
11
+ -t 4
12
+ healthcheck:
13
+ test: ["CMD", "curl", "-f", "http://localhost:9000/health"]
14
+ interval: 10s
15
+ timeout: 5s
16
+ retries: 30
17
+ start_period: 30s
18
+
19
+ app:
20
+ build: .
21
+ volumes:
22
+ - ./data:/data:ro
23
+ environment:
24
+ GAZET_DATA_DIR: /data
25
+ LLAMA_SERVER_URL: http://llama:9000
26
+ ports:
27
+ - "8000:8000"
28
+ command: uvicorn gazet.api:app --host 0.0.0.0 --port 8000
29
+ depends_on:
30
+ llama:
31
+ condition: service_healthy
32
+
33
+ demo:
34
+ build: .
35
+ environment:
36
+ GAZET_API_URL: http://app:8000
37
+ ports:
38
+ - "8501:8501"
39
+ command: streamlit run gazet_demo.py --server.port 8501 --server.address 0.0.0.0
40
+ depends_on:
41
+ - app
finetune/README.md ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning and Inference
2
+
3
+ LoRA fine-tuning of Qwen3.5-0.8B (via Unsloth) to perform two geospatial
4
+ tasks (text-to-SQL and place extraction), then serving locally via
5
+ llama-server.
6
+
7
+ ---
8
+
9
+ ## End-to-end workflow
10
+
11
+ ```
12
+ 1. Generate dataset → dataset/ (see dataset/README.md)
13
+ 2. Check token lengths → check_token_lengths.py
14
+ 3. Train on Modal → train_modal_qwen35.py
15
+ 4. Convert to GGUF → llama.cpp
16
+ 5. Serve locally → llama-server
17
+ 6. Eval locally → eval_cli.py (interactive or batch) + eval_demo.py
18
+ ```
19
+
20
+ ---
21
+
22
+ ## Step 1 — Check token lengths
23
+
24
+ Before training, verify that your `max_length` setting covers the data.
25
+ SQL samples are long (schema + candidates + SQL), places samples are short.
26
+
27
+ ```bash
28
+ modal run finetune/check_token_lengths.py
29
+ modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/v1
30
+ ```
31
+
32
+ This prints per-split statistics (min, max, P95, P99) and recommends a
33
+ `max_length` value. Adjust `--max-seq-length` in `train_modal_qwen35.py` accordingly.
34
+
35
+ ---
36
+
37
+ ## Step 2 — Train (Qwen3.5 + Unsloth)
38
+
39
+ Training runs on Modal with an A100-80GB GPU. The script loads both SQL and
40
+ places JSONL files from the run directory, applies the Qwen3.5 ChatML
41
+ template, and trains a LoRA adapter using Unsloth's
42
+ `train_on_responses_only` to mask non-assistant tokens.
43
+
44
+ ```bash
45
+ # Default settings (Qwen3.5-0.8B, r=16, 1 epoch)
46
+ modal run finetune/train_modal_qwen35.py --experiment-name qwen35-v1
47
+
48
+ # Override any config field from CLI
49
+ modal run finetune/train_modal_qwen35.py \
50
+ --experiment-name qwen35-v1 \
51
+ --base-model unsloth/Qwen3.5-0.8B \
52
+ --num-train-epochs 3 \
53
+ --lora-r 32 \
54
+ --max-seq-length 2048
55
+
56
+ # Quick smoke test
57
+ modal run finetune/train_modal_qwen35.py --experiment-name qwen35-v1 --max-train-samples 100
58
+ ```
59
+
60
+ All CLI overrides: `--base-model`, `--experiment-name`, `--run-dir`,
61
+ `--num-train-epochs`, `--per-device-train-batch-size`, `--max-train-samples`,
62
+ `--max-eval-samples`, `--lora-r`, `--max-seq-length`. When `--lora-r` is
63
+ overridden, `lora_alpha` is automatically set to `2 * r`.
64
+
65
+ ### Training config defaults (`Qwen35Config`)
66
+
67
+ ```
68
+ base_model: unsloth/Qwen3.5-0.8B
69
+ run_dir: /mnt/gazet/data/v1
70
+ lora_r: 16
71
+ lora_alpha: 32 (2 * r, Unsloth recommendation for Qwen)
72
+ lora_dropout: 0.0
73
+ num_train_epochs: 1
74
+ batch_size: 32 (x 1 gradient accumulation = 32 effective)
75
+ learning_rate: 1e-4
76
+ lr_scheduler: linear
77
+ optim: adamw_8bit
78
+ max_seq_length: 2048
79
+ ```
80
+
81
+ ### Output
82
+
83
+ Checkpoints and the merged model are saved to the Modal volume:
84
+
85
+ ```
86
+ /mnt/gazet/checkpoints/{experiment_name}/
87
+ adapter_config.json # LoRA adapter
88
+ adapter_model.safetensors
89
+ checkpoint-*/ # intermediate checkpoints
90
+ merged/ # full merged 16-bit model
91
+ model.safetensors
92
+ tokenizer.json
93
+ ```
94
+
95
+ Pass `--experiment-name` to set a human-readable name (e.g. `qwen35-v1`).
96
+ If omitted, it is auto-generated as `{model}-r{lora_r}-{timestamp}`.
97
+
98
+ Training metrics are logged to [trackio](https://huggingface.co/spaces/srmsoumya/gazet-trackio).
99
+
100
+ ---
101
+
102
+ ## Step 3 — Convert merged model to GGUF
103
+
104
+ After training, download the merged model from Modal and convert to GGUF
105
+ for local inference with llama-server.
106
+
107
+ ```bash
108
+ # Download from Modal volume
109
+ modal volume get gazet checkpoints/qwen35-v1/merged ./finetune/models/merged
110
+
111
+ # Convert to GGUF (requires llama.cpp repo)
112
+ uv run \
113
+ --no-project \
114
+ --with transformers \
115
+ --with sentencepiece \
116
+ --with protobuf \
117
+ --with torch \
118
+ python convert_hf_to_gguf.py \
119
+ ../gazet/finetune/models/qwen-base/merged \
120
+ --outtype q8_0 \
121
+ --outfile ../gazet/finetune/models/qwen-base/ckpt-001.gguf
122
+ ```
123
+
124
+ ---
125
+
126
+ ## Step 4 — Serve with llama-server
127
+
128
+ ### Local
129
+
130
+ ```bash
131
+ llama-server \
132
+ -m finetune/models/qwen-base/ckpt-001.gguf \
133
+ -ngl 99 \
134
+ --port 9000 \
135
+ --ctx-size 2048
136
+ ```
137
+
138
+ `--ctx-size` is the total KV cache shared across all parallel slots. SQL
139
+ prompts can be ~600 tokens; with `--parallel 4` and up to 2048 output
140
+ tokens, use at least `8192`. Match `--parallel` to `--workers` in
141
+ `eval_cli.py`.
142
+
143
+ ### Docker (CPU-only)
144
+
145
+ Useful for testing inference in a constrained environment. Adjust `--cpus`
146
+ and `--memory` to simulate deployment targets. Set `-t` to match `--cpus`.
147
+
148
+ ```bash
149
+ docker run \
150
+ --cpus="2" --memory="4g" \
151
+ -v $(pwd)/finetune/models:/models \
152
+ -p 9000:9000 \
153
+ ghcr.io/ggml-org/llama.cpp:server \
154
+ -m /models/qwen-base/ckpt-001.gguf \
155
+ --port 9000 --host 0.0.0.0 \
156
+ --ctx-size 2048 -t 2 -v
157
+ ```
158
+
159
+ Notes:
160
+ - `--host 0.0.0.0` is required so the port forward from Docker works
161
+ - `-v` (verbose) enables per-request timing logs (prompt eval t/s, generation t/s)
162
+ - `-ngl` is omitted since the default Docker image is CPU-only; for GPU use
163
+ the CUDA image (`ghcr.io/ggml-org/llama.cpp:server-cuda`) with `--gpus`
164
+ - The model is memory-mapped by default (`mmap = true`), so containers with
165
+ less RAM than the model size may still start but will be slow due to page
166
+ thrashing
167
+
168
+ The server exposes `/v1/chat/completions` (chat API) on
169
+ `http://localhost:9000`. All eval scripts use this endpoint.
170
+
171
+ ---
172
+
173
+ ## Step 5 — Evaluate
174
+
175
+ Two evaluation tools, both using a locally running llama-server.
176
+
177
+ ### Interactive or batch eval (`eval_cli.py`)
178
+
179
+ Requires llama-server running on port 9000 (see Step 4).
180
+
181
+ **Interactive** — spot-check individual samples:
182
+
183
+ ```bash
184
+ uv run finetune/eval_cli.py # prompts for sample index
185
+ uv run finetune/eval_cli.py 0 5 12 # run specific samples
186
+ uv run finetune/eval_cli.py --task places 0 5
187
+ uv run finetune/eval_cli.py -v 0 # print full prompt
188
+ ```
189
+
190
+ **Batch** — run the full split and save a JSON results file:
191
+
192
+ ```bash
193
+ # Full val set, SQL task
194
+ uv run finetune/eval_cli.py --all --label finetuned-qwen35
195
+
196
+ # Places task
197
+ uv run finetune/eval_cli.py --all --task places --label finetuned-places
198
+
199
+ # Limit samples, custom output path
200
+ uv run finetune/eval_cli.py --all --max-samples 100 --output results/eval-v5.json
201
+
202
+ # Evaluate test split instead of val
203
+ uv run finetune/eval_cli.py --all --split test --label finetuned-qwen35
204
+ ```
205
+
206
+ All batch CLI args:
207
+
208
+ | Arg | Default | Description |
209
+ |-----|---------|-------------|
210
+ | `--all` | off | Enable batch mode |
211
+ | `--label` | `local-gguf` | Label used in the output filename |
212
+ | `--task` | `sql` | `sql` or `places` |
213
+ | `--split` | `val` | Data split to evaluate (`val`, `test`) |
214
+ | `--run-dir` | `dataset/output/runs/v1` | Directory with `{task}/{split}.jsonl` |
215
+ | `--max-samples` | all | Cap the number of samples |
216
+ | `--output` | `eval-{label}-{task}.json` | Output JSON path |
217
+ | `--workers` | `4` | Concurrent requests; match llama-server `--parallel` |
218
+
219
+ Results are saved to `results/eval-{label}-{task}.json` with this structure:
220
+
221
+ ```json
222
+ {
223
+ "summary": {"label": "...", "task": "sql", "exact_match_rate": 0.85, ...},
224
+ "results": [
225
+ {"index": 0, "question": "...", "expected": "...", "predicted": "...", "exact_match": true},
226
+ ...
227
+ ]
228
+ }
229
+ ```
230
+
231
+ Config constants at the top of `eval_cli.py`: `SERVER_URL` (default
232
+ `http://localhost:9000`), `MAX_TOKENS` (2048), `TEMPERATURE` (0.6).
233
+
234
+ ### Visual eval (`eval_demo.py`)
235
+
236
+ Streamlit app that loads JSON results from `eval_cli.py --all` and displays
237
+ them interactively. For SQL results, it shows formatted SQL side-by-side,
238
+ a diff view for mismatches, and executes both queries against DuckDB to
239
+ render the geometry on a map. For places results, it shows expected vs
240
+ predicted JSON.
241
+
242
+ ```bash
243
+ streamlit run finetune/eval_demo.py
244
+ ```
245
+
246
+ Reads result files from `results/eval-*.json` by default. Override with:
247
+
248
+ ```bash
249
+ GAZET_EVAL_DIR=/path/to/results streamlit run finetune/eval_demo.py
250
+ ```
251
+
252
+ Set `GAZET_DATA_DIR` if your parquet data is not in the default `data/` directory.
253
+
254
+ ---
255
+
256
+ ## File reference
257
+
258
+ | File | What it does |
259
+ |---|---|
260
+ | `train_modal_qwen35.py` | Modal training script — Qwen3.5 LoRA fine-tuning with Unsloth |
261
+ | `check_token_lengths.py` | Modal script to analyze token length distribution before training |
262
+ | `eval_cli.py` | Local eval — interactive spot-check or full batch mode via llama-server |
263
+ | `eval_demo.py` | Streamlit app — visual diff + map rendering of `eval_cli.py --all` results |
264
+ | `models/` | GGUF model files for local llama-server inference |
265
+
266
+ ---
267
+
268
+ ## Data format
269
+
270
+ The Qwen3.5 training pipeline (`train_modal_qwen35.py`) expects data in
271
+ **messages format**:
272
+
273
+ ```json
274
+ {
275
+ "messages": [
276
+ {"role": "system", "content": "You are a text to SQL query translator..."},
277
+ {"role": "user", "content": "GIVEN the <SCHEMA_DETAILS>..."},
278
+ {"role": "assistant", "content": "SELECT ST_AsGeoJSON(geometry) ..."}
279
+ ]
280
+ }
281
+ ```
282
+
283
+ The Qwen3.5 chat template (ChatML) is applied by the tokenizer. Unsloth's
284
+ `train_on_responses_only` then masks everything before the assistant
285
+ response marker (`<|im_start|>assistant\n<think>\n\n</think>\n\n`), so
286
+ loss is computed only on the completion tokens.
287
+
288
+ SQL in the training data uses symbolic path placeholders
289
+ (`read_parquet('divisions_area')`) instead of real file paths. At inference
290
+ time, `src/gazet/sql.py` replaces these with actual runtime paths before
291
+ executing against DuckDB.
finetune/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
finetune/check_token_lengths.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Check token lengths of training samples to validate max_length setting.
2
+
3
+ Usage
4
+ -----
5
+ modal run finetune/check_token_lengths.py
6
+ modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/v1
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import modal
12
+
13
+ app = modal.App("gazet-check-token-lengths")
14
+
15
+ check_image = (
16
+ modal.Image.debian_slim(python_version="3.11")
17
+ .pip_install(
18
+ "datasets>=3.0",
19
+ "transformers>=4.46",
20
+ "jinja2>=3.1",
21
+ )
22
+ .add_local_python_source("finetune", copy=True)
23
+ .env({"HF_HOME": "/mnt/gazet/model_cache"})
24
+ )
25
+
26
+ gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
27
+
28
+ VOLUMES = {
29
+ "/mnt/gazet": gazet_vol,
30
+ }
31
+
32
+
33
+ @app.function(
34
+ image=check_image,
35
+ volumes=VOLUMES,
36
+ secrets=[modal.Secret.from_name("huggingface-secret")],
37
+ )
38
+ def analyze_token_lengths(run_dir: str, base_model: str):
39
+ import json
40
+ import pathlib
41
+
42
+ from datasets import Dataset, DatasetDict
43
+ from transformers import AutoTokenizer
44
+
45
+ def load_jsonl(path):
46
+ rows = []
47
+ with open(path) as f:
48
+ for line in f:
49
+ line = line.strip()
50
+ if line:
51
+ rows.append(json.loads(line))
52
+ return rows
53
+
54
+ print(f"Loading tokenizer: {base_model}")
55
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
56
+
57
+ root = pathlib.Path(run_dir)
58
+ ds_dict = {}
59
+ for split in ("train", "val", "test"):
60
+ combined = []
61
+ for task in ("sql", "places"):
62
+ path = root / task / f"{split}.jsonl"
63
+ if path.exists():
64
+ combined.extend(load_jsonl(path))
65
+ if combined:
66
+ ds_dict[split] = Dataset.from_list(combined)
67
+ ds = DatasetDict(ds_dict)
68
+
69
+ def token_lengths(dataset):
70
+ lengths = []
71
+ for row in dataset:
72
+ msgs = row["messages"]
73
+ text = tokenizer.apply_chat_template(msgs, tokenize=False)
74
+ lengths.append(len(tokenizer.encode(text)))
75
+ return lengths
76
+
77
+ def report(split_name: str, lengths: list[int]):
78
+ lengths.sort()
79
+ n = len(lengths)
80
+ if not n:
81
+ print(f"\n{split_name}: empty")
82
+ return
83
+
84
+ print(f"\n{'='*60}")
85
+ print(f"{split_name} ({n:,} samples)")
86
+ print(f"{'='*60}")
87
+ print(f" Min: {min(lengths):,}")
88
+ print(f" Max: {max(lengths):,}")
89
+ print(f" Mean: {sum(lengths)/n:.0f}")
90
+ print(f" Median: {lengths[n//2]:,}")
91
+ print(f" P90: {lengths[int(n*0.90)]:,}")
92
+ print(f" P95: {lengths[int(n*0.95)]:,}")
93
+ print(f" P99: {lengths[int(n*0.99)]:,}")
94
+
95
+ buckets = [512, 1024, 2048, 4096, 8192]
96
+ print(f"\n Distribution:")
97
+ prev = 0
98
+ for limit in buckets:
99
+ count = sum(1 for l in lengths if prev < l <= limit)
100
+ pct = 100 * count / n
101
+ bar = "#" * int(pct / 2)
102
+ print(f" {prev+1:>5}-{limit:<5}: {count:5,} ({pct:5.1f}%) {bar}")
103
+ prev = limit
104
+ over = sum(1 for l in lengths if l > buckets[-1])
105
+ if over:
106
+ print(f" {buckets[-1]+1:>5}+ : {over:5,} ({100*over/n:5.1f}%)")
107
+
108
+ return lengths
109
+
110
+ all_lengths = []
111
+ for split in ("train", "val", "test"):
112
+ if split not in ds:
113
+ continue
114
+ lengths = token_lengths(ds[split])
115
+ report(split, lengths)
116
+ all_lengths.extend(lengths)
117
+
118
+ if all_lengths:
119
+ all_lengths.sort()
120
+ n = len(all_lengths)
121
+ max_len = max(all_lengths)
122
+ p99 = all_lengths[int(n * 0.99)]
123
+
124
+ print(f"\n{'='*60}")
125
+ print(f"RECOMMENDATION")
126
+ print(f"{'='*60}")
127
+ print(f" Total samples: {n:,}")
128
+ print(f" Max length: {max_len:,}")
129
+ print(f" P99: {p99:,}")
130
+
131
+ for threshold in [1024, 2048, 4096]:
132
+ over = sum(1 for l in all_lengths if l > threshold)
133
+ pct = 100 * over / n
134
+ print(f" > {threshold:5,}: {over:5,} ({pct:5.1f}%)")
135
+
136
+ if max_len <= 1024:
137
+ print(f"\n All samples fit in 1024 tokens. Use --max-length 1024.")
138
+ elif max_len <= 2048:
139
+ print(f"\n All samples fit in 2048 tokens. Use --max-length 2048.")
140
+ else:
141
+ over_2048 = sum(1 for l in all_lengths if l > 2048)
142
+ print(f"\n {over_2048} samples exceed 2048. Consider --max-length {max_len}")
143
+ print(f" or reduce candidate count to keep samples shorter.")
144
+
145
+
146
+ @app.local_entrypoint()
147
+ def main(
148
+ run_dir: str = "/mnt/gazet/data/v1",
149
+ base_model: str = "unsloth/Qwen3.5-0.8B",
150
+ ):
151
+ print(f"Checking token lengths:")
152
+ print(f" Model: {base_model}")
153
+ print(f" Run dir: {run_dir}")
154
+ analyze_token_lengths.remote(run_dir, base_model)
155
+ print("Analysis complete!")
finetune/eval_cli.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interactive eval: run test samples through the local GGUF model.
2
+
3
+ Requires llama-server running on port 8080:
4
+ llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 4096 --log-disable
5
+
6
+ Uses the /v1/chat/completions endpoint with a messages list. The Qwen3 GGUF
7
+ embeds its chat template in metadata, so llama-server applies it automatically.
8
+
9
+ Usage
10
+ -----
11
+ uv run finetune/eval_cli.py # prompts for index
12
+ uv run finetune/eval_cli.py 5 # run sample at index 5
13
+ uv run finetune/eval_cli.py 5 12 20 # run multiple samples
14
+
15
+ Use --task places for place extraction:
16
+ uv run finetune/eval_cli.py --task places 0 5
17
+
18
+ Override run directory:
19
+ uv run finetune/eval_cli.py --run-dir dataset/output/runs/v1 0
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import sys
27
+ import urllib.error
28
+ import urllib.request
29
+ from concurrent.futures import ThreadPoolExecutor, as_completed
30
+ from datetime import datetime
31
+ from pathlib import Path
32
+
33
+ SERVER_URL = "http://localhost:9000"
34
+ MAX_TOKENS = 2048
35
+ TEMPERATURE = 0.6
36
+
37
+ DEFAULT_RUN_DIR = Path("dataset/output/runs/v1")
38
+
39
+
40
+ def postprocess_sql(text: str) -> str:
41
+ cleaned = text.strip()
42
+ if "```sql" in cleaned:
43
+ cleaned = cleaned.split("```sql", 1)[1]
44
+ if cleaned.startswith("```"):
45
+ cleaned = cleaned[3:]
46
+ if "```" in cleaned:
47
+ cleaned = cleaned.split("```", 1)[0]
48
+ return cleaned.strip()
49
+
50
+
51
+ def check_server() -> bool:
52
+ try:
53
+ urllib.request.urlopen(f"{SERVER_URL}/health", timeout=2)
54
+ return True
55
+ except Exception:
56
+ return False
57
+
58
+
59
+ def chat_complete(messages: list[dict]) -> str:
60
+ """Call llama-server /v1/chat/completions with a messages list."""
61
+ payload = json.dumps({
62
+ "messages": messages,
63
+ "n_predict": MAX_TOKENS,
64
+ "temperature": TEMPERATURE,
65
+ "chat_template_kwargs": {"enable_thinking": False},
66
+ }).encode()
67
+
68
+ req = urllib.request.Request(
69
+ f"{SERVER_URL}/v1/chat/completions",
70
+ data=payload,
71
+ headers={"Content-Type": "application/json"},
72
+ )
73
+ with urllib.request.urlopen(req, timeout=60) as resp:
74
+ return json.loads(resp.read())["choices"][0]["message"]["content"]
75
+
76
+
77
+ def load_samples(run_dir: Path, task: str, split: str = "val") -> list[dict]:
78
+ path = run_dir / task / f"{split}.jsonl"
79
+ if not path.exists():
80
+ print(f"Error: {path} not found")
81
+ sys.exit(1)
82
+ print(f"Loading {task} samples from: {path}")
83
+ with path.open() as f:
84
+ return [json.loads(line) for line in f if line.strip()]
85
+
86
+
87
+ def build_raw_prompt(sample: dict) -> str:
88
+ """Reconstruct the plain prompt string from messages format (all turns except assistant)."""
89
+ return "\n\n".join(m["content"] for m in sample["messages"][:-1])
90
+
91
+
92
+ def eval_sample(sample: dict, task: str) -> dict:
93
+ """Run a single sample through the server and return a result dict."""
94
+ expected = sample["messages"][-1]["content"]
95
+ messages = sample["messages"][:-1]
96
+
97
+ user_content = sample["messages"][-2]["content"]
98
+ if "<USER_QUERY>" in user_content:
99
+ question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
100
+ else:
101
+ question = user_content[:120]
102
+
103
+ raw = chat_complete(messages)
104
+ predicted = postprocess_sql(raw) if task == "sql" else raw.strip()
105
+ return {
106
+ "question": question,
107
+ "expected": expected,
108
+ "predicted": predicted,
109
+ "exact_match": predicted.strip() == expected.strip(),
110
+ }
111
+
112
+
113
+ def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
114
+ user_content = sample["messages"][-2]["content"]
115
+ if "<USER_QUERY>" in user_content:
116
+ question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
117
+ else:
118
+ question = user_content[:120]
119
+
120
+ header = f" Sample {index}/{total-1} | {task} "
121
+ print(f"\n{'━' * 60}")
122
+ print(f"{'━' * ((60 - len(header)) // 2)}{header}{'━' * ((60 - len(header)) // 2)}")
123
+ print(f"{'━' * 60}")
124
+ print(f"\nQuestion: {question}\n")
125
+
126
+ if verbose:
127
+ prompt = build_raw_prompt(sample)
128
+ print(f"{'─' * 60}")
129
+ print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())} words):")
130
+ print(f"{'─' * 60}")
131
+ print(prompt)
132
+
133
+ result = eval_sample(sample, task)
134
+
135
+ print(f"{'─' * 60}")
136
+ print("Expected:")
137
+ print(f"{'─' * 60}")
138
+ print(result["expected"])
139
+
140
+ print(f"\n{'─' * 60}")
141
+ print("Generated:")
142
+ print(f"{'─' * 60}")
143
+ print(result["predicted"])
144
+
145
+ print(f"\n{'─' * 60}")
146
+ print(f"Match: {'YES' if result['exact_match'] else 'NO'}")
147
+
148
+
149
+ def run_batch(
150
+ samples: list[dict],
151
+ task: str,
152
+ label: str,
153
+ output_path: Path,
154
+ workers: int = 8,
155
+ ) -> None:
156
+ """Run all samples concurrently and save results to a JSON file."""
157
+ total = len(samples)
158
+ results = [None] * total
159
+ completed = 0
160
+
161
+ with ThreadPoolExecutor(max_workers=workers) as executor:
162
+ futures = {executor.submit(eval_sample, s, task): i for i, s in enumerate(samples)}
163
+ for future in as_completed(futures):
164
+ i = futures[future]
165
+ result = future.result()
166
+ results[i] = {"index": i, **result}
167
+ completed += 1
168
+ if completed % 50 == 0 or completed == total:
169
+ print(f"{completed}/{total} done", flush=True)
170
+
171
+ matches = sum(1 for r in results if r["exact_match"])
172
+ exact_match_rate = matches / total if total else 0
173
+
174
+ output = {
175
+ "summary": {
176
+ "label": label,
177
+ "task": task,
178
+ "num_samples": total,
179
+ "exact_matches": matches,
180
+ "exact_match_rate": exact_match_rate,
181
+ "timestamp": datetime.now().isoformat(),
182
+ },
183
+ "results": results,
184
+ }
185
+
186
+ output_path.parent.mkdir(parents=True, exist_ok=True)
187
+ with output_path.open("w") as f:
188
+ json.dump(output, f, indent=2)
189
+
190
+ print(f"\n{'=' * 60}")
191
+ print(f"[{label}] {matches}/{total} exact matches ({100 * exact_match_rate:.1f}%)")
192
+ print(f"Results saved to {output_path}")
193
+ print(f"{'=' * 60}")
194
+
195
+
196
+ def main() -> None:
197
+ parser = argparse.ArgumentParser(description="Interactive eval against llama-server")
198
+ parser.add_argument("indices", nargs="*", type=int, help="Sample indices to evaluate")
199
+ parser.add_argument("--task", default="sql", choices=["sql", "places"])
200
+ parser.add_argument(
201
+ "--run-dir",
202
+ type=Path,
203
+ default=DEFAULT_RUN_DIR,
204
+ help="Run directory containing {task}/{split}.jsonl files",
205
+ )
206
+ parser.add_argument("--split", default="val", choices=["val", "test"], help="Dataset split")
207
+ parser.add_argument("--verbose", "-v", action="store_true", help="Print full prompt sent to the model")
208
+ parser.add_argument("--all", dest="run_all", action="store_true", help="Run all samples in batch mode")
209
+ parser.add_argument("--max-samples", type=int, default=None, help="Limit number of samples (batch mode)")
210
+ parser.add_argument("--label", default="local-gguf", help="Label for batch output file")
211
+ parser.add_argument("--output", type=Path, default=None, help="Output JSON path (batch mode)")
212
+ parser.add_argument("--workers", type=int, default=4, help="Concurrent requests; match llama-server --parallel (default 4)")
213
+ args = parser.parse_args()
214
+
215
+ if not check_server():
216
+ print("llama-server not running. Start it with:")
217
+ print("llama-server -m finetune/models/<model>.gguf -ngl 99 --port 9000 --ctx-size 2048 --log-disable")
218
+ sys.exit(1)
219
+
220
+ samples = load_samples(args.run_dir, args.task, args.split)
221
+ total = len(samples)
222
+
223
+ if args.run_all:
224
+ if args.max_samples:
225
+ samples = samples[: args.max_samples]
226
+ output_path = args.output or Path(f"eval-{args.label}-{args.task}.json")
227
+ print(f"Running batch eval: {len(samples)} samples, {args.workers} workers")
228
+ run_batch(samples, args.task, args.label, output_path, workers=args.workers)
229
+ return
230
+
231
+ if not args.indices:
232
+ print(f"Test set has {total} {args.task} samples (0-{total-1})")
233
+ raw = input("Enter index (or press Enter for 0): ").strip()
234
+ indices = [int(raw) if raw else 0]
235
+ else:
236
+ indices = args.indices
237
+
238
+ for idx in indices:
239
+ if not (0 <= idx < total):
240
+ print(f"Index {idx} out of range (0-{total-1}), skipping")
241
+ continue
242
+ run_sample(samples[idx], args.task, total, idx, verbose=args.verbose)
243
+
244
+ print(f"\n{'━' * 60}\n")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()
finetune/eval_demo.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit eval viewer: compare expected vs predicted SQL and view results on a map.
2
+
3
+ Usage: streamlit run finetune/eval_demo.py
4
+ """
5
+
6
+ import difflib
7
+ import json
8
+ import math
9
+ import os
10
+ import pathlib
11
+
12
+ import duckdb
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pydeck as pdk
16
+ import sqlparse
17
+ import streamlit as st
18
+
19
+ PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
20
+ DATA_DIR = pathlib.Path(
21
+ os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
22
+ )
23
+ EVAL_DIR = pathlib.Path(
24
+ os.environ.get("GAZET_EVAL_DIR", str(PROJECT_ROOT / "results"))
25
+ )
26
+
27
+
28
+ def load_eval_results(path):
29
+ with open(path) as f:
30
+ return json.load(f)
31
+
32
+
33
+ def rewrite_data_paths(sql):
34
+ """Replace symbolic and legacy paths with actual local data paths."""
35
+ # Legacy fixed Docker paths must be replaced first to avoid double-expansion
36
+ sql = sql.replace("/data/", f"{DATA_DIR}/")
37
+ div_path = str(DATA_DIR / "overture" / "divisions_area" / "*.parquet")
38
+ ne_path = str(DATA_DIR / "natural_earth_geoparquet" / "ne_geography.parquet")
39
+ sql = sql.replace("read_parquet('divisions_area')", f"read_parquet('{div_path}')")
40
+ sql = sql.replace("read_parquet('natural_earth')", f"read_parquet('{ne_path}')")
41
+ return sql
42
+
43
+
44
+ def format_sql(sql):
45
+ """Pretty-print SQL with sqlparse."""
46
+ return sqlparse.format(sql, reindent=True, keyword_case="upper")
47
+
48
+
49
+ def sql_diff_html(expected, predicted):
50
+ """Return an HTML diff of two SQL strings."""
51
+ expected_lines = format_sql(expected).splitlines()
52
+ predicted_lines = format_sql(predicted).splitlines()
53
+ diff = difflib.HtmlDiff(tabsize=2, wrapcolumn=80)
54
+ return diff.make_table(
55
+ expected_lines, predicted_lines,
56
+ fromdesc="Expected", todesc="Predicted",
57
+ context=False,
58
+ )
59
+
60
+
61
+ def get_duckdb_connection():
62
+ con = duckdb.connect()
63
+ con.execute("INSTALL spatial")
64
+ con.execute("LOAD spatial")
65
+ return con
66
+
67
+
68
+ def execute_sql(con, sql):
69
+ """Execute SQL, converting geometry columns to simplified GeoJSON strings."""
70
+ rel = con.sql(sql)
71
+ cols = rel.columns
72
+ types = [str(t) for t in rel.dtypes]
73
+
74
+ select_parts = []
75
+ for col, dtype in zip(cols, types):
76
+ if "GEOMETRY" in dtype.upper():
77
+ select_parts.append(
78
+ f'ST_AsGeoJSON(ST_SimplifyPreserveTopology("{col}", 0.001)) AS "{col}"'
79
+ )
80
+ else:
81
+ select_parts.append(f'"{col}"')
82
+
83
+ wrapped = f"SELECT {', '.join(select_parts)} FROM ({sql})"
84
+ return con.execute(wrapped).fetchdf()
85
+
86
+
87
+ def _is_notna(val):
88
+ """Check if a value is not NA, handling arrays/lists/numpy arrays safely."""
89
+ if isinstance(val, (list, tuple, np.ndarray)):
90
+ return len(val) > 0
91
+ return pd.notna(val)
92
+
93
+
94
+ def _to_python(val):
95
+ """Convert numpy/pandas types to native Python for JSON serialization."""
96
+ if isinstance(val, (np.integer,)):
97
+ return int(val)
98
+ if isinstance(val, (np.floating,)):
99
+ return float(val)
100
+ if isinstance(val, np.ndarray):
101
+ return val.tolist()
102
+ if isinstance(val, (np.bool_,)):
103
+ return bool(val)
104
+ return val
105
+
106
+
107
+ def to_feature_collection(result_df):
108
+ """Build GeoJSON FeatureCollection from a DataFrame with GeoJSON string columns."""
109
+ geom_cols = []
110
+ for c in result_df.columns:
111
+ vals = [v for v in result_df[c].head(5) if isinstance(v, str)]
112
+ if vals and all(v.lstrip().startswith('{"type":') for v in vals):
113
+ geom_cols.append(c)
114
+
115
+ prop_cols = [c for c in result_df.columns if c not in geom_cols]
116
+ features = []
117
+ for _, row in result_df.iterrows():
118
+ geometry = None
119
+ if geom_cols:
120
+ raw = row[geom_cols[0]]
121
+ if raw and isinstance(raw, str):
122
+ geometry = json.loads(raw)
123
+ properties = {}
124
+ for c in prop_cols:
125
+ val = row[c]
126
+ if _is_notna(val):
127
+ properties[c] = _to_python(val)
128
+ features.append(
129
+ {"type": "Feature", "geometry": geometry, "properties": properties}
130
+ )
131
+ return {"type": "FeatureCollection", "features": features}
132
+
133
+
134
+ def bbox_from_geojson(geojson):
135
+ lngs, lats = [], []
136
+ for f in geojson.get("features", []):
137
+ geom = f.get("geometry")
138
+ if geom:
139
+ for coord in _extract_coords(geom):
140
+ lngs.append(coord[0])
141
+ lats.append(coord[1])
142
+ if not lngs:
143
+ return None
144
+ return min(lngs), min(lats), max(lngs), max(lats)
145
+
146
+
147
+ def _extract_coords(geom):
148
+ t = geom.get("type", "")
149
+ coords = geom.get("coordinates", [])
150
+ if t == "Point":
151
+ yield coords
152
+ elif t in ("LineString", "MultiPoint"):
153
+ yield from coords
154
+ elif t == "Polygon":
155
+ for ring in coords:
156
+ yield from ring
157
+ elif t in ("MultiLineString", "MultiPolygon"):
158
+ for part in coords:
159
+ if t == "MultiLineString":
160
+ yield from part
161
+ else:
162
+ for ring in part:
163
+ yield from ring
164
+ elif t == "GeometryCollection":
165
+ for g in geom.get("geometries", []):
166
+ yield from _extract_coords(g)
167
+
168
+
169
+ def _centroids_from_geojson(geojson):
170
+ """Extract centroid [lng, lat] for each feature to use as scatter markers."""
171
+ centroids = []
172
+ for f in geojson.get("features", []):
173
+ geom = f.get("geometry")
174
+ if not geom:
175
+ continue
176
+ lngs, lats = [], []
177
+ for coord in _extract_coords(geom):
178
+ lngs.append(coord[0])
179
+ lats.append(coord[1])
180
+ if lngs:
181
+ centroids.append({"lng": sum(lngs) / len(lngs), "lat": sum(lats) / len(lats)})
182
+ return centroids
183
+
184
+
185
+ def render_map(geojson, color, key):
186
+ n = len(geojson.get("features", []))
187
+ if not n:
188
+ st.info("Query returned no features.")
189
+ return
190
+
191
+ layers = [
192
+ pdk.Layer(
193
+ "GeoJsonLayer",
194
+ data=geojson,
195
+ get_fill_color=color,
196
+ get_line_color=[100, 100, 100, 200],
197
+ get_line_width=2,
198
+ pickable=True,
199
+ ),
200
+ ]
201
+
202
+ bbox = bbox_from_geojson(geojson)
203
+ if bbox:
204
+ min_lng, min_lat, max_lng, max_lat = bbox
205
+ span = max(max_lng - min_lng, max_lat - min_lat, 1e-6)
206
+ zoom = max(0, min(18, math.log2(360 / span) - 0.8))
207
+
208
+ # Add scatter markers when polygons would be too small to see
209
+ if zoom < 4:
210
+ centroids = _centroids_from_geojson(geojson)
211
+ if centroids:
212
+ layers.append(
213
+ pdk.Layer(
214
+ "ScatterplotLayer",
215
+ data=centroids,
216
+ get_position=["lng", "lat"],
217
+ get_fill_color=color[:3] + [220],
218
+ get_radius=50000,
219
+ radius_min_pixels=6,
220
+ pickable=True,
221
+ )
222
+ )
223
+
224
+ view = pdk.ViewState(
225
+ latitude=(min_lat + max_lat) / 2,
226
+ longitude=(min_lng + max_lng) / 2,
227
+ zoom=zoom,
228
+ )
229
+ else:
230
+ view = pdk.ViewState(latitude=0, longitude=0, zoom=1)
231
+
232
+ st.pydeck_chart(
233
+ pdk.Deck(layers=layers, initial_view_state=view, map_style=None),
234
+ width="stretch",
235
+ height=400,
236
+ key=key,
237
+ )
238
+
239
+
240
+ # --- App ---
241
+
242
+ st.set_page_config(page_title="Eval Viewer", layout="wide")
243
+ st.title("Eval Viewer")
244
+
245
+ eval_files = sorted(EVAL_DIR.glob("eval-*.json"))
246
+ if not eval_files:
247
+ st.error(f"No eval result files found in {EVAL_DIR}")
248
+ st.stop()
249
+
250
+ selected_file = st.sidebar.selectbox(
251
+ "Eval file",
252
+ eval_files,
253
+ format_func=lambda p: p.stem,
254
+ )
255
+
256
+ data = load_eval_results(selected_file)
257
+ summary = data["summary"]
258
+ results = data["results"]
259
+
260
+ st.sidebar.markdown(f"""
261
+ **Model**: `{summary.get('label', '')}`
262
+ **Exact match**: {summary['exact_matches']}/{summary['num_samples']} ({summary['exact_match_rate']:.1%})
263
+ """)
264
+
265
+ filter_option = st.sidebar.radio("Filter", ["All", "Matches only", "Mismatches only"])
266
+ if filter_option == "Matches only":
267
+ results = [r for r in results if r["exact_match"]]
268
+ elif filter_option == "Mismatches only":
269
+ results = [r for r in results if not r["exact_match"]]
270
+
271
+ if not results:
272
+ st.warning("No results match the current filter.")
273
+ st.stop()
274
+
275
+ questions = [
276
+ f"[{r['index']}] {r.get('question', 'Sample ' + str(r['index']))}"
277
+ for r in results
278
+ ]
279
+ selected_idx = st.selectbox("Select a query", range(len(questions)), format_func=lambda i: questions[i])
280
+ row = results[selected_idx]
281
+
282
+ match_label = "MATCH" if row["exact_match"] else "MISMATCH"
283
+ match_color = "green" if row["exact_match"] else "red"
284
+ st.markdown(f"### :{match_color}[{match_label}]")
285
+
286
+ is_sql = summary.get("task", "sql") == "sql"
287
+ expected = row["expected"]
288
+ predicted = row["predicted"]
289
+
290
+ # Formatted output side-by-side
291
+ col_expected, col_predicted = st.columns(2)
292
+ with col_expected:
293
+ st.markdown("**Expected**")
294
+ if is_sql:
295
+ st.code(format_sql(expected), language="sql")
296
+ else:
297
+ st.code(expected, language="json")
298
+ with col_predicted:
299
+ st.markdown("**Predicted**")
300
+ if is_sql:
301
+ st.code(format_sql(predicted), language="sql")
302
+ else:
303
+ st.code(predicted, language="json")
304
+
305
+ # Diff view
306
+ if not row["exact_match"]:
307
+ with st.expander("Diff", expanded=True):
308
+ diff_html = sql_diff_html(expected, predicted)
309
+ diff_css = """
310
+ <style>
311
+ .diff_add { background-color: rgba(40, 167, 69, 0.15); }
312
+ .diff_sub { background-color: rgba(220, 53, 69, 0.15); }
313
+ .diff_chg { background-color: rgba(255, 193, 7, 0.15); }
314
+ .diff_header { background-color: rgba(128, 128, 128, 0.1); font-weight: bold; }
315
+ table.diff { border-collapse: collapse; width: 100%; font-family: monospace; color: inherit; }
316
+ table.diff td, table.diff th { padding: 4px 8px; border: 1px solid rgba(128, 128, 128, 0.2); }
317
+ </style>
318
+ """
319
+ st.html(f"{diff_css}<div style='overflow-x:auto; font-size:13px;'>{diff_html}</div>")
320
+
321
+ # Auto-execute SQL and show maps (only for sql task)
322
+ if is_sql:
323
+ con = get_duckdb_connection()
324
+
325
+ map_col1, map_col2 = st.columns(2)
326
+
327
+ with map_col1:
328
+ st.markdown("**Expected result**")
329
+ sql = rewrite_data_paths(expected)
330
+ try:
331
+ df = execute_sql(con, sql)
332
+ geojson = to_feature_collection(df)
333
+ render_map(geojson, [40, 180, 160, 140], key="map_expected")
334
+ with st.expander("Result table"):
335
+ st.dataframe(df, width="stretch")
336
+ except Exception as e:
337
+ st.error(f"Execution error: {e}")
338
+
339
+ with map_col2:
340
+ st.markdown("**Predicted result**")
341
+ sql = rewrite_data_paths(predicted)
342
+ try:
343
+ df = execute_sql(con, sql)
344
+ geojson = to_feature_collection(df)
345
+ render_map(geojson, [180, 80, 60, 140], key="map_predicted")
346
+ with st.expander("Result table"):
347
+ st.dataframe(df, width="stretch")
348
+ except Exception as e:
349
+ st.error(f"Execution error: {e}")
350
+
351
+ con.close()
finetune/train_modal_qwen35.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal training script for gazet Qwen3.5 LoRA fine-tuning with Unsloth.
2
+
3
+ Key differences from train_modal.py (Gemma):
4
+ - Uses Unsloth's FastLanguageModel for memory-efficient training
5
+ - Applies Qwen3.5 chat template to format data (not plain prompt+completion strings)
6
+ - Uses train_on_responses_only with ChatML markers to mask non-assistant tokens
7
+ - Saves merged 16-bit model via unsloth's save_pretrained_merged
8
+
9
+ Usage
10
+ -----
11
+ modal run finetune/train_modal_qwen35.py
12
+ modal run finetune/train_modal_qwen35.py --base-model unsloth/Qwen3.5-0.8B
13
+ modal run finetune/train_modal_qwen35.py --run-dir /mnt/gazet/data/v3-symbolic-paths
14
+ modal run finetune/train_modal_qwen35.py --num-train-epochs 5 --lora-r 32
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import pathlib
20
+ from dataclasses import dataclass
21
+ from datetime import datetime
22
+ from typing import Optional
23
+
24
+ import modal
25
+
26
+ app = modal.App("gazet-nlg-qwen35-finetune-v2")
27
+
28
+ GPU_TYPE = "A100-80GB"
29
+ TIMEOUT_HOURS = 24
30
+ MAX_RETRIES = 1
31
+
32
+ train_image = (
33
+ modal.Image.debian_slim(python_version="3.11")
34
+ .pip_install(
35
+ # Use unsloth's bundled CUDA+torch extra so bitsandbytes, xformers,
36
+ # and trl are all resolved together against the same CUDA/torch build.
37
+ # Mirrors the approach in https://modal.com/docs/examples/unsloth_finetune
38
+ "unsloth[cu129-torch280]",
39
+ "unsloth_zoo",
40
+ "transformers~=5.2.0",
41
+ "hf-transfer==0.1.9",
42
+ "trackio[gpu]==0.21.1",
43
+ "datasets",
44
+ "pandas",
45
+ )
46
+ .add_local_python_source("finetune", copy=True)
47
+ .env({
48
+ "HF_HOME": "/mnt/gazet/model_cache",
49
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
50
+ })
51
+ )
52
+
53
+ with train_image.imports():
54
+ from unsloth import FastLanguageModel
55
+ from unsloth.chat_templates import train_on_responses_only
56
+ from trl import SFTConfig, SFTTrainer
57
+ from transformers import set_seed
58
+
59
+ gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
60
+
61
+ VOLUMES = {
62
+ "/mnt/gazet": gazet_vol,
63
+ }
64
+
65
+ # ChatML response markers for Qwen3.5 — the empty <think> block is how Qwen3.5
66
+ # formats non-thinking responses. We train only on tokens after this prefix.
67
+ INSTRUCTION_PART = "<|im_start|>user\n"
68
+ RESPONSE_PART = "<|im_start|>assistant\n<think>\n\n</think>\n\n"
69
+
70
+
71
+ @dataclass
72
+ class Qwen35Config:
73
+ # Model
74
+ base_model: str = "unsloth/Qwen3.5-0.8B"
75
+
76
+ # Dataset — path to run dir with {task}/{split}.jsonl files
77
+ run_dir: str = "/mnt/gazet/data/v1"
78
+ max_train_samples: Optional[int] = None
79
+ max_eval_samples: Optional[int] = None
80
+
81
+ # Sequence
82
+ max_seq_length: int = 2048
83
+
84
+ # LoRA — alpha=2*r follows unsloth recommendation for Qwen models
85
+ lora_r: int = 16
86
+ lora_alpha: int = 32
87
+ lora_dropout: float = 0.0
88
+
89
+ # Training
90
+ num_train_epochs: int = 1
91
+ per_device_train_batch_size: int = 32
92
+ per_device_eval_batch_size: int = 16
93
+ gradient_accumulation_steps: int = 1 # effective batch = 48
94
+ learning_rate: float = 1e-4
95
+ max_grad_norm: float = 1.0
96
+ warmup_steps: int = 50
97
+ lr_scheduler_type: str = "linear"
98
+ weight_decay: float = 0.01
99
+ optim: str = "adamw_8bit"
100
+
101
+ # Logging / saving
102
+ logging_steps: int = 10
103
+ save_strategy: str = "steps"
104
+ save_steps: int = 400
105
+ eval_strategy: str = "steps"
106
+ eval_steps: int = 200
107
+ report_to: str = "trackio"
108
+ trackio_space_id: Optional[str] = "srmsoumya/gazet-trackio"
109
+ project: str = "gazet-nlg-qwen35"
110
+
111
+ # Experiment
112
+ seed: int = 42
113
+ experiment_name: Optional[str] = None
114
+
115
+ def __post_init__(self):
116
+ if self.experiment_name is None:
117
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
118
+ model_short = self.base_model.split("/")[-1]
119
+ self.experiment_name = f"{model_short}-r{self.lora_r}-{timestamp}"
120
+
121
+
122
+ def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples=None):
123
+ """Load JSONL data and apply Qwen3.5 chat template.
124
+
125
+ Each sample must have:
126
+ messages: list of {role, content} dicts (system + user + assistant)
127
+
128
+ The chat template produces the full ChatML string including the assistant turn.
129
+ train_on_responses_only then masks everything except the assistant response.
130
+ """
131
+ import json
132
+ from datasets import Dataset, DatasetDict
133
+
134
+ def load_jsonl(path: pathlib.Path) -> list[dict]:
135
+ rows = []
136
+ with open(path) as f:
137
+ for line in f:
138
+ line = line.strip()
139
+ if line:
140
+ rows.append(json.loads(line))
141
+ return rows
142
+
143
+ def to_message(sample: dict) -> dict:
144
+ text = tokenizer.apply_chat_template(
145
+ sample["messages"],
146
+ tokenize=False,
147
+ add_generation_prompt=False,
148
+ )
149
+ return {"messages": text}
150
+
151
+ run_dir = pathlib.Path(run_dir)
152
+ tasks = ("sql", "places")
153
+ splits = ("train", "val")
154
+ ds_dict: dict = {}
155
+
156
+ for split in splits:
157
+ combined: list[dict] = []
158
+ for task in tasks:
159
+ path = run_dir / task / f"{split}.jsonl"
160
+ if not path.exists():
161
+ print(f"Missing {path} — skipping")
162
+ continue
163
+ rows = load_jsonl(path)
164
+ flattened = [to_message(r) for r in rows]
165
+ combined.extend(flattened)
166
+ print(f"Loaded {len(rows):,} {task}/{split} rows")
167
+
168
+ if combined:
169
+ ds_dict[split] = Dataset.from_list(combined)
170
+ print(f"{split} split: {len(combined):,} total rows")
171
+
172
+ ds = DatasetDict(ds_dict).shuffle(seed=42)
173
+
174
+ if max_train_samples is not None and "train" in ds:
175
+ ds["train"] = ds["train"].select(range(min(max_train_samples, len(ds["train"]))))
176
+ if max_eval_samples is not None and "val" in ds:
177
+ ds["val"] = ds["val"].select(range(min(max_eval_samples, len(ds["val"]))))
178
+
179
+ return ds
180
+
181
+
182
+ def _find_latest_checkpoint(checkpoint_dir: pathlib.Path) -> str | None:
183
+ if not checkpoint_dir.exists():
184
+ return None
185
+ checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
186
+ if not checkpoints:
187
+ return None
188
+ latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
189
+ print(f"Found existing checkpoint: {latest}")
190
+ return str(latest)
191
+
192
+
193
+ @app.function(
194
+ image=train_image,
195
+ gpu=GPU_TYPE,
196
+ volumes=VOLUMES,
197
+ secrets=[modal.Secret.from_name("huggingface-secret")],
198
+ timeout=TIMEOUT_HOURS * 60 * 60,
199
+ retries=modal.Retries(initial_delay=0.0, max_retries=MAX_RETRIES),
200
+ )
201
+ def finetune(config_dict: dict):
202
+ """Run Qwen3.5 LoRA SFT training with Unsloth inside a Modal container."""
203
+ config = Qwen35Config(**config_dict)
204
+ set_seed(config.seed)
205
+
206
+ experiment_dir = pathlib.Path("/mnt/gazet/checkpoints") / config.experiment_name
207
+ experiment_dir.mkdir(parents=True, exist_ok=True)
208
+
209
+ print(f"Experiment: {config.experiment_name}")
210
+ print(f"Model: {config.base_model}")
211
+ print(f"Run dir: {config.run_dir}")
212
+
213
+ # Load base model with unsloth — gradient checkpointing is handled internally
214
+ model, processor = FastLanguageModel.from_pretrained(
215
+ config.base_model,
216
+ max_seq_length=config.max_seq_length,
217
+ load_in_4bit=False,
218
+ use_gradient_checkpointing="unsloth",
219
+ fast_inference=False,
220
+ )
221
+ tokenizer = processor.tokenizer
222
+
223
+ # Apply LoRA adapters — let unsloth select target modules via finetune_* flags
224
+ model = FastLanguageModel.get_peft_model(
225
+ model,
226
+ r=config.lora_r,
227
+ lora_alpha=config.lora_alpha,
228
+ lora_dropout=config.lora_dropout,
229
+ finetune_vision_layers=False,
230
+ finetune_language_layers=True,
231
+ finetune_attention_modules=True,
232
+ finetune_mlp_modules=True,
233
+ bias="none",
234
+ random_state=config.seed,
235
+ use_gradient_checkpointing=False, # already set in from_pretrained
236
+ )
237
+
238
+ total_params = sum(p.numel() for p in model.parameters())
239
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
240
+ print(f"Total parameters: {total_params:,}")
241
+ print(f"Trainable parameters: {trainable_params:,}")
242
+
243
+ ds = _load_data(
244
+ config.run_dir,
245
+ tokenizer,
246
+ max_train_samples=config.max_train_samples,
247
+ max_eval_samples=config.max_eval_samples,
248
+ )
249
+ if "train" not in ds:
250
+ raise RuntimeError(
251
+ f"No training data found in {config.run_dir}. "
252
+ "Run the dataset pipeline and upload exported data to the volume first."
253
+ )
254
+ print(f"Train samples: {len(ds['train']):,}")
255
+ if "val" in ds:
256
+ print(f"Val samples: {len(ds['val']):,}")
257
+ effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
258
+ print(f"Effective batch: {effective_batch}")
259
+
260
+ sft_args = SFTConfig(
261
+ output_dir=str(experiment_dir),
262
+ dataset_text_field="messages",
263
+ max_seq_length=config.max_seq_length,
264
+ num_train_epochs=config.num_train_epochs,
265
+ per_device_train_batch_size=config.per_device_train_batch_size,
266
+ per_device_eval_batch_size=config.per_device_eval_batch_size,
267
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
268
+ learning_rate=config.learning_rate,
269
+ max_grad_norm=config.max_grad_norm,
270
+ warmup_steps=config.warmup_steps,
271
+ lr_scheduler_type=config.lr_scheduler_type,
272
+ weight_decay=config.weight_decay,
273
+ optim=config.optim,
274
+ bf16=True,
275
+ logging_steps=config.logging_steps,
276
+ save_strategy=config.save_strategy,
277
+ save_steps=config.save_steps,
278
+ eval_strategy=config.eval_strategy,
279
+ eval_steps=config.eval_steps,
280
+ report_to=config.report_to,
281
+ trackio_space_id=config.trackio_space_id,
282
+ project=config.project,
283
+ dataset_num_proc=8,
284
+ seed=config.seed,
285
+ )
286
+
287
+ trainer = SFTTrainer(
288
+ model=model,
289
+ tokenizer=tokenizer,
290
+ train_dataset=ds["train"],
291
+ eval_dataset=ds.get("val"),
292
+ args=sft_args,
293
+ )
294
+
295
+ # Mask all tokens except the assistant response — train on completions only
296
+ trainer = train_on_responses_only(
297
+ trainer,
298
+ instruction_part=INSTRUCTION_PART,
299
+ response_part=RESPONSE_PART,
300
+ )
301
+
302
+ resume_from = _find_latest_checkpoint(experiment_dir)
303
+ if resume_from:
304
+ print(f"Resuming from {resume_from}")
305
+
306
+ trainer.train(resume_from_checkpoint=resume_from)
307
+
308
+ # Save LoRA adapter + tokenizer (lightweight, for future merging)
309
+ print(f"Saving LoRA adapter to {experiment_dir}")
310
+ model.save_pretrained(str(experiment_dir))
311
+ tokenizer.save_pretrained(str(experiment_dir))
312
+
313
+ # Save merged 16-bit model (full weights, ready for inference / GGUF conversion)
314
+ merged_dir = experiment_dir / "merged"
315
+ merged_dir.mkdir(parents=True, exist_ok=True)
316
+ print(f"Saving merged 16-bit model to {merged_dir}")
317
+ model.save_pretrained_merged(str(merged_dir), tokenizer, save_method="merged_16bit")
318
+
319
+ gazet_vol.commit()
320
+ print(f"Training complete: {config.experiment_name}")
321
+ return config.experiment_name
322
+
323
+
324
+ @app.local_entrypoint()
325
+ def main(
326
+ base_model: Optional[str] = None,
327
+ experiment_name: Optional[str] = None,
328
+ run_dir: Optional[str] = None,
329
+ num_train_epochs: Optional[int] = None,
330
+ per_device_train_batch_size: Optional[int] = None,
331
+ max_train_samples: Optional[int] = None,
332
+ max_eval_samples: Optional[int] = None,
333
+ lora_r: Optional[int] = None,
334
+ max_seq_length: Optional[int] = None,
335
+ ):
336
+ overrides = {
337
+ k: v for k, v in dict(
338
+ base_model=base_model,
339
+ experiment_name=experiment_name,
340
+ run_dir=run_dir,
341
+ num_train_epochs=num_train_epochs,
342
+ per_device_train_batch_size=per_device_train_batch_size,
343
+ max_train_samples=max_train_samples,
344
+ max_eval_samples=max_eval_samples,
345
+ lora_r=lora_r,
346
+ max_seq_length=max_seq_length,
347
+ ).items() if v is not None
348
+ }
349
+
350
+ config = Qwen35Config(**overrides)
351
+ # lora_alpha follows r if r was overridden and alpha wasn't
352
+ if lora_r is not None:
353
+ config.lora_alpha = 2 * config.lora_r
354
+
355
+ print(f"Starting experiment: {config.experiment_name}")
356
+ print(f"Model: {config.base_model}")
357
+ print(f"Run dir: {config.run_dir}")
358
+ print(f"LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
359
+ effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
360
+ print(f"Effective batch: {effective_batch}")
361
+
362
+ result = finetune.remote(config.__dict__)
363
+ print(f"Training complete: {result}")
gazet_demo.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import json
4
  import math
 
5
 
6
  import pandas as pd
7
  import requests
@@ -100,7 +101,7 @@ def _render_map(geojson, placeholder):
100
  st.json(geojson)
101
 
102
 
103
- API = "http://127.0.0.1:8000"
104
  EXAMPLES = [
105
  "Angola and Mozambique",
106
  "Mediterranean Sea",
@@ -114,6 +115,17 @@ st.set_page_config(page_title="Gazet", page_icon="🌍", layout="wide")
114
  st.title("Gazet")
115
  st.caption("Natural-language geo search · click an example or type your own")
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  if "run_q" not in st.session_state:
118
  st.session_state.run_q = None
119
 
@@ -149,7 +161,7 @@ with col2:
149
 
150
  try:
151
  with requests.get(
152
- f"{API}/search/stream", params={"q": to_run}, stream=True, timeout=120
153
  ) as r:
154
  r.raise_for_status()
155
 
 
2
 
3
  import json
4
  import math
5
+ import os
6
 
7
  import pandas as pd
8
  import requests
 
101
  st.json(geojson)
102
 
103
 
104
+ API = os.environ.get("GAZET_API_URL", "http://127.0.0.1:8000")
105
  EXAMPLES = [
106
  "Angola and Mozambique",
107
  "Mediterranean Sea",
 
115
  st.title("Gazet")
116
  st.caption("Natural-language geo search · click an example or type your own")
117
 
118
+ backend = st.sidebar.radio(
119
+ "SQL Backend",
120
+ ["gguf", "dspy"],
121
+ index=0,
122
+ format_func=lambda x: "⚡ GGUF (llama-server)" if x == "gguf" else "🧠 DSPy (cloud LM)",
123
+ )
124
+ st.sidebar.caption(
125
+ "**gguf** → finetuned Qwen3.5 via llama-server\n\n"
126
+ "**dspy** → Ollama / cloud LM with retry loop"
127
+ )
128
+
129
  if "run_q" not in st.session_state:
130
  st.session_state.run_q = None
131
 
 
161
 
162
  try:
163
  with requests.get(
164
+ f"{API}/search/stream", params={"q": to_run, "backend": backend}, stream=True, timeout=120
165
  ) as r:
166
  r.raise_for_status()
167
 
pyproject.toml CHANGED
@@ -16,8 +16,18 @@ dependencies = [
16
  "pydantic>=2.0",
17
  "pyarrow>=17.0.0",
18
  "geopandas>=1.1.2",
 
 
19
  ]
20
  optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
21
 
 
 
 
22
  [tool.hatch.build.targets.wheel]
23
- packages = ["src/gazet"]
 
 
 
 
 
 
16
  "pydantic>=2.0",
17
  "pyarrow>=17.0.0",
18
  "geopandas>=1.1.2",
19
+ "httpx>=0.28.1",
20
+ "sqlparse>=0.5.5",
21
  ]
22
  optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
23
 
24
+ [project.scripts]
25
+ gazet-dataset = "dataset.scripts.cli:main"
26
+
27
  [tool.hatch.build.targets.wheel]
28
+ packages = ["src/gazet", "dataset"]
29
+
30
+ [dependency-groups]
31
+ dataset = [
32
+ "modal>=1.4.0",
33
+ ]
src/gazet/api.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  from typing import Any, Generator
3
 
4
  import duckdb
@@ -7,19 +8,33 @@ from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import StreamingResponse
8
 
9
  from .export import to_feature_collection
10
- from .lm import extract
11
- from .search import search_divisions_area, search_natural_earth
12
- from .sql import run_geo_sql_loop
13
 
14
  app = FastAPI()
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _df_to_records(df: pd.DataFrame) -> list[dict[str, Any]]:
18
  """Convert DataFrame to list of dicts for JSON; handle non-JSON-serializable types."""
19
  return df.replace({float("nan"): None}).to_dict(orient="records")
20
 
21
 
22
- def _run_stream(query: str) -> Generator[str, None, None]:
23
  """Yield NDJSON lines as each stage of the search completes.
24
 
25
  Event ``type`` values (in order of emission):
@@ -30,9 +45,12 @@ def _run_stream(query: str) -> Generator[str, None, None]:
30
  - ``geojson`` – final FeatureCollection
31
  - ``error`` – fatal error (no result)
32
  """
33
- pred = extract(query=query)
34
- print("extract result:", pred.result)
35
- places_result = pred.result
 
 
 
36
 
37
  yield json.dumps({"type": "places", "data": places_result.model_dump()}) + "\n"
38
 
@@ -41,12 +59,10 @@ def _run_stream(query: str) -> Generator[str, None, None]:
41
  con.execute("LOAD spatial")
42
 
43
  try:
 
44
  all_candidates: list[pd.DataFrame] = []
45
  for place in places_result.places:
46
- for search_fn in (search_divisions_area, search_natural_earth):
47
- df = search_fn(con, place)
48
- if not df.empty:
49
- all_candidates.append(df)
50
 
51
  if not all_candidates:
52
  yield json.dumps({"type": "error", "data": "No candidates found"}) + "\n"
@@ -64,8 +80,9 @@ def _run_stream(query: str) -> Generator[str, None, None]:
64
  + "\n"
65
  )
66
 
 
67
  result_df: pd.DataFrame | None = None
68
- for event in run_geo_sql_loop(con, query, candidates_df):
69
  if event["type"] == "sql_attempt":
70
  yield (
71
  json.dumps(
@@ -105,13 +122,13 @@ def _run_stream(query: str) -> Generator[str, None, None]:
105
 
106
 
107
  @app.get("/search/stream")
108
- def search_stream(q: str) -> StreamingResponse:
109
  """Stream search progress as NDJSON (one JSON object per line)."""
110
- return StreamingResponse(_run_stream(q), media_type="application/x-ndjson")
111
 
112
 
113
  @app.get("/search", response_model=None)
114
- def search(q: str) -> dict[str, Any]:
115
  """Run geo search for natural-language query (non-streaming).
116
 
117
  Returns GeoJSON FeatureCollection, the executed SQL, and the identified
@@ -122,7 +139,7 @@ def search(q: str) -> dict[str, Any]:
122
  sql = ""
123
  geojson: dict | None = None
124
 
125
- for line in _run_stream(q):
126
  if not line.strip():
127
  continue
128
  event = json.loads(line)
 
1
  import json
2
+ import math
3
  from typing import Any, Generator
4
 
5
  import duckdb
 
8
  from fastapi.responses import StreamingResponse
9
 
10
  from .export import to_feature_collection
11
+ from .lm import extract, generate_places
12
+ from .search import search_candidates
13
+ from .sql import run_geo_sql_dspy, run_geo_sql_gguf
14
 
15
  app = FastAPI()
16
 
17
 
18
+ def _per_source_limit(num_places: int) -> int:
19
+ """Candidates to fetch per source per place, scaled by number of places.
20
+
21
+ Keeps the total candidate count in the prompt manageable:
22
+ 1 place → 5 per source → 10 total
23
+ 2 places → 4 per source → 16 total
24
+ 3 places → 3 per source → 18 total
25
+ 4 places → 2 per source → 16 total
26
+ 5 places → 2 per source → 20 total
27
+ """
28
+ table = {1: 5, 2: 4, 3: 3, 4: 2, 5: 2}
29
+ return table.get(num_places, max(1, math.ceil(5 / num_places)))
30
+
31
+
32
  def _df_to_records(df: pd.DataFrame) -> list[dict[str, Any]]:
33
  """Convert DataFrame to list of dicts for JSON; handle non-JSON-serializable types."""
34
  return df.replace({float("nan"): None}).to_dict(orient="records")
35
 
36
 
37
+ def _run_stream(query: str, backend: str = "gguf") -> Generator[str, None, None]:
38
  """Yield NDJSON lines as each stage of the search completes.
39
 
40
  Event ``type`` values (in order of emission):
 
45
  - ``geojson`` – final FeatureCollection
46
  - ``error`` – fatal error (no result)
47
  """
48
+ if backend == "gguf":
49
+ places_result = generate_places(query)
50
+ else:
51
+ pred = extract(query=query)
52
+ places_result = pred.result
53
+ print("places:", places_result)
54
 
55
  yield json.dumps({"type": "places", "data": places_result.model_dump()}) + "\n"
56
 
 
59
  con.execute("LOAD spatial")
60
 
61
  try:
62
+ limit = _per_source_limit(len(places_result.places))
63
  all_candidates: list[pd.DataFrame] = []
64
  for place in places_result.places:
65
+ all_candidates.extend(search_candidates(con, place, limit=limit))
 
 
 
66
 
67
  if not all_candidates:
68
  yield json.dumps({"type": "error", "data": "No candidates found"}) + "\n"
 
80
  + "\n"
81
  )
82
 
83
+ sql_fn = run_geo_sql_gguf if backend == "gguf" else run_geo_sql_dspy
84
  result_df: pd.DataFrame | None = None
85
+ for event in sql_fn(con, query, candidates_df):
86
  if event["type"] == "sql_attempt":
87
  yield (
88
  json.dumps(
 
122
 
123
 
124
  @app.get("/search/stream")
125
+ def search_stream(q: str, backend: str = "gguf") -> StreamingResponse:
126
  """Stream search progress as NDJSON (one JSON object per line)."""
127
+ return StreamingResponse(_run_stream(q, backend), media_type="application/x-ndjson")
128
 
129
 
130
  @app.get("/search", response_model=None)
131
+ def search(q: str, backend: str = "gguf") -> dict[str, Any]:
132
  """Run geo search for natural-language query (non-streaming).
133
 
134
  Returns GeoJSON FeatureCollection, the executed SQL, and the identified
 
139
  sql = ""
140
  geojson: dict | None = None
141
 
142
+ for line in _run_stream(q, backend):
143
  if not line.strip():
144
  continue
145
  event = json.loads(line)
src/gazet/config.py CHANGED
@@ -1,7 +1,11 @@
 
1
  import pathlib
2
 
3
- # Data lives at project root (gazet/data/), not inside the package
4
- _DATA_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "data"
 
 
 
5
  DIVISIONS_AREA_PATH = str(_DATA_DIR / "overture/divisions_area/*.parquet")
6
  NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet")
7
 
@@ -9,18 +13,29 @@ NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parq
9
  # MODEL = "granite4:350m"
10
  # MODEL = "gemma3:12b-cloud"
11
  # MODEL = "qwen3.5:397b-cloud"
12
- MODEL = "gpt-oss:20b-cloud"
13
  # MODEL = "qwen3:4b"
14
  # MODEL = "qwen3-coder-next:cloud"
15
  # MODEL = "deepseek-coder:1.3b"
 
 
 
 
 
 
16
 
17
  MAX_SQL_ITERATIONS = 5
18
 
19
- SCHEMA_INFO = f"""
 
 
 
 
 
20
  Available DuckDB datasets (read via read_parquet):
21
 
22
  1. divisions_area — Overture polygon/multipolygon admin boundaries
23
- path: '{DIVISIONS_AREA_PATH}'
24
  columns:
25
  id VARCHAR -- unique feature id (use this to filter precisely)
26
  names STRUCT("primary" VARCHAR, ...)
@@ -36,7 +51,7 @@ Available DuckDB datasets (read via read_parquet):
36
  geometry GEOMETRY -- boundary polygon/multipolygon (WKB, spatial ext loaded)
37
 
38
  2. natural_earth — Natural Earth geography polygons (oceans, seas, terrain regions, islands)
39
- path: '{NATURAL_EARTH_PATH}'
40
  columns:
41
  id VARCHAR -- unique feature id prefixed 'ne_'
42
  names STRUCT("primary" VARCHAR, ...)
@@ -49,26 +64,26 @@ Available DuckDB datasets (read via read_parquet):
49
  is_territorial BOOLEAN
50
  geometry GEOMETRY -- polygon/multipolygon (WKB, spatial ext loaded)
51
 
52
- Spatial extension is already loaded — use ST_AsGeoJSON(geometry) or ST_AsText(geometry).
53
  To access names use: names."primary"
54
 
55
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
56
- Use the matching path for each candidate's source when querying.
57
 
58
  Example patterns:
59
  -- single region boundary from divisions_area
60
- SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS geojson
61
- FROM read_parquet('{DIVISIONS_AREA_PATH}')
62
  WHERE id = '<candidate_id>'
63
 
64
  -- feature from natural_earth
65
- SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS geojson
66
- FROM read_parquet('{NATURAL_EARTH_PATH}')
67
  WHERE id = '<candidate_id>'
68
 
69
  -- shared border between two adjacent regions
70
- WITH a AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '<id_a>'),
71
- b AS (SELECT geometry FROM read_parquet('{DIVISIONS_AREA_PATH}') WHERE id = '<id_b>')
72
- SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, b.geometry)) AS border
73
  FROM a, b
74
  """
 
1
+ import os
2
  import pathlib
3
 
4
+ # Data lives at project root (gazet/data/), not inside the package.
5
+ # Override with GAZET_DATA_DIR env var for remote execution (e.g. Modal volume at /data).
6
+ _DATA_DIR = pathlib.Path(os.environ.get("GAZET_DATA_DIR", str(
7
+ pathlib.Path(__file__).resolve().parent.parent.parent / "data"
8
+ )))
9
  DIVISIONS_AREA_PATH = str(_DATA_DIR / "overture/divisions_area/*.parquet")
10
  NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet")
11
 
 
13
  # MODEL = "granite4:350m"
14
  # MODEL = "gemma3:12b-cloud"
15
  # MODEL = "qwen3.5:397b-cloud"
16
+ # MODEL = "gpt-oss:20b-cloud"
17
  # MODEL = "qwen3:4b"
18
  # MODEL = "qwen3-coder-next:cloud"
19
  # MODEL = "deepseek-coder:1.3b"
20
+ # MODEL = "qwen3.5:2b"
21
+ # MODEL = "qwen3.5:0.8b"
22
+ # MODEL = "qwen2.5-coder:1.5b"
23
+
24
+ PLACE_EXTRACTION_MODEL = "gpt-oss:20b-cloud"
25
+ SQL_GENERATION_MODEL = "gpt-oss:20b-cloud"
26
 
27
  MAX_SQL_ITERATIONS = 5
28
 
29
+ # ── GGUF / llama-server config ────────────────────────────────────────────────
30
+ LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://localhost:9000")
31
+ LLAMA_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "2048"))
32
+ LLAMA_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0"))
33
+
34
+ SCHEMA_INFO = """
35
  Available DuckDB datasets (read via read_parquet):
36
 
37
  1. divisions_area — Overture polygon/multipolygon admin boundaries
38
+ query: read_parquet('divisions_area')
39
  columns:
40
  id VARCHAR -- unique feature id (use this to filter precisely)
41
  names STRUCT("primary" VARCHAR, ...)
 
51
  geometry GEOMETRY -- boundary polygon/multipolygon (WKB, spatial ext loaded)
52
 
53
  2. natural_earth — Natural Earth geography polygons (oceans, seas, terrain regions, islands)
54
+ query: read_parquet('natural_earth')
55
  columns:
56
  id VARCHAR -- unique feature id prefixed 'ne_'
57
  names STRUCT("primary" VARCHAR, ...)
 
64
  is_territorial BOOLEAN
65
  geometry GEOMETRY -- polygon/multipolygon (WKB, spatial ext loaded)
66
 
67
+ Spatial extension is already loaded — use ST_AsGeoJSON(geometry) for geometry outputs.
68
  To access names use: names."primary"
69
 
70
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
71
+ Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
72
 
73
  Example patterns:
74
  -- single region boundary from divisions_area
75
+ SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS geometry
76
+ FROM read_parquet('divisions_area')
77
  WHERE id = '<candidate_id>'
78
 
79
  -- feature from natural_earth
80
+ SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS geometry
81
+ FROM read_parquet('natural_earth')
82
  WHERE id = '<candidate_id>'
83
 
84
  -- shared border between two adjacent regions
85
+ WITH a AS (SELECT geometry FROM read_parquet('divisions_area') WHERE id = '<id_a>'),
86
+ b AS (SELECT geometry FROM read_parquet('divisions_area') WHERE id = '<id_b>')
87
+ SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, b.geometry)) AS geometry
88
  FROM a, b
89
  """
src/gazet/export.py CHANGED
@@ -2,9 +2,25 @@ import json
2
  import pathlib
3
  import re
4
 
 
5
  import pandas as pd
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def _is_geojson_col(series: pd.Series) -> bool:
9
  """Heuristic: a string column whose non-null values start with '{"type":'."""
10
  sample = series.dropna().head(5)
@@ -16,6 +32,36 @@ def _is_geojson_col(series: pd.Series) -> bool:
16
  )
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def save_geojson(
20
  result_df: pd.DataFrame, query: str, output_dir: pathlib.Path | str = "."
21
  ) -> pathlib.Path:
@@ -43,22 +89,33 @@ def to_feature_collection(result_df: pd.DataFrame) -> dict:
43
 
44
 
45
  def _to_feature_collection(result_df: pd.DataFrame) -> dict:
46
- geom_cols = [c for c in result_df.columns if _is_geojson_col(result_df[c])]
 
 
47
  prop_cols = [c for c in result_df.columns if c not in geom_cols]
48
  features = []
49
  for _, row in result_df.iterrows():
50
  geometry = None
51
- if geom_cols:
52
- raw = row[geom_cols[0]]
53
  if raw and isinstance(raw, str):
54
  try:
55
  geometry = json.loads(raw)
56
  except json.JSONDecodeError:
57
  pass
58
- properties = {c: row[c] for c in prop_cols if pd.notna(row[c])}
59
- for c in geom_cols[1:]:
60
- if pd.notna(row[c]):
61
- properties[c] = row[c]
 
 
 
 
 
 
 
 
 
62
  features.append(
63
  {"type": "Feature", "geometry": geometry, "properties": properties}
64
  )
 
2
  import pathlib
3
  import re
4
 
5
+ import numpy as np
6
  import pandas as pd
7
 
8
 
9
+ def _to_serializable(val):
10
+ """Convert a value to a JSON-serializable Python type."""
11
+ if isinstance(val, (bytearray, bytes)):
12
+ return None
13
+ if isinstance(val, np.ndarray):
14
+ return val.tolist()
15
+ if isinstance(val, (np.integer,)):
16
+ return int(val)
17
+ if isinstance(val, (np.floating,)):
18
+ return float(val)
19
+ if isinstance(val, (np.bool_,)):
20
+ return bool(val)
21
+ return val
22
+
23
+
24
  def _is_geojson_col(series: pd.Series) -> bool:
25
  """Heuristic: a string column whose non-null values start with '{"type":'."""
26
  sample = series.dropna().head(5)
 
32
  )
33
 
34
 
35
+ def _is_wkb_col(series: pd.Series) -> bool:
36
+ """Heuristic: a column whose non-null values are bytearray or bytes (WKB geometry)."""
37
+ sample = series.dropna().head(5)
38
+ return (
39
+ sample.apply(lambda v: isinstance(v, (bytearray, bytes))).all()
40
+ and len(sample) > 0
41
+ )
42
+
43
+
44
+ def _wkb_to_geojson(wkb: bytearray | bytes) -> dict | None:
45
+ """Convert WKB geometry to GeoJSON dict via DuckDB."""
46
+ import duckdb
47
+
48
+ con = duckdb.connect()
49
+ try:
50
+ con.execute("INSTALL spatial")
51
+ con.execute("LOAD spatial")
52
+ result = con.execute(
53
+ "SELECT ST_AsGeoJSON(ST_GeomFromWKB(?::BLOB)) AS geojson",
54
+ [bytes(wkb)],
55
+ ).fetchone()
56
+ if result and result[0]:
57
+ return json.loads(result[0])
58
+ except Exception:
59
+ pass
60
+ finally:
61
+ con.close()
62
+ return None
63
+
64
+
65
  def save_geojson(
66
  result_df: pd.DataFrame, query: str, output_dir: pathlib.Path | str = "."
67
  ) -> pathlib.Path:
 
89
 
90
 
91
  def _to_feature_collection(result_df: pd.DataFrame) -> dict:
92
+ geojson_cols = [c for c in result_df.columns if _is_geojson_col(result_df[c])]
93
+ wkb_cols = [c for c in result_df.columns if _is_wkb_col(result_df[c])]
94
+ geom_cols = geojson_cols + wkb_cols
95
  prop_cols = [c for c in result_df.columns if c not in geom_cols]
96
  features = []
97
  for _, row in result_df.iterrows():
98
  geometry = None
99
+ if geojson_cols:
100
+ raw = row[geojson_cols[0]]
101
  if raw and isinstance(raw, str):
102
  try:
103
  geometry = json.loads(raw)
104
  except json.JSONDecodeError:
105
  pass
106
+ elif wkb_cols:
107
+ raw = row[wkb_cols[0]]
108
+ if raw and isinstance(raw, (bytearray, bytes)):
109
+ geometry = _wkb_to_geojson(raw)
110
+ properties = {}
111
+ for c in prop_cols:
112
+ v = row[c]
113
+ try:
114
+ if not pd.notna(v):
115
+ continue
116
+ except ValueError:
117
+ pass # pd.notna fails on arrays — treat as present
118
+ properties[c] = _to_serializable(v)
119
  features.append(
120
  {"type": "Feature", "geometry": geometry, "properties": properties}
121
  )
src/gazet/lm.py CHANGED
@@ -1,7 +1,20 @@
 
 
 
1
  import dspy
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- from .config import MODEL
4
- from .schemas import PlacesResult
5
 
6
 
7
  class ExtractPlaces(dspy.Signature):
@@ -20,6 +33,13 @@ class ExtractPlaces(dspy.Signature):
20
 
21
  Where possible and relevant, also extract the ISO country code for each place.
22
 
 
 
 
 
 
 
 
23
  Do not repeat the same place name in the result.
24
 
25
  If the user does not explicitly mention a country, dont add the country code to the result.
@@ -103,10 +123,213 @@ class WriteGeoSQL(dspy.Signature):
103
  )
104
 
105
 
106
- lm = dspy.LM(
107
- f"ollama_chat/{MODEL}", api_base="http://localhost:11434", api_key="", temperature=0.1, cache=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
- dspy.configure(lm=lm)
110
 
111
- extract = dspy.Predict(ExtractPlaces)
112
- write_sql = dspy.Predict(WriteGeoSQL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
  import dspy
5
+ import httpx
6
+ import pandas as pd
7
+
8
+ from .config import (
9
+ LLAMA_MAX_TOKENS,
10
+ LLAMA_SERVER_URL,
11
+ LLAMA_TEMPERATURE,
12
+ PLACE_EXTRACTION_MODEL,
13
+ SQL_GENERATION_MODEL,
14
+ )
15
+ from .schemas import Place, PlacesResult
16
 
17
+ logger = logging.getLogger(__name__)
 
18
 
19
 
20
  class ExtractPlaces(dspy.Signature):
 
33
 
34
  Where possible and relevant, also extract the ISO country code for each place.
35
 
36
+ Only extract place names that are explicitly mentioned in the query.
37
+ Do NOT generate or infer place names from your own knowledge.
38
+ For example:
39
+ - "north half of India" -> extract "India", NOT individual state names
40
+ - "coastal cities of France" -> extract "France", NOT city names
41
+ - "neighbouring states of Odisha" -> extract "Odisha", NOT neighbouring state names
42
+
43
  Do not repeat the same place name in the result.
44
 
45
  If the user does not explicitly mention a country, dont add the country code to the result.
 
123
  )
124
 
125
 
126
+ place_extraction_lm = dspy.LM(
127
+ f"ollama_chat/{PLACE_EXTRACTION_MODEL}",
128
+ api_base="http://localhost:11434",
129
+ api_key="",
130
+ temperature=0.1,
131
+ cache=False,
132
+ )
133
+
134
+ sql_generation_lm = dspy.LM(
135
+ f"ollama_chat/{SQL_GENERATION_MODEL}",
136
+ api_base="http://localhost:11434",
137
+ api_key="",
138
+ temperature=0.1,
139
+ cache=False,
140
+ think=False
141
  )
 
142
 
143
+
144
+ class PlaceExtractor(dspy.Module):
145
+ def __init__(self, lm):
146
+ super().__init__()
147
+ self.lm = lm
148
+ self.predictor = dspy.Predict(ExtractPlaces)
149
+
150
+ def forward(self, query: str):
151
+ with dspy.context(lm=self.lm):
152
+ return self.predictor(query=query)
153
+
154
+
155
+ class SQLWriter(dspy.Module):
156
+ def __init__(self, lm):
157
+ super().__init__()
158
+ self.lm = lm
159
+ self.predictor = dspy.Predict(WriteGeoSQL)
160
+
161
+ def forward(self, user_query: str, schema: str, candidates: str,
162
+ previous_sql: str = "", execution_error: str = ""):
163
+ with dspy.context(lm=self.lm):
164
+ return self.predictor(
165
+ user_query=user_query,
166
+ schema=schema,
167
+ candidates=candidates,
168
+ previous_sql=previous_sql,
169
+ execution_error=execution_error
170
+ )
171
+
172
+
173
+ extract = PlaceExtractor(lm=place_extraction_lm)
174
+ write_sql = SQLWriter(lm=sql_generation_lm)
175
+
176
+
177
+ # ── GGUF SQL generation via llama-server ──────────────────────────────────────
178
+
179
+ _SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
180
+
181
+ You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
182
+
183
+ <SCHEMA>
184
+ 1. divisions_area -- Overture polygon/multipolygon admin boundaries
185
+ query: read_parquet('divisions_area')
186
+ columns:
187
+ id VARCHAR -- unique feature id
188
+ names STRUCT("primary" VARCHAR, ...)
189
+ country VARCHAR -- ISO 3166-1 alpha-2
190
+ subtype VARCHAR -- country | region | dependency | county | localadmin |
191
+ locality | macrohood | neighborhood | microhood
192
+ class VARCHAR
193
+ region VARCHAR
194
+ admin_level INTEGER
195
+ division_id VARCHAR
196
+ is_land BOOLEAN
197
+ is_territorial BOOLEAN
198
+ geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
199
+
200
+ 2. natural_earth -- Natural Earth geography polygons (oceans, seas, rivers, terrain)
201
+ query: read_parquet('natural_earth')
202
+ columns:
203
+ id VARCHAR -- unique feature id prefixed 'ne_'
204
+ names STRUCT("primary" VARCHAR, ...)
205
+ country VARCHAR
206
+ subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
207
+ class VARCHAR
208
+ region VARCHAR
209
+ admin_level INTEGER
210
+ is_land BOOLEAN
211
+ is_territorial BOOLEAN
212
+ geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
213
+ </SCHEMA>
214
+
215
+ The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
216
+ Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
217
+ Use ST_AsGeoJSON(geometry) for all geometry outputs."""
218
+
219
+ _USER_PROMPT_TEMPLATE = """<CANDIDATES>
220
+ {candidates_csv}
221
+ </CANDIDATES>
222
+
223
+ <USER_QUERY>
224
+ {question}
225
+ </USER_QUERY>
226
+ """
227
+
228
+
229
+ def _postprocess_sql(text: str) -> str:
230
+ """Strip markdown fences and whitespace from generated SQL."""
231
+ cleaned = text.strip()
232
+ if "```sql" in cleaned:
233
+ cleaned = cleaned.split("```sql", 1)[1]
234
+ if cleaned.startswith("```"):
235
+ cleaned = cleaned[3:]
236
+ if "```" in cleaned:
237
+ cleaned = cleaned.split("```", 1)[0]
238
+ return cleaned.strip()
239
+
240
+
241
+ def is_llama_server_available() -> bool:
242
+ """Check if the llama-server is running and healthy."""
243
+ try:
244
+ resp = httpx.get(f"{LLAMA_SERVER_URL}/health", timeout=2)
245
+ return resp.status_code == 200
246
+ except (httpx.ConnectError, httpx.TimeoutException):
247
+ return False
248
+
249
+
250
+ def _llama_chat_complete(messages: list[dict]) -> str:
251
+ """Call llama-server /v1/chat/completions with a messages list."""
252
+ resp = httpx.post(
253
+ f"{LLAMA_SERVER_URL}/v1/chat/completions",
254
+ json={
255
+ "messages": messages,
256
+ "n_predict": LLAMA_MAX_TOKENS,
257
+ "temperature": LLAMA_TEMPERATURE,
258
+ "chat_template_kwargs": {"enable_thinking": False},
259
+ },
260
+ timeout=60,
261
+ )
262
+ if resp.status_code != 200:
263
+ logger.error("llama-server %s: %s", resp.status_code, resp.text[:500])
264
+ resp.raise_for_status()
265
+ return resp.json()["choices"][0]["message"]["content"]
266
+
267
+
268
+ _PLACES_SYSTEM_PROMPT = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
269
+
270
+ OUTPUT FORMAT:
271
+ {"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
272
+ "country" and "subtype" are optional; omit if not applicable.
273
+
274
+ RULES:
275
+ - Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
276
+ - No duplicate place names.
277
+ - "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
278
+ - "subtype": include only when the geographic level is clear from the query.
279
+
280
+ SUBTYPES:
281
+ country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
282
+ - Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
283
+
284
+
285
+ def generate_places(user_query: str) -> PlacesResult:
286
+ """Extract place names from a query using the finetuned GGUF model.
287
+
288
+ Uses the same prompt format the model was trained on.
289
+ Returns a PlacesResult; falls back to an empty result on parse failure.
290
+ """
291
+ messages = [
292
+ {"role": "system", "content": _PLACES_SYSTEM_PROMPT},
293
+ {"role": "user", "content": user_query},
294
+ ]
295
+ raw_output = _llama_chat_complete(messages).strip()
296
+
297
+ # Strip markdown fences if the model wrapped the JSON
298
+ if raw_output.startswith("```"):
299
+ raw_output = raw_output.split("```")[1]
300
+ if raw_output.startswith("json"):
301
+ raw_output = raw_output[4:]
302
+ raw_output = raw_output.strip()
303
+
304
+ try:
305
+ data = json.loads(raw_output)
306
+ return PlacesResult.model_validate(data)
307
+ except Exception as exc:
308
+ logger.warning("generate_places: failed to parse output %r: %s", raw_output, exc)
309
+ # Best-effort: treat entire query as a single unnamed place
310
+ return PlacesResult(places=[Place(place=user_query)])
311
+
312
+
313
+ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
314
+ """Generate SQL from a natural language query using the finetuned GGUF model.
315
+
316
+ Uses the same prompt format the model was trained on:
317
+ SYSTEM_PROMPT (includes schema) + USER_PROMPT_TEMPLATE with candidates CSV and question.
318
+ Single-shot — no retry loop (the finetuned model can't improve from error feedback).
319
+ """
320
+ # Keep only columns the model was trained on
321
+ keep_cols = ["source", "id", "name", "subtype", "country", "region", "admin_level", "similarity"]
322
+ cols = [c for c in keep_cols if c in candidates_df.columns]
323
+ candidates_csv = candidates_df[cols].to_csv(index=False)
324
+
325
+ user_prompt = _USER_PROMPT_TEMPLATE.format(
326
+ candidates_csv=candidates_csv.strip(),
327
+ question=user_query.strip(),
328
+ )
329
+
330
+ messages = [
331
+ {"role": "system", "content": _SYSTEM_PROMPT},
332
+ {"role": "user", "content": user_prompt},
333
+ ]
334
+ raw_output = _llama_chat_complete(messages)
335
+ return _postprocess_sql(raw_output)
src/gazet/search.py CHANGED
@@ -5,31 +5,16 @@ from .config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
5
  from .schemas import Place
6
 
7
 
8
- def _fuzzy_search(
9
  con: duckdb.DuckDBPyConnection,
10
  path: str,
11
  source: str,
12
  place: Place,
13
  extra_select: str = "",
14
  limit: int = 5,
15
- is_overture: bool = False,
16
  ) -> pd.DataFrame:
17
- """Generic Levenshtein fuzzy search against any parquet with a names.primary column."""
18
- country_filter = ""
19
- country_params: list = []
20
- if is_overture and place.country:
21
- country_filter = "AND country = ?"
22
- country_params = [place.country]
23
-
24
- subtype_filter = ""
25
- subtype_params: list = []
26
- if is_overture and place.subtype:
27
- subtype_filter = "AND subtype = ?"
28
- subtype_params = [place.subtype]
29
-
30
- params = (
31
- [place.place, place.place, path] + country_params + subtype_params + [limit]
32
- )
33
 
34
  extra_clause = f", {extra_select}" if extra_select else ""
35
  rel = con.execute(
@@ -44,12 +29,9 @@ def _fuzzy_search(
44
  admin_level,
45
  is_land,
46
  is_territorial{extra_clause},
47
- 1.0 - (levenshtein(lower(names."primary"), lower(?))::float
48
- / greatest(length(names."primary"), length(?), 1)) AS similarity
49
  FROM read_parquet(?)
50
  WHERE names."primary" IS NOT NULL AND trim(names."primary") != ''
51
- {country_filter}
52
- {subtype_filter}
53
  ORDER BY similarity DESC, admin_level ASC
54
  LIMIT ?
55
  """,
@@ -57,11 +39,10 @@ def _fuzzy_search(
57
  )
58
  df = rel.fetchdf()
59
  df.insert(0, "source", source)
60
- label = f'"{place.place}"' + (f" [{place.country}]" if place.country else "")
61
  if df.empty:
62
- print(f"\n{source} {label}: no matches")
63
  else:
64
- print(f"\n{source} {label} (top {len(df)} by name similarity):")
65
  print(df.to_string(index=False))
66
  return df
67
 
@@ -70,14 +51,13 @@ def search_divisions_area(
70
  con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
71
  ) -> pd.DataFrame:
72
  """Fuzzy-match a place against divisions_area (Overture admin boundaries)."""
73
- return _fuzzy_search(
74
  con,
75
  DIVISIONS_AREA_PATH,
76
  "divisions_area",
77
  place,
78
  extra_select="division_id",
79
  limit=limit,
80
- is_overture=True,
81
  )
82
 
83
 
@@ -85,10 +65,26 @@ def search_natural_earth(
85
  con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
86
  ) -> pd.DataFrame:
87
  """Fuzzy-match a place against Natural Earth geography polygons."""
88
- return _fuzzy_search(
89
  con,
90
  NATURAL_EARTH_PATH,
91
  "natural_earth",
92
  place,
93
  limit=limit,
94
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from .schemas import Place
6
 
7
 
8
+ def simple_fuzzy_search(
9
  con: duckdb.DuckDBPyConnection,
10
  path: str,
11
  source: str,
12
  place: Place,
13
  extra_select: str = "",
14
  limit: int = 5,
 
15
  ) -> pd.DataFrame:
16
+ """Jaro-Winkler similarity search using only the place name."""
17
+ params = [place.place, path, limit]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  extra_clause = f", {extra_select}" if extra_select else ""
20
  rel = con.execute(
 
29
  admin_level,
30
  is_land,
31
  is_territorial{extra_clause},
32
+ jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
 
33
  FROM read_parquet(?)
34
  WHERE names."primary" IS NOT NULL AND trim(names."primary") != ''
 
 
35
  ORDER BY similarity DESC, admin_level ASC
36
  LIMIT ?
37
  """,
 
39
  )
40
  df = rel.fetchdf()
41
  df.insert(0, "source", source)
 
42
  if df.empty:
43
+ print(f"\n{source} - \"{place.place}\": no matches")
44
  else:
45
+ print(f"\n{source} - \"{place.place}\" (top {len(df)} by Jaro-Winkler):")
46
  print(df.to_string(index=False))
47
  return df
48
 
 
51
  con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
52
  ) -> pd.DataFrame:
53
  """Fuzzy-match a place against divisions_area (Overture admin boundaries)."""
54
+ return simple_fuzzy_search(
55
  con,
56
  DIVISIONS_AREA_PATH,
57
  "divisions_area",
58
  place,
59
  extra_select="division_id",
60
  limit=limit,
 
61
  )
62
 
63
 
 
65
  con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
66
  ) -> pd.DataFrame:
67
  """Fuzzy-match a place against Natural Earth geography polygons."""
68
+ return simple_fuzzy_search(
69
  con,
70
  NATURAL_EARTH_PATH,
71
  "natural_earth",
72
  place,
73
  limit=limit,
74
  )
75
+
76
+
77
+ def search_candidates(
78
+ con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
79
+ ) -> list[pd.DataFrame]:
80
+ """Return candidate DataFrames for a place from both sources.
81
+
82
+ Always searches divisions_area and natural_earth to avoid missing
83
+ natural features when the model assigns an incorrect admin subtype.
84
+ """
85
+ results = []
86
+ for fn in (search_divisions_area, search_natural_earth):
87
+ df = fn(con, place, limit=limit)
88
+ if not df.empty:
89
+ results.append(df)
90
+ return results
src/gazet/sql.py CHANGED
@@ -4,8 +4,40 @@ from typing import Any, Generator, Optional
4
  import duckdb
5
  import pandas as pd
6
 
7
- from .config import MAX_SQL_ITERATIONS, SCHEMA_INFO
8
- from .lm import write_sql
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def _strip_fences(sql: Optional[str]) -> str:
@@ -17,13 +49,38 @@ def _strip_fences(sql: Optional[str]) -> str:
17
  return sql.strip()
18
 
19
 
20
- def run_geo_sql_loop(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  con: duckdb.DuckDBPyConnection,
22
  user_query: str,
23
  candidates_df: pd.DataFrame,
24
- max_iterations: int = MAX_SQL_ITERATIONS,
25
  ) -> Generator[dict[str, Any], None, None]:
26
- """Code-act loop yielding progress events.
27
 
28
  Event types:
29
  - ``sql_attempt`` – ``{"type": "sql_attempt", "sql": str, "iteration": int}``
@@ -31,7 +88,46 @@ def run_geo_sql_loop(
31
  - ``result`` – ``{"type": "result", "df": DataFrame | None, "sql": str}``
32
  """
33
  if candidates_df.empty:
34
- print("\n[SQL-Act] No candidates to work with — skipping.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  yield {"type": "result", "df": None, "sql": ""}
36
  return
37
 
@@ -41,7 +137,7 @@ def run_geo_sql_loop(
41
 
42
  for iteration in range(1, max_iterations + 1):
43
  print(f"\n{'=' * 60}")
44
- print(f"[SQL-Act] Iteration {iteration}/{max_iterations}")
45
 
46
  try:
47
  pred = write_sql(
@@ -86,6 +182,6 @@ def run_geo_sql_loop(
86
  yield {"type": "sql_error", "error": error, "iteration": iteration}
87
 
88
  print(
89
- f"\n[SQL-Act] Exhausted {max_iterations} iterations without a successful query."
90
  )
91
  yield {"type": "result", "df": None, "sql": ""}
 
4
  import duckdb
5
  import pandas as pd
6
 
7
+ from .config import DIVISIONS_AREA_PATH, MAX_SQL_ITERATIONS, NATURAL_EARTH_PATH, SCHEMA_INFO
8
+ from .lm import generate_sql, write_sql
9
+
10
+
11
+ def _rewrite_data_paths(sql: str) -> str:
12
+ """Replace any read_parquet table reference with the correct runtime path.
13
+
14
+ Handles three generations of model output:
15
+ - Symbolic: read_parquet('divisions_area')
16
+ - Old paths: read_parquet('/data/overture/division_area/...')
17
+ - Hallucinated variants: any quoted path containing 'division' or 'natural_earth'
18
+
19
+ Legacy replacements run FIRST so the absolute path is never re-matched.
20
+ """
21
+ # Any quoted path that looks like a divisions_area reference
22
+ sql = re.sub(
23
+ r"read_parquet\(['\"][^'\"]*(?:division_area|divisions_area)[^'\"]*['\"]\)",
24
+ f"read_parquet('{DIVISIONS_AREA_PATH}')",
25
+ sql,
26
+ )
27
+ # Any quoted path that looks like a natural_earth reference
28
+ sql = re.sub(
29
+ r"read_parquet\(['\"][^'\"]*natural_earth[^'\"]*['\"]\)",
30
+ f"read_parquet('{NATURAL_EARTH_PATH}')",
31
+ sql,
32
+ )
33
+ # Symbolic names (current training format)
34
+ sql = sql.replace(
35
+ "read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')"
36
+ )
37
+ sql = sql.replace(
38
+ "read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
39
+ )
40
+ return sql
41
 
42
 
43
  def _strip_fences(sql: Optional[str]) -> str:
 
49
  return sql.strip()
50
 
51
 
52
+ def _execute_sql(
53
+ con: duckdb.DuckDBPyConnection,
54
+ sql: str,
55
+ label: str,
56
+ iteration: int,
57
+ ) -> Generator[dict[str, Any], None, None]:
58
+ """Execute SQL and yield result/error events. Shared by both paths."""
59
+ try:
60
+ result_df = con.execute(sql).fetchdf()
61
+ if result_df.empty:
62
+ print(f"[{label}] Query returned no rows.")
63
+ yield {"type": "sql_error", "error": "Query returned no rows", "iteration": iteration}
64
+ yield {"type": "result", "df": None, "sql": sql}
65
+ else:
66
+ print(f"[{label}] Result ({len(result_df)} row(s))")
67
+ yield {"type": "result", "df": result_df, "sql": sql}
68
+ except Exception as exc:
69
+ error = str(exc)
70
+ print(f"[{label}] Execution error: {error}")
71
+ yield {"type": "sql_error", "error": error, "iteration": iteration}
72
+ yield {"type": "result", "df": None, "sql": sql}
73
+
74
+
75
+ # ── GGUF path: finetuned model via llama-server (single-shot) ─────────────────
76
+
77
+
78
+ def run_geo_sql_gguf(
79
  con: duckdb.DuckDBPyConnection,
80
  user_query: str,
81
  candidates_df: pd.DataFrame,
 
82
  ) -> Generator[dict[str, Any], None, None]:
83
+ """Single-shot text-to-SQL via the finetuned GGUF model (llama-server).
84
 
85
  Event types:
86
  - ``sql_attempt`` – ``{"type": "sql_attempt", "sql": str, "iteration": int}``
 
88
  - ``result`` – ``{"type": "result", "df": DataFrame | None, "sql": str}``
89
  """
90
  if candidates_df.empty:
91
+ print("\n[SQL·GGUF] No candidates to work with — skipping.")
92
+ yield {"type": "result", "df": None, "sql": ""}
93
+ return
94
+
95
+ try:
96
+ sql = generate_sql(user_query, candidates_df)
97
+ except Exception as exc:
98
+ error = f"GGUF generation failed: {exc}"
99
+ print(f"[SQL·GGUF] {error}")
100
+ yield {"type": "sql_error", "error": error, "iteration": 1}
101
+ yield {"type": "result", "df": None, "sql": ""}
102
+ return
103
+
104
+ if not sql:
105
+ print("[SQL·GGUF] Model returned empty SQL.")
106
+ yield {"type": "sql_error", "error": "Empty SQL response", "iteration": 1}
107
+ yield {"type": "result", "df": None, "sql": ""}
108
+ return
109
+
110
+ sql = _rewrite_data_paths(sql)
111
+ print(f"\n[SQL·GGUF] Generated:\n{sql}\n")
112
+ yield {"type": "sql_attempt", "sql": sql, "iteration": 1}
113
+ yield from _execute_sql(con, sql, "SQL·GGUF", iteration=1)
114
+
115
+
116
+ # ── DSPy path: cloud/local LM with retry loop ────────────────────────────────
117
+
118
+
119
+ def run_geo_sql_dspy(
120
+ con: duckdb.DuckDBPyConnection,
121
+ user_query: str,
122
+ candidates_df: pd.DataFrame,
123
+ max_iterations: int = MAX_SQL_ITERATIONS,
124
+ ) -> Generator[dict[str, Any], None, None]:
125
+ """Code-act retry loop using the DSPy SQL writer (Ollama / cloud LM).
126
+
127
+ Same event types as ``run_geo_sql_gguf``.
128
+ """
129
+ if candidates_df.empty:
130
+ print("\n[SQL·DSPy] No candidates to work with — skipping.")
131
  yield {"type": "result", "df": None, "sql": ""}
132
  return
133
 
 
137
 
138
  for iteration in range(1, max_iterations + 1):
139
  print(f"\n{'=' * 60}")
140
+ print(f"[SQL·DSPy] Iteration {iteration}/{max_iterations}")
141
 
142
  try:
143
  pred = write_sql(
 
182
  yield {"type": "sql_error", "error": error, "iteration": iteration}
183
 
184
  print(
185
+ f"\n[SQL·DSPy] Exhausted {max_iterations} iterations without a successful query."
186
  )
187
  yield {"type": "result", "df": None, "sql": ""}
uv.lock CHANGED
The diff for this file is too large to render. See raw diff