Spaces:
Running
Running
Merge pull request #1 from developmentseed/slm-qwen3.5
Browse files- .dockerignore +12 -0
- .gitignore +22 -2
- Dockerfile +20 -0
- README.md +4 -4
- dataset/README.md +177 -0
- dataset/__init__.py +1 -0
- dataset/config.yaml +63 -0
- dataset/modal_app.py +272 -0
- dataset/scripts/__init__.py +1 -0
- dataset/scripts/build_inventory.py +124 -0
- dataset/scripts/build_relations.py +557 -0
- dataset/scripts/cli.py +377 -0
- dataset/scripts/export_training_data.py +372 -0
- dataset/scripts/generate_samples.py +1560 -0
- dataset/scripts/sql_templates.py +1651 -0
- dataset/scripts/validate_dataset.py +309 -0
- docker-compose.yml +41 -0
- finetune/README.md +291 -0
- finetune/__init__.py +1 -0
- finetune/check_token_lengths.py +155 -0
- finetune/eval_cli.py +248 -0
- finetune/eval_demo.py +351 -0
- finetune/train_modal_qwen35.py +363 -0
- gazet_demo.py +14 -2
- pyproject.toml +11 -1
- src/gazet/api.py +33 -16
- src/gazet/config.py +30 -15
- src/gazet/export.py +64 -7
- src/gazet/lm.py +230 -7
- src/gazet/search.py +24 -28
- src/gazet/sql.py +104 -8
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.venv
|
| 3 |
+
__pycache__
|
| 4 |
+
data/
|
| 5 |
+
finetune/models/
|
| 6 |
+
dataset/output/
|
| 7 |
+
dataset/intermediate/
|
| 8 |
+
results/
|
| 9 |
+
*.gguf
|
| 10 |
+
*.safetensors
|
| 11 |
+
.windsurf/
|
| 12 |
+
.claude/
|
.gitignore
CHANGED
|
@@ -133,6 +133,26 @@ dmypy.json
|
|
| 133 |
# Pyre type checker
|
| 134 |
.pyre/
|
| 135 |
|
| 136 |
-
|
| 137 |
data/
|
| 138 |
-
output/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# Pyre type checker
|
| 134 |
.pyre/
|
| 135 |
|
| 136 |
+
# Dataset
|
| 137 |
data/
|
| 138 |
+
output/
|
| 139 |
+
*.parquet
|
| 140 |
+
*.jsonl
|
| 141 |
+
|
| 142 |
+
# Eval results
|
| 143 |
+
results/
|
| 144 |
+
|
| 145 |
+
# IDE
|
| 146 |
+
.windsurf/
|
| 147 |
+
|
| 148 |
+
# Local notes
|
| 149 |
+
notes.md
|
| 150 |
+
|
| 151 |
+
# Model
|
| 152 |
+
models/
|
| 153 |
+
*.gguf
|
| 154 |
+
*.safetensors
|
| 155 |
+
*.bin
|
| 156 |
+
*.pt
|
| 157 |
+
*.pth
|
| 158 |
+
*.ckpt
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13-slim
|
| 2 |
+
|
| 3 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install dependencies first (cache layer)
|
| 8 |
+
COPY pyproject.toml uv.lock ./
|
| 9 |
+
RUN uv sync --frozen --no-install-project --extra demo
|
| 10 |
+
|
| 11 |
+
# Copy application code
|
| 12 |
+
COPY src/ src/
|
| 13 |
+
COPY gazet_demo.py .
|
| 14 |
+
|
| 15 |
+
# Install the project itself
|
| 16 |
+
RUN uv sync --frozen --extra demo
|
| 17 |
+
|
| 18 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 19 |
+
|
| 20 |
+
EXPOSE 8000 8501
|
README.md
CHANGED
|
@@ -26,14 +26,14 @@ uv sync --extra dev --extra demo
|
|
| 26 |
Example for downloading overture
|
| 27 |
|
| 28 |
```bash
|
| 29 |
-
aws s3 sync
|
| 30 |
-
s3 sync s3://overturemaps-us-west-2/release/2026-02-18.0/theme=divisions/type=division_area/ data/overture/divisions_area
|
| 31 |
```
|
| 32 |
|
| 33 |
Example for running conversion script for natural earth
|
| 34 |
|
| 35 |
```bash
|
| 36 |
-
|
|
|
|
| 37 |
```
|
| 38 |
|
| 39 |
### Based on ollama
|
|
@@ -61,7 +61,7 @@ uv run streamlit run gazet_demo.py # demo UI
|
|
| 61 |
| Module | Contents |
|
| 62 |
| --- | --- |
|
| 63 |
| `config.py` | data paths, model name, SQL schema description |
|
| 64 |
-
| `
|
| 65 |
| `lm.py` | DSPy signatures + LM init (`extract`, `write_sql`) |
|
| 66 |
| `search.py` | fuzzy search against `divisions_area` / `natural_earth` |
|
| 67 |
| `sql.py` | code-act SQL generation loop |
|
|
|
|
| 26 |
Example for downloading overture
|
| 27 |
|
| 28 |
```bash
|
| 29 |
+
aws s3 sync s3://overturemaps-us-west-2/release/2026-02-18.0/theme=divisions/type=division_area/ data/overture/divisions_area
|
|
|
|
| 30 |
```
|
| 31 |
|
| 32 |
Example for running conversion script for natural earth
|
| 33 |
|
| 34 |
```bash
|
| 35 |
+
unzip ~/Downloads/10m_physical.zip -d data/natural_earth
|
| 36 |
+
python -m ingest.convert_natural_earth data/natural_earth
|
| 37 |
```
|
| 38 |
|
| 39 |
### Based on ollama
|
|
|
|
| 61 |
| Module | Contents |
|
| 62 |
| --- | --- |
|
| 63 |
| `config.py` | data paths, model name, SQL schema description |
|
| 64 |
+
| `schemas.py` | `SUBTYPES`, `COUNTRIES`, `Place`, `PlacesResult` |
|
| 65 |
| `lm.py` | DSPy signatures + LM init (`extract`, `write_sql`) |
|
| 66 |
| `search.py` | fuzzy search against `divisions_area` / `natural_earth` |
|
| 67 |
| `sql.py` | code-act SQL generation loop |
|
dataset/README.md
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gazet Dataset Generation
|
| 2 |
+
|
| 3 |
+
Generates synthetic training data for fine-tuning the geocoding model.
|
| 4 |
+
Two datasets come out of one pipeline run:
|
| 5 |
+
|
| 6 |
+
- **SQL generation** — `(question + candidates) -> DuckDB SQL`
|
| 7 |
+
- **Place extraction** — `question -> place names JSON`
|
| 8 |
+
|
| 9 |
+
Both tasks export in **conversation format** (`messages` list of
|
| 10 |
+
system/user/assistant turns), ready for chat-template fine-tuning.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## Prerequisites
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
uv sync
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
You need the Overture and Natural Earth parquet files under `data/` locally,
|
| 21 |
+
or on a Modal volume if running in the cloud.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Option A — Run locally (small datasets, development)
|
| 26 |
+
|
| 27 |
+
Use this when you want to iterate quickly on a laptop with a subset of countries.
|
| 28 |
+
|
| 29 |
+
**Step 1 — Pick a run name and countries in `config.yaml`**
|
| 30 |
+
|
| 31 |
+
```yaml
|
| 32 |
+
run_name: "v1" # change this every time you generate fresh data
|
| 33 |
+
|
| 34 |
+
countries:
|
| 35 |
+
- IN # India
|
| 36 |
+
- BR # Brazil
|
| 37 |
+
- US # United States
|
| 38 |
+
# add more, or use "- all" for every country (slow locally)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Step 2 — Run the full pipeline**
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
gazet-dataset full-pipeline --config dataset/config.yaml
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
That's it. It runs all four steps in order and puts the results in
|
| 48 |
+
`dataset/output/runs/my-run-001/`.
|
| 49 |
+
|
| 50 |
+
If you want to run steps individually (e.g. to re-export without regenerating):
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
gazet-dataset build-relations --config dataset/config.yaml # ~5 min
|
| 54 |
+
gazet-dataset generate-samples --config dataset/config.yaml # ~15 min
|
| 55 |
+
gazet-dataset validate --config dataset/config.yaml # ~5 min
|
| 56 |
+
gazet-dataset export --config dataset/config.yaml # <1 min
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
## Option B — Run on Modal (large datasets, production)
|
| 62 |
+
|
| 63 |
+
Use this when you need 10 K+ samples or want to use all countries. Modal
|
| 64 |
+
distributes generation across many containers in parallel.
|
| 65 |
+
|
| 66 |
+
**Step 1 — One-time setup**
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
modal setup # authenticate with Modal (one time)
|
| 70 |
+
gazet-dataset modal-upload --config dataset/config.yaml # upload parquet data to Modal volume
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
**Step 2 — Set run name and targets in `config.yaml`**
|
| 74 |
+
|
| 75 |
+
```yaml
|
| 76 |
+
run_name: "v1"
|
| 77 |
+
|
| 78 |
+
countries:
|
| 79 |
+
- all
|
| 80 |
+
|
| 81 |
+
sample_targets:
|
| 82 |
+
adjacency: 1250
|
| 83 |
+
containment: 1250
|
| 84 |
+
# ... see config.yaml for all families
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
**Step 3 — Run on Modal**
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
gazet-dataset modal-generate --config dataset/config.yaml
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
This builds relations, generates samples, validates, and exports — same as
|
| 94 |
+
`full-pipeline` but distributed across 100 cloud containers.
|
| 95 |
+
|
| 96 |
+
If relations are already built from a previous run (same countries, same
|
| 97 |
+
template version), skip rebuilding them:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
gazet-dataset modal-generate --config dataset/config.yaml --skip-relations
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
|
| 105 |
+
## Output
|
| 106 |
+
|
| 107 |
+
After running, your training files are at:
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
dataset/output/runs/{run_name}/
|
| 111 |
+
sql/
|
| 112 |
+
train.jsonl <- fine-tune the SQL generation model
|
| 113 |
+
val.jsonl
|
| 114 |
+
test.jsonl
|
| 115 |
+
places/
|
| 116 |
+
train.jsonl <- fine-tune the place extraction model
|
| 117 |
+
val.jsonl
|
| 118 |
+
test.jsonl
|
| 119 |
+
stats.json <- sample counts by family
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Each JSONL row is a conversation-format dict:
|
| 123 |
+
|
| 124 |
+
```json
|
| 125 |
+
{
|
| 126 |
+
"messages": [
|
| 127 |
+
{"role": "system", "content": "..."},
|
| 128 |
+
{"role": "user", "content": "..."},
|
| 129 |
+
{"role": "assistant", "content": "..."}
|
| 130 |
+
]
|
| 131 |
+
}
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
**SQL task**: the system prompt includes the full two-table schema inside
|
| 135 |
+
`<SCHEMA>` tags. The user prompt contains only `<CANDIDATES>` CSV and
|
| 136 |
+
`<USER_QUERY>`. The assistant response is pretty-printed SQL (via `sqlparse`).
|
| 137 |
+
All parquet paths are symbolic (`divisions_area` / `natural_earth`), never
|
| 138 |
+
runtime-specific.
|
| 139 |
+
|
| 140 |
+
**Places task**: the system prompt includes output format, extraction rules,
|
| 141 |
+
and the full list of Overture subtypes. The assistant response is a JSON
|
| 142 |
+
object with a `places` array.
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## When to regenerate from scratch
|
| 147 |
+
|
| 148 |
+
Change `run_name` and regenerate from scratch whenever you:
|
| 149 |
+
|
| 150 |
+
- Change any SQL templates (`sql_templates.py`)
|
| 151 |
+
- Add new template families
|
| 152 |
+
- Change the candidate format or count
|
| 153 |
+
- Change the system/user prompt structure or content
|
| 154 |
+
- Change the export format
|
| 155 |
+
|
| 156 |
+
For local runs, the default is a clean run. For Modal, `modal-generate` appends
|
| 157 |
+
by default; pass `--fresh` to overwrite existing samples.
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## Troubleshooting
|
| 162 |
+
|
| 163 |
+
**Very few samples generated for a family**
|
| 164 |
+
The generation loop tries `retry_multiplier × target` and discards SQL that
|
| 165 |
+
returns empty results. Some families (e.g. `multi_adjacency`, `chained`) have
|
| 166 |
+
a lower success rate. Increase `sample_targets` for those families or increase
|
| 167 |
+
`retry_multiplier` in `config.yaml`.
|
| 168 |
+
|
| 169 |
+
**Relations step is slow**
|
| 170 |
+
Normal for `countries: [all]` — it's a spatial self-join over millions of
|
| 171 |
+
features. Use a country subset for development. Relations only need to be
|
| 172 |
+
rebuilt when you add countries or change template families.
|
| 173 |
+
|
| 174 |
+
**Validate step drops many samples**
|
| 175 |
+
The validate step re-executes every SQL query and discards ones that return
|
| 176 |
+
empty results. This is expected — check `output/runs/{run_name}/stats.json`
|
| 177 |
+
for per-family counts after export.
|
dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Synthetic dataset generation package."""
|
dataset/config.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Generation Configuration
|
| 2 |
+
# This config controls which countries to process and how many samples to generate
|
| 3 |
+
|
| 4 |
+
# Countries to include in relation building
|
| 5 |
+
# Use ISO 3166-1 alpha-2 codes, or "all" to include every country
|
| 6 |
+
countries:
|
| 7 |
+
- all
|
| 8 |
+
# Or specify a subset:
|
| 9 |
+
# - IN # India
|
| 10 |
+
# - PK # Pakistan
|
| 11 |
+
# - EC # Ecuador
|
| 12 |
+
# - BE # Belgium
|
| 13 |
+
# - KE # Kenya
|
| 14 |
+
|
| 15 |
+
# Sample generation targets per family
|
| 16 |
+
# Relation limits are auto-calculated from these targets
|
| 17 |
+
sample_targets:
|
| 18 |
+
direct_lookup: 500
|
| 19 |
+
adjacency: 750
|
| 20 |
+
multi_adjacency: 300
|
| 21 |
+
containment: 750
|
| 22 |
+
intersection: 500
|
| 23 |
+
buffer: 500
|
| 24 |
+
chained: 750 # coastal / landlocked variants
|
| 25 |
+
difference: 300
|
| 26 |
+
border_corridor: 300
|
| 27 |
+
set_operations: 500
|
| 28 |
+
partial_selection: 500
|
| 29 |
+
aggregation: 500
|
| 30 |
+
window_function: 300
|
| 31 |
+
attribute_filter: 300
|
| 32 |
+
|
| 33 |
+
# Generation settings
|
| 34 |
+
generation:
|
| 35 |
+
max_workers: 8 # Number of parallel workers
|
| 36 |
+
retry_multiplier: 2 # Generate 2x samples to account for failures
|
| 37 |
+
append_mode: false # Set false for clean regeneration after template/format changes
|
| 38 |
+
|
| 39 |
+
# Auto-scaling configuration
|
| 40 |
+
# Relation limits are automatically calculated: target * retry_multiplier * safety_factor
|
| 41 |
+
auto_scaling:
|
| 42 |
+
safety_factor: 1.5 # Extra buffer to ensure enough unique pairs
|
| 43 |
+
|
| 44 |
+
# Manual overrides (optional) - uncomment to override auto-calculated limits
|
| 45 |
+
manual_limits: {}
|
| 46 |
+
# adjacency: 10000 # Uncomment to manually set
|
| 47 |
+
# containment: 2000
|
| 48 |
+
# intersection: 1000
|
| 49 |
+
# cross_source: 500
|
| 50 |
+
|
| 51 |
+
# Modal configuration for distributed generation
|
| 52 |
+
modal:
|
| 53 |
+
volume_name: "gazet-data" # Modal Volume for parquet data
|
| 54 |
+
app_name: "gazet-dataset" # Modal app name
|
| 55 |
+
num_containers: 100 # Number of parallel containers for sample generation
|
| 56 |
+
container_cpu: 2 # CPUs per container
|
| 57 |
+
container_memory: 4096 # Memory (MB) per container
|
| 58 |
+
timeout: 3600 # Per-container timeout in seconds
|
| 59 |
+
|
| 60 |
+
# Run name — used to version exported splits so re-runs never overwrite previous data.
|
| 61 |
+
# Change this whenever you regenerate from scratch (e.g. after template changes).
|
| 62 |
+
# Exported files land in: output/runs/{run_name}/
|
| 63 |
+
run_name: "v1"
|
dataset/modal_app.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal app for distributed dataset generation."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import modal
|
| 5 |
+
|
| 6 |
+
app = modal.App("gazet-dataset")
|
| 7 |
+
|
| 8 |
+
VOLUME_MOUNT = "/data"
|
| 9 |
+
INTERMEDIATE_MOUNT = "/intermediate"
|
| 10 |
+
|
| 11 |
+
volume = modal.Volume.from_name("gazet-data", create_if_missing=True)
|
| 12 |
+
intermediate_volume = modal.Volume.from_name(
|
| 13 |
+
"gazet-intermediate", create_if_missing=True
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
image = (
|
| 17 |
+
modal.Image.debian_slim(python_version="3.12")
|
| 18 |
+
.pip_install(
|
| 19 |
+
"duckdb>=1.4.4",
|
| 20 |
+
"dspy>=3.1.3",
|
| 21 |
+
"fastapi>=0.100",
|
| 22 |
+
"pandas>=2.2",
|
| 23 |
+
"pydantic>=2.0",
|
| 24 |
+
"pyarrow>=17.0.0",
|
| 25 |
+
"pyyaml>=6.0",
|
| 26 |
+
)
|
| 27 |
+
.env({"GAZET_DATA_DIR": VOLUME_MOUNT, "PYTHONPATH": "/root"})
|
| 28 |
+
.add_local_dir("src/gazet", "/root/gazet")
|
| 29 |
+
.add_local_dir("dataset", "/root/dataset")
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.function(
|
| 34 |
+
image=image,
|
| 35 |
+
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 36 |
+
timeout=300,
|
| 37 |
+
cpu=2,
|
| 38 |
+
memory=4096,
|
| 39 |
+
)
|
| 40 |
+
def build_inventory_remote():
|
| 41 |
+
"""Build entity inventory from parquet files on the volume."""
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from dataset.scripts.build_inventory import build_inventory_to_dir
|
| 44 |
+
|
| 45 |
+
result = build_inventory_to_dir(Path(INTERMEDIATE_MOUNT))
|
| 46 |
+
intermediate_volume.commit()
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@app.function(
|
| 51 |
+
image=image,
|
| 52 |
+
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 53 |
+
timeout=3600,
|
| 54 |
+
cpu=4,
|
| 55 |
+
memory=32768,
|
| 56 |
+
)
|
| 57 |
+
def build_relation_remote(relation_type: str, countries: list, limit: int):
|
| 58 |
+
"""Compute one relation type and save to intermediate volume."""
|
| 59 |
+
from pathlib import Path
|
| 60 |
+
from dataset.scripts.build_relations import compute_single_relation
|
| 61 |
+
|
| 62 |
+
count = compute_single_relation(
|
| 63 |
+
relation_type=relation_type,
|
| 64 |
+
countries=countries,
|
| 65 |
+
limit=limit,
|
| 66 |
+
output_dir=Path(INTERMEDIATE_MOUNT),
|
| 67 |
+
)
|
| 68 |
+
intermediate_volume.commit()
|
| 69 |
+
return {"relation_type": relation_type, "count": count}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.function(
|
| 73 |
+
image=image,
|
| 74 |
+
volumes={VOLUME_MOUNT: volume, INTERMEDIATE_MOUNT: intermediate_volume},
|
| 75 |
+
timeout=3600,
|
| 76 |
+
cpu=2,
|
| 77 |
+
memory=4096,
|
| 78 |
+
)
|
| 79 |
+
def generate_batch_remote(work_items: list) -> list:
|
| 80 |
+
"""Process a batch of work items on a Modal container."""
|
| 81 |
+
from dataset.scripts.generate_samples import generate_batch_core
|
| 82 |
+
|
| 83 |
+
results = generate_batch_core(
|
| 84 |
+
work_items=work_items,
|
| 85 |
+
intermediate_dir=INTERMEDIATE_MOUNT,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
print(f"Batch complete: {sum(1 for r in results if r['sample'])} success / "
|
| 89 |
+
f"{sum(1 for r in results if not r['sample'])} failed out of {len(work_items)}")
|
| 90 |
+
|
| 91 |
+
return results
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@app.local_entrypoint()
|
| 95 |
+
def run_pipeline(
|
| 96 |
+
config_path: str = "dataset/config.yaml",
|
| 97 |
+
num_containers: int = 0,
|
| 98 |
+
skip_inventory: bool = False,
|
| 99 |
+
skip_relations: bool = False,
|
| 100 |
+
fresh: bool = False,
|
| 101 |
+
):
|
| 102 |
+
"""Run the full distributed pipeline."""
|
| 103 |
+
import yaml
|
| 104 |
+
from pathlib import Path
|
| 105 |
+
|
| 106 |
+
config = yaml.safe_load(Path(config_path).read_text())
|
| 107 |
+
countries = config["countries"]
|
| 108 |
+
sample_targets = config["sample_targets"]
|
| 109 |
+
modal_cfg = config.get("modal", {})
|
| 110 |
+
n_containers = num_containers or modal_cfg.get("num_containers", 50)
|
| 111 |
+
retry_multiplier = config["generation"]["retry_multiplier"]
|
| 112 |
+
|
| 113 |
+
print(f"Countries: {countries}")
|
| 114 |
+
print(f"Sample targets: {sample_targets}")
|
| 115 |
+
print(f"Containers: {n_containers}")
|
| 116 |
+
|
| 117 |
+
if not skip_inventory:
|
| 118 |
+
print("Building inventory...")
|
| 119 |
+
result = build_inventory_remote.remote()
|
| 120 |
+
print(f" Inventory: {result}")
|
| 121 |
+
|
| 122 |
+
if not skip_relations:
|
| 123 |
+
print("Building relations...")
|
| 124 |
+
|
| 125 |
+
from dataset.scripts.cli import calculate_relation_limits
|
| 126 |
+
|
| 127 |
+
relation_needs = calculate_relation_limits(config)
|
| 128 |
+
|
| 129 |
+
handles = []
|
| 130 |
+
for rel_type, limit in relation_needs.items():
|
| 131 |
+
h = build_relation_remote.spawn(rel_type, countries, max(limit, 500))
|
| 132 |
+
handles.append((rel_type, h))
|
| 133 |
+
|
| 134 |
+
for rel_type, h in handles:
|
| 135 |
+
result = h.get()
|
| 136 |
+
print(f" {rel_type}: {result['count']} pairs")
|
| 137 |
+
|
| 138 |
+
print(f"Generating samples across {n_containers} containers...")
|
| 139 |
+
|
| 140 |
+
import json
|
| 141 |
+
from dataset.scripts.generate_samples import prepare_work_items
|
| 142 |
+
|
| 143 |
+
output_dir = Path("dataset/output")
|
| 144 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 145 |
+
output_file = output_dir / "dataset_raw.jsonl"
|
| 146 |
+
|
| 147 |
+
existing_samples = []
|
| 148 |
+
sample_counter = 1
|
| 149 |
+
if not fresh and output_file.exists():
|
| 150 |
+
with open(output_file) as f:
|
| 151 |
+
for line in f:
|
| 152 |
+
if line.strip():
|
| 153 |
+
existing_samples.append(json.loads(line))
|
| 154 |
+
if existing_samples:
|
| 155 |
+
max_id = max(
|
| 156 |
+
int(s["id"].split("_")[1])
|
| 157 |
+
for s in existing_samples
|
| 158 |
+
if s["id"].startswith("sample_")
|
| 159 |
+
)
|
| 160 |
+
sample_counter = max_id + 1
|
| 161 |
+
print(f" Appending to {len(existing_samples)} existing samples")
|
| 162 |
+
|
| 163 |
+
work_items = prepare_work_items(
|
| 164 |
+
target_counts=sample_targets,
|
| 165 |
+
retry_multiplier=retry_multiplier,
|
| 166 |
+
start_counter=sample_counter,
|
| 167 |
+
intermediate_dir_str="",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
total_work = len(work_items)
|
| 171 |
+
print(f" Total work items: {total_work}")
|
| 172 |
+
|
| 173 |
+
batch_size = max(1, (total_work + n_containers - 1) // n_containers)
|
| 174 |
+
batches = [
|
| 175 |
+
work_items[i : i + batch_size]
|
| 176 |
+
for i in range(0, total_work, batch_size)
|
| 177 |
+
]
|
| 178 |
+
print(f" Batches: {len(batches)} x ~{batch_size} items")
|
| 179 |
+
|
| 180 |
+
new_sample_count = 0
|
| 181 |
+
failed_batches = 0
|
| 182 |
+
family_progress = {}
|
| 183 |
+
|
| 184 |
+
write_mode = "w" if fresh else "a"
|
| 185 |
+
fout = open(output_file, write_mode)
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
for batch_results in generate_batch_remote.map(
|
| 189 |
+
batches, return_exceptions=True
|
| 190 |
+
):
|
| 191 |
+
if isinstance(batch_results, Exception):
|
| 192 |
+
failed_batches += 1
|
| 193 |
+
print(f" Batch failed: {batch_results}")
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
batch_samples = []
|
| 197 |
+
for r in batch_results:
|
| 198 |
+
fam = r["family"]
|
| 199 |
+
if fam not in family_progress:
|
| 200 |
+
family_progress[fam] = {"success": 0, "failed": 0}
|
| 201 |
+
if r["sample"]:
|
| 202 |
+
batch_samples.append(r["sample"])
|
| 203 |
+
family_progress[fam]["success"] += 1
|
| 204 |
+
else:
|
| 205 |
+
family_progress[fam]["failed"] += 1
|
| 206 |
+
|
| 207 |
+
for sample in batch_samples:
|
| 208 |
+
fout.write(json.dumps(sample) + "\n")
|
| 209 |
+
fout.flush()
|
| 210 |
+
new_sample_count += len(batch_samples)
|
| 211 |
+
|
| 212 |
+
done = sum(p["success"] + p["failed"] for p in family_progress.values())
|
| 213 |
+
print(f" Progress: {done}/{total_work} items | {new_sample_count} saved | {failed_batches} batch errors")
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f" Map interrupted: {e}")
|
| 217 |
+
finally:
|
| 218 |
+
fout.close()
|
| 219 |
+
|
| 220 |
+
print(f"\nResults by family:")
|
| 221 |
+
for fam in sorted(family_progress.keys()):
|
| 222 |
+
s = family_progress[fam]["success"]
|
| 223 |
+
f = family_progress[fam]["failed"]
|
| 224 |
+
total = s + f
|
| 225 |
+
rate = (s / total * 100) if total > 0 else 0
|
| 226 |
+
target = sample_targets.get(fam, 0)
|
| 227 |
+
print(
|
| 228 |
+
f" {fam:20s}: {s:4d} success / {f:4d} failed "
|
| 229 |
+
f"({rate:5.1f}%, target: {target})"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
total_samples = len(existing_samples) + new_sample_count
|
| 233 |
+
status = "COMPLETE" if failed_batches == 0 else "PARTIAL"
|
| 234 |
+
print(f"\nGeneration {status}: {new_sample_count} new, {total_samples} total")
|
| 235 |
+
if failed_batches:
|
| 236 |
+
print(f" Failed batches: {failed_batches}/{len(batches)}")
|
| 237 |
+
print(f" Output: {output_file}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@app.local_entrypoint()
|
| 241 |
+
def upload_data(data_dir: str = "data"):
|
| 242 |
+
"""Upload local data directory to the Modal volume."""
|
| 243 |
+
import os
|
| 244 |
+
from pathlib import Path
|
| 245 |
+
|
| 246 |
+
data_path = Path(data_dir)
|
| 247 |
+
if not data_path.exists():
|
| 248 |
+
print(f"Error: {data_path} does not exist")
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
print(f"Uploading {data_path} to Modal volume 'gazet-data'...")
|
| 252 |
+
|
| 253 |
+
file_count = 0
|
| 254 |
+
total_size = 0
|
| 255 |
+
|
| 256 |
+
for root, dirs, files in os.walk(data_path):
|
| 257 |
+
for f in files:
|
| 258 |
+
local_path = os.path.join(root, f)
|
| 259 |
+
# Relative path within data_dir becomes the volume path
|
| 260 |
+
rel = os.path.relpath(local_path, data_path)
|
| 261 |
+
size = os.path.getsize(local_path)
|
| 262 |
+
total_size += size
|
| 263 |
+
file_count += 1
|
| 264 |
+
print(f" {rel} ({size / (1024*1024):.1f} MB)")
|
| 265 |
+
|
| 266 |
+
print(f" {file_count} files, {total_size / (1024*1024):.1f} MB")
|
| 267 |
+
|
| 268 |
+
vol = modal.Volume.from_name("gazet-data", create_if_missing=True)
|
| 269 |
+
with vol.batch_upload() as batch:
|
| 270 |
+
batch.put_directory(str(data_path), "/")
|
| 271 |
+
|
| 272 |
+
print("Upload complete")
|
dataset/scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Dataset generation scripts package."""
|
dataset/scripts/build_inventory.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build entity inventory from divisions_area and natural_earth parquet files.
|
| 3 |
+
|
| 4 |
+
This script creates compact inventory tables containing only the fields needed
|
| 5 |
+
for candidate sampling and distractor generation.
|
| 6 |
+
|
| 7 |
+
Output:
|
| 8 |
+
- intermediate/divisions_area_inventory.parquet
|
| 9 |
+
- intermediate/natural_earth_inventory.parquet
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_divisions_area_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
|
| 20 |
+
"""Extract compact inventory from divisions_area."""
|
| 21 |
+
query = """
|
| 22 |
+
SELECT
|
| 23 |
+
'divisions_area' AS source,
|
| 24 |
+
id,
|
| 25 |
+
names."primary" AS name,
|
| 26 |
+
subtype,
|
| 27 |
+
country,
|
| 28 |
+
region,
|
| 29 |
+
admin_level,
|
| 30 |
+
class,
|
| 31 |
+
is_land,
|
| 32 |
+
is_territorial,
|
| 33 |
+
division_id,
|
| 34 |
+
ST_Area(geometry) AS area_sq_deg,
|
| 35 |
+
ST_XMin(geometry) AS xmin,
|
| 36 |
+
ST_YMin(geometry) AS ymin,
|
| 37 |
+
ST_XMax(geometry) AS xmax,
|
| 38 |
+
ST_YMax(geometry) AS ymax
|
| 39 |
+
FROM read_parquet(?)
|
| 40 |
+
WHERE names."primary" IS NOT NULL
|
| 41 |
+
AND trim(names."primary") != ''
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH]).fetchdf()
|
| 45 |
+
print(f"Divisions area inventory: {len(df)} entities")
|
| 46 |
+
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
| 47 |
+
print(f"Countries: {df['country'].nunique()} unique")
|
| 48 |
+
|
| 49 |
+
return df
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_natural_earth_inventory(con: duckdb.DuckDBPyConnection) -> pd.DataFrame:
|
| 53 |
+
"""Extract compact inventory from natural_earth."""
|
| 54 |
+
query = """
|
| 55 |
+
SELECT
|
| 56 |
+
'natural_earth' AS source,
|
| 57 |
+
id,
|
| 58 |
+
names."primary" AS name,
|
| 59 |
+
subtype,
|
| 60 |
+
country,
|
| 61 |
+
region,
|
| 62 |
+
admin_level,
|
| 63 |
+
class,
|
| 64 |
+
is_land,
|
| 65 |
+
is_territorial,
|
| 66 |
+
ST_Area(geometry) AS area_sq_deg,
|
| 67 |
+
ST_XMin(geometry) AS xmin,
|
| 68 |
+
ST_YMin(geometry) AS ymin,
|
| 69 |
+
ST_XMax(geometry) AS xmax,
|
| 70 |
+
ST_YMax(geometry) AS ymax
|
| 71 |
+
FROM read_parquet(?)
|
| 72 |
+
WHERE names."primary" IS NOT NULL
|
| 73 |
+
AND trim(names."primary") != ''
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
df = con.execute(query, [NATURAL_EARTH_PATH]).fetchdf()
|
| 77 |
+
print(f"\nNatural earth inventory: {len(df)} entities")
|
| 78 |
+
print(f"Subtypes: {df['subtype'].value_counts().to_dict()}")
|
| 79 |
+
|
| 80 |
+
return df
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def build_inventory_to_dir(output_dir: Path) -> dict:
|
| 84 |
+
"""Build and save all inventory tables to output_dir.
|
| 85 |
+
|
| 86 |
+
Reusable entry point for both local CLI and Modal.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Dict with counts: {"divisions_area": int, "natural_earth": int}
|
| 90 |
+
"""
|
| 91 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 92 |
+
|
| 93 |
+
con = duckdb.connect()
|
| 94 |
+
con.execute("INSTALL spatial")
|
| 95 |
+
con.execute("LOAD spatial")
|
| 96 |
+
|
| 97 |
+
print("Building divisions_area inventory...")
|
| 98 |
+
divisions_df = build_divisions_area_inventory(con)
|
| 99 |
+
divisions_path = output_dir / "divisions_area_inventory.parquet"
|
| 100 |
+
divisions_df.to_parquet(divisions_path, index=False)
|
| 101 |
+
print(f"Saved to {divisions_path}")
|
| 102 |
+
|
| 103 |
+
print("\nBuilding natural_earth inventory...")
|
| 104 |
+
natural_earth_df = build_natural_earth_inventory(con)
|
| 105 |
+
natural_earth_path = output_dir / "natural_earth_inventory.parquet"
|
| 106 |
+
natural_earth_df.to_parquet(natural_earth_path, index=False)
|
| 107 |
+
print(f"Saved to {natural_earth_path}")
|
| 108 |
+
|
| 109 |
+
con.close()
|
| 110 |
+
|
| 111 |
+
total = len(divisions_df) + len(natural_earth_df)
|
| 112 |
+
print(f"\nInventory build complete")
|
| 113 |
+
print(f" Total entities: {total}")
|
| 114 |
+
return {"divisions_area": len(divisions_df), "natural_earth": len(natural_earth_df)}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
"""Build and save inventory tables."""
|
| 119 |
+
output_dir = Path(__file__).parent.parent / "intermediate"
|
| 120 |
+
build_inventory_to_dir(output_dir)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
main()
|
dataset/scripts/build_relations.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Precompute spatial relation tables for efficient anchor sampling.
|
| 3 |
+
|
| 4 |
+
This script computes:
|
| 5 |
+
- Adjacency pairs (touching features)
|
| 6 |
+
- Containment pairs (features within other features)
|
| 7 |
+
- Intersection pairs (overlapping features)
|
| 8 |
+
- Cross-source relations (divisions_area ↔ natural_earth)
|
| 9 |
+
|
| 10 |
+
Output:
|
| 11 |
+
- intermediate/adjacency_pairs.parquet
|
| 12 |
+
- intermediate/containment_pairs.parquet
|
| 13 |
+
- intermediate/intersection_pairs.parquet
|
| 14 |
+
- intermediate/cross_source_relations.parquet
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import duckdb
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
+
|
| 22 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Subtypes too granular for spatial self-joins at global scale
|
| 26 |
+
_EXCLUDED_SUBTYPES_FOR_GLOBAL = ("locality", "neighborhood", "microhood", "macrohood")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _country_filter(countries: list) -> tuple[str, list]:
|
| 30 |
+
"""Return (SQL WHERE clause, params) handling 'all' sentinel."""
|
| 31 |
+
if countries == ["all"]:
|
| 32 |
+
return "", []
|
| 33 |
+
return "WHERE country IN (SELECT unnest(?))", [countries]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _country_filter_for_join(countries: list) -> tuple[str, list]:
|
| 37 |
+
"""Like _country_filter but also excludes fine-grained subtypes for global runs.
|
| 38 |
+
|
| 39 |
+
When joining all 1M+ entities, localities/neighborhoods/microhoods cause
|
| 40 |
+
OOM. Excluding them keeps ~110K higher-level admin entities.
|
| 41 |
+
"""
|
| 42 |
+
excluded = "', '".join(_EXCLUDED_SUBTYPES_FOR_GLOBAL)
|
| 43 |
+
subtype_clause = f"AND subtype NOT IN ('{excluded}')"
|
| 44 |
+
if countries == ["all"]:
|
| 45 |
+
return f"WHERE 1=1 {subtype_clause}", []
|
| 46 |
+
return f"WHERE country IN (SELECT unnest(?)) {subtype_clause}", [countries]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_adjacency_pairs(
|
| 50 |
+
con: duckdb.DuckDBPyConnection,
|
| 51 |
+
countries: list,
|
| 52 |
+
limit: int
|
| 53 |
+
) -> pd.DataFrame:
|
| 54 |
+
"""Find all pairs of features that touch (share a boundary)."""
|
| 55 |
+
print("Computing adjacency pairs (optimized with spatial index)...")
|
| 56 |
+
|
| 57 |
+
cfilter, cparams = _country_filter_for_join(countries)
|
| 58 |
+
|
| 59 |
+
# Use bounding box pre-filter to avoid full cartesian product
|
| 60 |
+
query = f"""
|
| 61 |
+
WITH features AS (
|
| 62 |
+
SELECT
|
| 63 |
+
id,
|
| 64 |
+
names."primary" AS name,
|
| 65 |
+
subtype,
|
| 66 |
+
country,
|
| 67 |
+
admin_level,
|
| 68 |
+
geometry,
|
| 69 |
+
ST_Envelope(geometry) AS bbox
|
| 70 |
+
FROM read_parquet(?)
|
| 71 |
+
{cfilter}
|
| 72 |
+
)
|
| 73 |
+
SELECT
|
| 74 |
+
a.id AS anchor_id,
|
| 75 |
+
a.name AS anchor_name,
|
| 76 |
+
a.subtype AS anchor_subtype,
|
| 77 |
+
a.country AS anchor_country,
|
| 78 |
+
b.id AS target_id,
|
| 79 |
+
b.name AS target_name,
|
| 80 |
+
b.subtype AS target_subtype,
|
| 81 |
+
b.country AS target_country,
|
| 82 |
+
'adjacency' AS relation_type
|
| 83 |
+
FROM features AS a
|
| 84 |
+
JOIN features AS b ON (
|
| 85 |
+
a.id < b.id
|
| 86 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 87 |
+
AND ST_Touches(a.geometry, b.geometry)
|
| 88 |
+
)
|
| 89 |
+
LIMIT ?
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
|
| 93 |
+
print(f"Found {len(df)} adjacency pairs")
|
| 94 |
+
|
| 95 |
+
return df
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_containment_pairs(
|
| 99 |
+
con: duckdb.DuckDBPyConnection,
|
| 100 |
+
countries: list,
|
| 101 |
+
limit: int
|
| 102 |
+
) -> pd.DataFrame:
|
| 103 |
+
"""Find all pairs where one feature contains another."""
|
| 104 |
+
print("\nComputing containment pairs (optimized)...")
|
| 105 |
+
|
| 106 |
+
cfilter, cparams = _country_filter(countries)
|
| 107 |
+
|
| 108 |
+
query = f"""
|
| 109 |
+
WITH features AS (
|
| 110 |
+
SELECT
|
| 111 |
+
id,
|
| 112 |
+
names."primary" AS name,
|
| 113 |
+
subtype,
|
| 114 |
+
country,
|
| 115 |
+
admin_level,
|
| 116 |
+
geometry,
|
| 117 |
+
ST_Envelope(geometry) AS bbox
|
| 118 |
+
FROM read_parquet(?)
|
| 119 |
+
{cfilter}
|
| 120 |
+
)
|
| 121 |
+
SELECT
|
| 122 |
+
a.id AS container_id,
|
| 123 |
+
a.name AS container_name,
|
| 124 |
+
a.subtype AS container_subtype,
|
| 125 |
+
b.id AS contained_id,
|
| 126 |
+
b.name AS contained_name,
|
| 127 |
+
b.subtype AS contained_subtype,
|
| 128 |
+
'containment' AS relation_type
|
| 129 |
+
FROM features AS a
|
| 130 |
+
JOIN features AS b ON (
|
| 131 |
+
a.id != b.id
|
| 132 |
+
AND a.admin_level < b.admin_level
|
| 133 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 134 |
+
AND ST_Within(b.geometry, a.geometry)
|
| 135 |
+
)
|
| 136 |
+
LIMIT ?
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
|
| 140 |
+
print(f"Found {len(df)} containment pairs")
|
| 141 |
+
|
| 142 |
+
return df
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def compute_intersection_pairs(
|
| 146 |
+
con: duckdb.DuckDBPyConnection,
|
| 147 |
+
countries: list,
|
| 148 |
+
limit: int
|
| 149 |
+
) -> pd.DataFrame:
|
| 150 |
+
"""Find pairs that intersect but don't touch or contain."""
|
| 151 |
+
print("\nComputing intersection pairs (optimized)...")
|
| 152 |
+
|
| 153 |
+
cfilter, cparams = _country_filter_for_join(countries)
|
| 154 |
+
|
| 155 |
+
query = f"""
|
| 156 |
+
WITH features AS (
|
| 157 |
+
SELECT
|
| 158 |
+
id,
|
| 159 |
+
names."primary" AS name,
|
| 160 |
+
subtype,
|
| 161 |
+
country,
|
| 162 |
+
admin_level,
|
| 163 |
+
geometry,
|
| 164 |
+
ST_Envelope(geometry) AS bbox
|
| 165 |
+
FROM read_parquet(?)
|
| 166 |
+
{cfilter}
|
| 167 |
+
)
|
| 168 |
+
SELECT
|
| 169 |
+
a.id AS anchor_id,
|
| 170 |
+
a.name AS anchor_name,
|
| 171 |
+
a.subtype AS anchor_subtype,
|
| 172 |
+
b.id AS target_id,
|
| 173 |
+
b.name AS target_name,
|
| 174 |
+
b.subtype AS target_subtype,
|
| 175 |
+
'intersection' AS relation_type
|
| 176 |
+
FROM features AS a
|
| 177 |
+
JOIN features AS b ON (
|
| 178 |
+
a.id < b.id
|
| 179 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 180 |
+
AND ST_Intersects(a.geometry, b.geometry)
|
| 181 |
+
AND NOT ST_Touches(a.geometry, b.geometry)
|
| 182 |
+
AND NOT ST_Within(a.geometry, b.geometry)
|
| 183 |
+
AND NOT ST_Within(b.geometry, a.geometry)
|
| 184 |
+
)
|
| 185 |
+
LIMIT ?
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
df = con.execute(query, [DIVISIONS_AREA_PATH] + cparams + [limit]).fetchdf()
|
| 189 |
+
print(f"Found {len(df)} same-source intersection pairs")
|
| 190 |
+
|
| 191 |
+
return df
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def compute_cross_source_relations(
|
| 195 |
+
con: duckdb.DuckDBPyConnection,
|
| 196 |
+
countries: list,
|
| 197 |
+
limit: int
|
| 198 |
+
) -> pd.DataFrame:
|
| 199 |
+
"""Find relations between divisions_area and natural_earth.
|
| 200 |
+
|
| 201 |
+
Covers all natural_earth subtypes that appear in SQL templates:
|
| 202 |
+
seas/oceans (adjacency, buffer, chained), terrain areas and island
|
| 203 |
+
groups (chained_03, intersect_02, buffer_03/04).
|
| 204 |
+
"""
|
| 205 |
+
print("\nComputing cross-source relations...")
|
| 206 |
+
|
| 207 |
+
cfilter, cparams = _country_filter(countries)
|
| 208 |
+
|
| 209 |
+
query = f"""
|
| 210 |
+
WITH divisions AS (
|
| 211 |
+
SELECT
|
| 212 |
+
id,
|
| 213 |
+
names."primary" AS name,
|
| 214 |
+
subtype,
|
| 215 |
+
country,
|
| 216 |
+
geometry
|
| 217 |
+
FROM read_parquet(?)
|
| 218 |
+
{cfilter}
|
| 219 |
+
),
|
| 220 |
+
natural_features AS (
|
| 221 |
+
SELECT
|
| 222 |
+
id,
|
| 223 |
+
names."primary" AS name,
|
| 224 |
+
subtype,
|
| 225 |
+
ST_SetCRS(geometry, 'OGC:CRS84') AS geometry
|
| 226 |
+
FROM read_parquet(?)
|
| 227 |
+
WHERE subtype IN (
|
| 228 |
+
'sea', 'ocean', 'Lake', 'River', 'Basin', 'gulf', 'bay',
|
| 229 |
+
'Terrain area', 'Island group', 'Peninsula', 'Strait',
|
| 230 |
+
'Reef', 'Range/Mts', 'Depression'
|
| 231 |
+
)
|
| 232 |
+
LIMIT 500
|
| 233 |
+
)
|
| 234 |
+
SELECT
|
| 235 |
+
d.id AS division_id,
|
| 236 |
+
d.name AS division_name,
|
| 237 |
+
d.subtype AS division_subtype,
|
| 238 |
+
d.country AS division_country,
|
| 239 |
+
n.id AS natural_id,
|
| 240 |
+
n.name AS natural_name,
|
| 241 |
+
n.subtype AS natural_subtype,
|
| 242 |
+
CASE
|
| 243 |
+
WHEN ST_Touches(d.geometry, n.geometry) THEN 'touches'
|
| 244 |
+
WHEN ST_Within(d.geometry, n.geometry) THEN 'within'
|
| 245 |
+
WHEN ST_Contains(d.geometry, n.geometry) THEN 'contains'
|
| 246 |
+
WHEN ST_Intersects(d.geometry, n.geometry) THEN 'intersects'
|
| 247 |
+
END AS relation_type
|
| 248 |
+
FROM divisions AS d
|
| 249 |
+
JOIN natural_features AS n ON ST_Intersects(d.geometry, n.geometry)
|
| 250 |
+
LIMIT ?
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
df = con.execute(
|
| 254 |
+
query, [DIVISIONS_AREA_PATH] + cparams + [NATURAL_EARTH_PATH, limit]
|
| 255 |
+
).fetchdf()
|
| 256 |
+
print(f"Found {len(df)} cross-source relations")
|
| 257 |
+
return df
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def compute_coastal_containment_pairs(
|
| 261 |
+
con: duckdb.DuckDBPyConnection,
|
| 262 |
+
countries: list,
|
| 263 |
+
limit: int,
|
| 264 |
+
) -> pd.DataFrame:
|
| 265 |
+
"""Containment pairs where the container is in a coastal country.
|
| 266 |
+
|
| 267 |
+
Used by chained_01 (coastal towns of X) to ensure sampled containment
|
| 268 |
+
anchors actually have sea-adjacent sub-features, keeping the SQL
|
| 269 |
+
verification step from constantly returning empty results.
|
| 270 |
+
|
| 271 |
+
Strategy: find countries whose geometry intersects any ocean/sea in
|
| 272 |
+
natural_earth, then filter containment_pairs to those countries.
|
| 273 |
+
"""
|
| 274 |
+
print("\nComputing coastal containment pairs...")
|
| 275 |
+
|
| 276 |
+
cfilter, cparams = _country_filter(countries)
|
| 277 |
+
|
| 278 |
+
query = f"""
|
| 279 |
+
WITH coastal_countries AS (
|
| 280 |
+
SELECT DISTINCT d.country
|
| 281 |
+
FROM read_parquet(?) AS d
|
| 282 |
+
JOIN read_parquet(?) AS n
|
| 283 |
+
ON ST_Intersects(d.geometry, ST_SetCRS(n.geometry, 'OGC:CRS84'))
|
| 284 |
+
WHERE d.subtype = 'country'
|
| 285 |
+
AND n.subtype IN ('sea', 'ocean')
|
| 286 |
+
),
|
| 287 |
+
features AS (
|
| 288 |
+
SELECT
|
| 289 |
+
id,
|
| 290 |
+
names."primary" AS name,
|
| 291 |
+
subtype,
|
| 292 |
+
country,
|
| 293 |
+
admin_level,
|
| 294 |
+
geometry,
|
| 295 |
+
ST_Envelope(geometry) AS bbox
|
| 296 |
+
FROM read_parquet(?)
|
| 297 |
+
{cfilter}
|
| 298 |
+
)
|
| 299 |
+
SELECT
|
| 300 |
+
a.id AS container_id,
|
| 301 |
+
a.name AS container_name,
|
| 302 |
+
a.subtype AS container_subtype,
|
| 303 |
+
b.id AS contained_id,
|
| 304 |
+
b.name AS contained_name,
|
| 305 |
+
b.subtype AS contained_subtype,
|
| 306 |
+
a.country AS container_country,
|
| 307 |
+
'coastal_containment' AS relation_type
|
| 308 |
+
FROM features AS a
|
| 309 |
+
JOIN features AS b ON (
|
| 310 |
+
a.id != b.id
|
| 311 |
+
AND a.admin_level < b.admin_level
|
| 312 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 313 |
+
AND ST_Within(b.geometry, a.geometry)
|
| 314 |
+
)
|
| 315 |
+
WHERE a.country IN (SELECT country FROM coastal_countries)
|
| 316 |
+
LIMIT ?
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
df = con.execute(
|
| 320 |
+
query,
|
| 321 |
+
[DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH] + cparams + [DIVISIONS_AREA_PATH, limit],
|
| 322 |
+
).fetchdf()
|
| 323 |
+
print(f"Found {len(df)} coastal containment pairs")
|
| 324 |
+
return df
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def compute_landlocked_containment_pairs(
|
| 328 |
+
con: duckdb.DuckDBPyConnection,
|
| 329 |
+
countries: list,
|
| 330 |
+
limit: int,
|
| 331 |
+
) -> pd.DataFrame:
|
| 332 |
+
"""Containment pairs where the container is in a landlocked country.
|
| 333 |
+
|
| 334 |
+
Used by chained_02 (landlocked regions within X) to ensure sampled
|
| 335 |
+
anchors genuinely have no sea access, keeping SQL verification from
|
| 336 |
+
always returning empty.
|
| 337 |
+
"""
|
| 338 |
+
print("\nComputing landlocked containment pairs...")
|
| 339 |
+
|
| 340 |
+
cfilter, cparams = _country_filter(countries)
|
| 341 |
+
|
| 342 |
+
query = f"""
|
| 343 |
+
WITH coastal_countries AS (
|
| 344 |
+
SELECT DISTINCT d.country
|
| 345 |
+
FROM read_parquet(?) AS d
|
| 346 |
+
JOIN read_parquet(?) AS n
|
| 347 |
+
ON ST_Intersects(d.geometry, ST_SetCRS(n.geometry, 'OGC:CRS84'))
|
| 348 |
+
WHERE d.subtype = 'country'
|
| 349 |
+
AND n.subtype IN ('sea', 'ocean')
|
| 350 |
+
),
|
| 351 |
+
features AS (
|
| 352 |
+
SELECT
|
| 353 |
+
id,
|
| 354 |
+
names."primary" AS name,
|
| 355 |
+
subtype,
|
| 356 |
+
country,
|
| 357 |
+
admin_level,
|
| 358 |
+
geometry,
|
| 359 |
+
ST_Envelope(geometry) AS bbox
|
| 360 |
+
FROM read_parquet(?)
|
| 361 |
+
{cfilter}
|
| 362 |
+
)
|
| 363 |
+
SELECT
|
| 364 |
+
a.id AS container_id,
|
| 365 |
+
a.name AS container_name,
|
| 366 |
+
a.subtype AS container_subtype,
|
| 367 |
+
b.id AS contained_id,
|
| 368 |
+
b.name AS contained_name,
|
| 369 |
+
b.subtype AS contained_subtype,
|
| 370 |
+
a.country AS container_country,
|
| 371 |
+
'landlocked_containment' AS relation_type
|
| 372 |
+
FROM features AS a
|
| 373 |
+
JOIN features AS b ON (
|
| 374 |
+
a.id != b.id
|
| 375 |
+
AND a.admin_level < b.admin_level
|
| 376 |
+
AND ST_Intersects(a.bbox, b.bbox)
|
| 377 |
+
AND ST_Within(b.geometry, a.geometry)
|
| 378 |
+
)
|
| 379 |
+
WHERE a.country NOT IN (SELECT country FROM coastal_countries)
|
| 380 |
+
LIMIT ?
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
df = con.execute(
|
| 384 |
+
query,
|
| 385 |
+
[DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH] + cparams + [DIVISIONS_AREA_PATH, limit],
|
| 386 |
+
).fetchdf()
|
| 387 |
+
print(f"Found {len(df)} landlocked containment pairs")
|
| 388 |
+
return df
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def compute_common_neighbor_pairs(
|
| 392 |
+
con: duckdb.DuckDBPyConnection,
|
| 393 |
+
countries: list,
|
| 394 |
+
limit: int,
|
| 395 |
+
) -> pd.DataFrame:
|
| 396 |
+
"""Pairs of anchors that share at least one common touching neighbour.
|
| 397 |
+
|
| 398 |
+
Used by multi_adj_01 (borders both X and Y) so that the generated SQL
|
| 399 |
+
is guaranteed to return at least one result rather than failing constantly
|
| 400 |
+
on random pairs that have no common neighbour.
|
| 401 |
+
|
| 402 |
+
Derived by self-joining adjacency_pairs on the shared target_id.
|
| 403 |
+
"""
|
| 404 |
+
print("\nComputing common-neighbor pairs...")
|
| 405 |
+
|
| 406 |
+
adj_path = Path(__file__).parent.parent / "intermediate" / "adjacency_pairs.parquet"
|
| 407 |
+
if not adj_path.exists():
|
| 408 |
+
print(" adjacency_pairs.parquet not found — skipping (run adjacency first)")
|
| 409 |
+
return pd.DataFrame(columns=[
|
| 410 |
+
"anchor_id_1", "anchor_name_1", "anchor_id_2", "anchor_name_2",
|
| 411 |
+
"shared_neighbor_id", "shared_neighbor_name",
|
| 412 |
+
])
|
| 413 |
+
|
| 414 |
+
query = """
|
| 415 |
+
SELECT DISTINCT
|
| 416 |
+
a1.anchor_id AS anchor_id_1,
|
| 417 |
+
a1.anchor_name AS anchor_name_1,
|
| 418 |
+
a2.anchor_id AS anchor_id_2,
|
| 419 |
+
a2.anchor_name AS anchor_name_2,
|
| 420 |
+
a1.target_id AS shared_neighbor_id,
|
| 421 |
+
a1.target_name AS shared_neighbor_name
|
| 422 |
+
FROM read_parquet(?) AS a1
|
| 423 |
+
JOIN read_parquet(?) AS a2
|
| 424 |
+
ON a1.target_id = a2.target_id
|
| 425 |
+
AND a1.anchor_id < a2.anchor_id
|
| 426 |
+
LIMIT ?
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
df = con.execute(query, [str(adj_path), str(adj_path), limit]).fetchdf()
|
| 430 |
+
print(f"Found {len(df)} common-neighbor pairs")
|
| 431 |
+
return df
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _make_connection():
|
| 435 |
+
"""Create a new DuckDB connection with spatial extension loaded."""
|
| 436 |
+
con = duckdb.connect()
|
| 437 |
+
con.execute("INSTALL spatial")
|
| 438 |
+
con.execute("LOAD spatial")
|
| 439 |
+
con.execute("SET memory_limit='24GB'")
|
| 440 |
+
con.execute("SET temp_directory='/tmp/duckdb_tmp'")
|
| 441 |
+
con.execute("SET threads=4")
|
| 442 |
+
return con
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _compute_and_save(compute_fn, countries, limit, output_path):
|
| 446 |
+
"""Compute a relation table and save it to parquet. Uses its own DuckDB connection."""
|
| 447 |
+
con = _make_connection()
|
| 448 |
+
try:
|
| 449 |
+
df = compute_fn(con, countries, limit)
|
| 450 |
+
df.to_parquet(output_path, index=False)
|
| 451 |
+
print(f"Saved to {output_path}")
|
| 452 |
+
return df
|
| 453 |
+
finally:
|
| 454 |
+
con.close()
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
RELATION_FUNCTIONS = {
|
| 458 |
+
"adjacency": compute_adjacency_pairs,
|
| 459 |
+
"containment": compute_containment_pairs,
|
| 460 |
+
"intersection": compute_intersection_pairs,
|
| 461 |
+
"cross_source": compute_cross_source_relations,
|
| 462 |
+
"coastal_containment": compute_coastal_containment_pairs,
|
| 463 |
+
"landlocked_containment": compute_landlocked_containment_pairs,
|
| 464 |
+
"common_neighbor": compute_common_neighbor_pairs,
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def compute_single_relation(
|
| 469 |
+
relation_type: str,
|
| 470 |
+
countries: list,
|
| 471 |
+
limit: int,
|
| 472 |
+
output_dir: Path,
|
| 473 |
+
) -> int:
|
| 474 |
+
"""Compute one relation type and save to output_dir.
|
| 475 |
+
|
| 476 |
+
Returns the number of rows computed. Usable from Modal or locally.
|
| 477 |
+
"""
|
| 478 |
+
compute_fn = RELATION_FUNCTIONS.get(relation_type)
|
| 479 |
+
if compute_fn is None:
|
| 480 |
+
raise ValueError(
|
| 481 |
+
f"Unknown relation type: {relation_type}. "
|
| 482 |
+
f"Expected one of {list(RELATION_FUNCTIONS)}"
|
| 483 |
+
)
|
| 484 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 485 |
+
output_path = output_dir / f"{relation_type}_pairs.parquet"
|
| 486 |
+
df = _compute_and_save(compute_fn, countries, limit, output_path)
|
| 487 |
+
return len(df)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def main(countries: list = None, relation_limits: dict = None):
|
| 491 |
+
"""Compute and save all relation tables in parallel.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
countries: List of country codes to process
|
| 495 |
+
relation_limits: Dict with keys: adjacency, containment, intersection, cross_source
|
| 496 |
+
"""
|
| 497 |
+
# Defaults
|
| 498 |
+
if countries is None:
|
| 499 |
+
countries = ['EC', 'BE', 'KE', 'AE', 'SG', 'CH']
|
| 500 |
+
if relation_limits is None:
|
| 501 |
+
relation_limits = {
|
| 502 |
+
'adjacency': 50000,
|
| 503 |
+
'containment': 1000,
|
| 504 |
+
'intersection': 500,
|
| 505 |
+
'cross_source': 500,
|
| 506 |
+
'coastal_containment': 1000,
|
| 507 |
+
'landlocked_containment': 500,
|
| 508 |
+
'common_neighbor': 5000,
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
output_dir = Path(__file__).parent.parent / "intermediate"
|
| 512 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 513 |
+
|
| 514 |
+
# Define all relation tasks.
|
| 515 |
+
# common_neighbor depends on adjacency_pairs so it runs after adjacency.
|
| 516 |
+
tasks = [
|
| 517 |
+
("adjacency", compute_adjacency_pairs, relation_limits['adjacency'], output_dir / "adjacency_pairs.parquet"),
|
| 518 |
+
("containment", compute_containment_pairs, relation_limits['containment'], output_dir / "containment_pairs.parquet"),
|
| 519 |
+
("intersection", compute_intersection_pairs, relation_limits['intersection'], output_dir / "intersection_pairs.parquet"),
|
| 520 |
+
("cross_source", compute_cross_source_relations, relation_limits['cross_source'], output_dir / "cross_source_relations.parquet"),
|
| 521 |
+
("coastal_containment", compute_coastal_containment_pairs, relation_limits['coastal_containment'], output_dir / "coastal_containment_pairs.parquet"),
|
| 522 |
+
("landlocked_containment", compute_landlocked_containment_pairs, relation_limits['landlocked_containment'], output_dir / "landlocked_containment_pairs.parquet"),
|
| 523 |
+
("common_neighbor", compute_common_neighbor_pairs, relation_limits['common_neighbor'], output_dir / "common_neighbor_pairs.parquet"),
|
| 524 |
+
]
|
| 525 |
+
|
| 526 |
+
# common_neighbor reads adjacency_pairs.parquet so it must run after
|
| 527 |
+
# adjacency finishes. Split into two waves.
|
| 528 |
+
independent_tasks = [t for t in tasks if t[0] != "common_neighbor"]
|
| 529 |
+
dependent_tasks = [t for t in tasks if t[0] == "common_neighbor"]
|
| 530 |
+
|
| 531 |
+
print(f"Computing {len(independent_tasks)} relation types in parallel...")
|
| 532 |
+
with ThreadPoolExecutor(max_workers=len(independent_tasks)) as executor:
|
| 533 |
+
futures = {
|
| 534 |
+
executor.submit(_compute_and_save, compute_fn, countries, limit, path): name
|
| 535 |
+
for name, compute_fn, limit, path in independent_tasks
|
| 536 |
+
}
|
| 537 |
+
for future in as_completed(futures):
|
| 538 |
+
name = futures[future]
|
| 539 |
+
try:
|
| 540 |
+
future.result()
|
| 541 |
+
except Exception as e:
|
| 542 |
+
print(f"ERROR computing {name}: {e}")
|
| 543 |
+
raise
|
| 544 |
+
|
| 545 |
+
for name, compute_fn, limit, path in dependent_tasks:
|
| 546 |
+
print(f"\nComputing {name} (depends on adjacency)...")
|
| 547 |
+
try:
|
| 548 |
+
_compute_and_save(compute_fn, countries, limit, path)
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f"ERROR computing {name}: {e}")
|
| 551 |
+
raise
|
| 552 |
+
|
| 553 |
+
print("\nRelation tables build complete")
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
if __name__ == "__main__":
|
| 557 |
+
main()
|
dataset/scripts/cli.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CLI for synthetic dataset generation.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python cli.py build-relations --config ../config.yaml
|
| 7 |
+
python cli.py generate-samples --config ../config.yaml
|
| 8 |
+
python cli.py generate-samples --config ../config.yaml --append
|
| 9 |
+
python cli.py full-pipeline --config ../config.yaml
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import subprocess
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict
|
| 17 |
+
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_config(config_path: Path) -> dict:
|
| 23 |
+
"""Load configuration from YAML file."""
|
| 24 |
+
with open(config_path) as f:
|
| 25 |
+
return yaml.safe_load(f)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def should_rebuild_relations(config: dict, intermediate_dir: Path, append: bool) -> bool:
|
| 29 |
+
"""Check if relation tables need to be rebuilt.
|
| 30 |
+
|
| 31 |
+
Returns True if:
|
| 32 |
+
- Not in append mode (always rebuild)
|
| 33 |
+
- Relation tables don't exist
|
| 34 |
+
- Countries in config differ from countries in existing relation tables
|
| 35 |
+
"""
|
| 36 |
+
if not append:
|
| 37 |
+
return True
|
| 38 |
+
|
| 39 |
+
# Check if relation tables exist
|
| 40 |
+
adjacency_file = intermediate_dir / "adjacency_pairs.parquet"
|
| 41 |
+
if not adjacency_file.exists():
|
| 42 |
+
print("WARNING: Relation tables not found, will rebuild despite append mode")
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
# Check if countries have changed
|
| 46 |
+
try:
|
| 47 |
+
df = pd.read_parquet(adjacency_file)
|
| 48 |
+
if 'anchor_country' in df.columns:
|
| 49 |
+
existing_countries = set(df['anchor_country'].unique())
|
| 50 |
+
config_countries = set(config['countries'])
|
| 51 |
+
|
| 52 |
+
if existing_countries != config_countries:
|
| 53 |
+
print(f"WARNING: Countries changed:")
|
| 54 |
+
print(f" Previous: {sorted(existing_countries)}")
|
| 55 |
+
print(f" New: {sorted(config_countries)}")
|
| 56 |
+
print(f" Will rebuild relation tables to include new countries")
|
| 57 |
+
return True
|
| 58 |
+
else:
|
| 59 |
+
print(f"Countries unchanged: {sorted(config_countries)}")
|
| 60 |
+
return False
|
| 61 |
+
else:
|
| 62 |
+
# Can't determine countries, rebuild to be safe
|
| 63 |
+
print("WARNING: Cannot determine countries from existing tables, will rebuild")
|
| 64 |
+
return True
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"WARNING: Error reading existing relation tables: {e}")
|
| 67 |
+
print(" Will rebuild to be safe")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def calculate_relation_limits(config: dict) -> Dict[str, int]:
|
| 72 |
+
"""Auto-calculate relation limits based on sample targets."""
|
| 73 |
+
sample_targets = config['sample_targets']
|
| 74 |
+
retry_mult = config['generation']['retry_multiplier']
|
| 75 |
+
safety = config.get('auto_scaling', {}).get('safety_factor', 1.5)
|
| 76 |
+
|
| 77 |
+
# Map each task family to the relation tables it draws anchors from.
|
| 78 |
+
# A family can need multiple relation types.
|
| 79 |
+
family_to_relations = {
|
| 80 |
+
'direct_lookup': [],
|
| 81 |
+
'adjacency': ['adjacency'],
|
| 82 |
+
'multi_adjacency': ['adjacency', 'common_neighbor'],
|
| 83 |
+
'containment': ['containment'],
|
| 84 |
+
'intersection': ['intersection', 'cross_source'],
|
| 85 |
+
'buffer': ['adjacency'],
|
| 86 |
+
'chained': ['coastal_containment', 'landlocked_containment', 'containment'],
|
| 87 |
+
'difference': ['containment', 'cross_source'],
|
| 88 |
+
'border_corridor': ['adjacency'],
|
| 89 |
+
'set_operations': ['containment', 'cross_source'],
|
| 90 |
+
'partial_selection': ['containment', 'cross_source'],
|
| 91 |
+
'aggregation': ['containment'],
|
| 92 |
+
'window_function': [],
|
| 93 |
+
'attribute_filter': [],
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
relation_needs: Dict[str, int] = {}
|
| 97 |
+
for family, target in sample_targets.items():
|
| 98 |
+
for rel_type in family_to_relations.get(family, []):
|
| 99 |
+
needed = int(target * retry_mult * safety)
|
| 100 |
+
relation_needs[rel_type] = relation_needs.get(rel_type, 0) + needed
|
| 101 |
+
|
| 102 |
+
# common_neighbor is derived from adjacency — keep its limit proportional
|
| 103 |
+
if 'common_neighbor' not in relation_needs and 'adjacency' in relation_needs:
|
| 104 |
+
relation_needs['common_neighbor'] = relation_needs['adjacency'] * 3
|
| 105 |
+
|
| 106 |
+
# Apply manual overrides if specified
|
| 107 |
+
manual = config.get('auto_scaling', {}).get('manual_limits', {})
|
| 108 |
+
relation_needs.update(manual)
|
| 109 |
+
|
| 110 |
+
return relation_needs
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_relations(config_path: Path):
|
| 114 |
+
"""Run relation building with config."""
|
| 115 |
+
config = load_config(config_path)
|
| 116 |
+
|
| 117 |
+
# Auto-calculate relation limits
|
| 118 |
+
relation_limits = calculate_relation_limits(config)
|
| 119 |
+
|
| 120 |
+
print("=" * 60)
|
| 121 |
+
print("STEP 1: Building Relation Tables")
|
| 122 |
+
print("=" * 60)
|
| 123 |
+
print(f"Countries: {', '.join(config['countries'])}")
|
| 124 |
+
print(f"\nAuto-calculated relation limits:")
|
| 125 |
+
for rel_type, limit in relation_limits.items():
|
| 126 |
+
print(f" {rel_type:20s}: {limit:,}")
|
| 127 |
+
print()
|
| 128 |
+
|
| 129 |
+
# Import and run the relation builder
|
| 130 |
+
from dataset.scripts import build_relations
|
| 131 |
+
|
| 132 |
+
# Run with config parameters
|
| 133 |
+
build_relations.main(
|
| 134 |
+
countries=config['countries'],
|
| 135 |
+
relation_limits=relation_limits
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
print("\nRelation tables built successfully")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def generate_samples(config_path: Path, append: bool = False):
|
| 142 |
+
"""Run sample generation with config."""
|
| 143 |
+
config = load_config(config_path)
|
| 144 |
+
|
| 145 |
+
print("=" * 60)
|
| 146 |
+
print("STEP 2: Generating Samples")
|
| 147 |
+
print("=" * 60)
|
| 148 |
+
print(f"Targets: {config['sample_targets']}")
|
| 149 |
+
print(f"Workers: {config['generation']['max_workers']}")
|
| 150 |
+
print(f"Append mode: {append or config['generation']['append_mode']}")
|
| 151 |
+
print()
|
| 152 |
+
|
| 153 |
+
# Simple import - no number prefixes needed
|
| 154 |
+
from dataset.scripts import generate_samples as gs_module
|
| 155 |
+
|
| 156 |
+
# Override config values
|
| 157 |
+
gs_module.TARGET_COUNTS = config['sample_targets']
|
| 158 |
+
gs_module.MAX_WORKERS = config['generation']['max_workers']
|
| 159 |
+
gs_module.RETRY_MULTIPLIER = config['generation']['retry_multiplier']
|
| 160 |
+
gs_module.APPEND_MODE = append or config['generation']['append_mode']
|
| 161 |
+
|
| 162 |
+
# Run the main function
|
| 163 |
+
gs_module.main()
|
| 164 |
+
|
| 165 |
+
print("\nSamples generated successfully")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def validate_dataset(config_path: Path):
|
| 169 |
+
"""Run dataset validation."""
|
| 170 |
+
print("=" * 60)
|
| 171 |
+
print("STEP 3: Validating Dataset")
|
| 172 |
+
print("=" * 60)
|
| 173 |
+
|
| 174 |
+
script_dir = Path(__file__).parent
|
| 175 |
+
result = subprocess.run(
|
| 176 |
+
[sys.executable, str(script_dir / "validate_dataset.py")],
|
| 177 |
+
check=True
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
print("\nDataset validated successfully")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def export_dataset(config_path: Path):
|
| 184 |
+
"""Run dataset export for both SQL generation and place extraction tasks."""
|
| 185 |
+
print("=" * 60)
|
| 186 |
+
print("STEP 4: Exporting Dataset")
|
| 187 |
+
print("=" * 60)
|
| 188 |
+
|
| 189 |
+
from dataset.scripts.export_training_data import main as export_main
|
| 190 |
+
export_main(config_path=config_path)
|
| 191 |
+
|
| 192 |
+
print("\nDataset exported successfully")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def modal_upload(config_path: Path):
|
| 196 |
+
"""Upload local data to Modal volume."""
|
| 197 |
+
subprocess.run(
|
| 198 |
+
[sys.executable, "-m", "modal", "run",
|
| 199 |
+
"dataset/modal_app.py::upload_data"],
|
| 200 |
+
check=True
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def modal_generate(config_path: Path, num_containers: int = 0,
|
| 205 |
+
skip_inventory: bool = False, skip_relations: bool = False,
|
| 206 |
+
fresh: bool = False):
|
| 207 |
+
"""Run distributed generation on Modal (appends by default)."""
|
| 208 |
+
cmd = [
|
| 209 |
+
sys.executable, "-m", "modal", "run",
|
| 210 |
+
"dataset/modal_app.py::run_pipeline",
|
| 211 |
+
"--config-path", str(config_path),
|
| 212 |
+
]
|
| 213 |
+
if num_containers > 0:
|
| 214 |
+
cmd.extend(["--num-containers", str(num_containers)])
|
| 215 |
+
if skip_inventory:
|
| 216 |
+
cmd.append("--skip-inventory")
|
| 217 |
+
if skip_relations:
|
| 218 |
+
cmd.append("--skip-relations")
|
| 219 |
+
if fresh:
|
| 220 |
+
cmd.append("--fresh")
|
| 221 |
+
|
| 222 |
+
subprocess.run(cmd, check=True)
|
| 223 |
+
validate_dataset(config_path)
|
| 224 |
+
export_dataset(config_path)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def full_pipeline(config_path: Path, append: bool = False):
|
| 228 |
+
"""Run the full pipeline."""
|
| 229 |
+
print("Running full dataset generation pipeline")
|
| 230 |
+
|
| 231 |
+
config = load_config(config_path)
|
| 232 |
+
|
| 233 |
+
# Check if inventory exists, create if not
|
| 234 |
+
script_dir = Path(__file__).parent
|
| 235 |
+
intermediate_dir = script_dir.parent / "intermediate"
|
| 236 |
+
inventory_files = [
|
| 237 |
+
intermediate_dir / "divisions_area_inventory.parquet",
|
| 238 |
+
intermediate_dir / "natural_earth_inventory.parquet"
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
inventory_missing = any(not f.exists() for f in inventory_files)
|
| 242 |
+
|
| 243 |
+
if inventory_missing:
|
| 244 |
+
print("=" * 60)
|
| 245 |
+
print("STEP 0: Building Entity Inventory")
|
| 246 |
+
print("=" * 60)
|
| 247 |
+
print("Inventory files not found, building...")
|
| 248 |
+
from dataset.scripts import build_inventory
|
| 249 |
+
build_inventory.main()
|
| 250 |
+
|
| 251 |
+
# Check if we need to rebuild relations
|
| 252 |
+
need_rebuild = should_rebuild_relations(config, intermediate_dir, append)
|
| 253 |
+
|
| 254 |
+
if need_rebuild:
|
| 255 |
+
build_relations(config_path)
|
| 256 |
+
else:
|
| 257 |
+
print("Using existing relation tables (append mode, same countries)")
|
| 258 |
+
|
| 259 |
+
generate_samples(config_path, append=append)
|
| 260 |
+
validate_dataset(config_path)
|
| 261 |
+
export_dataset(config_path)
|
| 262 |
+
|
| 263 |
+
print("\nPipeline complete")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def main():
|
| 267 |
+
parser = argparse.ArgumentParser(
|
| 268 |
+
description="Synthetic dataset generation CLI",
|
| 269 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 270 |
+
epilog="""
|
| 271 |
+
Examples:
|
| 272 |
+
# Build relation tables only
|
| 273 |
+
python cli.py build-relations --config ../config.yaml
|
| 274 |
+
|
| 275 |
+
# Generate samples only
|
| 276 |
+
python cli.py generate-samples --config ../config.yaml
|
| 277 |
+
|
| 278 |
+
# Generate and append to existing dataset
|
| 279 |
+
python cli.py generate-samples --config ../config.yaml --append
|
| 280 |
+
|
| 281 |
+
# Run full pipeline
|
| 282 |
+
python cli.py full-pipeline --config ../config.yaml
|
| 283 |
+
|
| 284 |
+
# Run full pipeline in append mode (skip relation building)
|
| 285 |
+
python cli.py full-pipeline --config ../config.yaml --append
|
| 286 |
+
|
| 287 |
+
# Upload data to Modal volume (one-time)
|
| 288 |
+
python cli.py modal-upload --config ../config.yaml
|
| 289 |
+
|
| 290 |
+
# Run distributed generation on Modal
|
| 291 |
+
python cli.py modal-generate --config ../config.yaml
|
| 292 |
+
python cli.py modal-generate --config ../config.yaml --num-containers 100
|
| 293 |
+
python cli.py modal-generate --config ../config.yaml --skip-inventory --skip-relations
|
| 294 |
+
"""
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
'command',
|
| 299 |
+
choices=['build-relations', 'generate-samples', 'validate', 'export',
|
| 300 |
+
'full-pipeline', 'modal-upload', 'modal-generate'],
|
| 301 |
+
help='Command to run'
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
'--config',
|
| 306 |
+
type=Path,
|
| 307 |
+
required=True,
|
| 308 |
+
help='Path to config YAML file'
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
'--append',
|
| 313 |
+
action='store_true',
|
| 314 |
+
help='Append to existing dataset instead of overwriting'
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
parser.add_argument(
|
| 318 |
+
'--num-containers',
|
| 319 |
+
type=int,
|
| 320 |
+
default=0,
|
| 321 |
+
help='Number of Modal containers (0 = use config default)'
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
'--skip-inventory',
|
| 326 |
+
action='store_true',
|
| 327 |
+
help='Skip inventory building on Modal'
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
'--skip-relations',
|
| 332 |
+
action='store_true',
|
| 333 |
+
help='Skip relation building on Modal'
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
'--fresh',
|
| 338 |
+
action='store_true',
|
| 339 |
+
help='Overwrite existing dataset instead of appending (Modal only)'
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
args = parser.parse_args()
|
| 343 |
+
|
| 344 |
+
# Validate config file exists
|
| 345 |
+
if not args.config.exists():
|
| 346 |
+
print(f"Error: Config file not found: {args.config}")
|
| 347 |
+
sys.exit(1)
|
| 348 |
+
|
| 349 |
+
# Run the appropriate command
|
| 350 |
+
try:
|
| 351 |
+
if args.command == 'build-relations':
|
| 352 |
+
build_relations(args.config)
|
| 353 |
+
elif args.command == 'generate-samples':
|
| 354 |
+
generate_samples(args.config, args.append)
|
| 355 |
+
elif args.command == 'validate':
|
| 356 |
+
validate_dataset(args.config)
|
| 357 |
+
elif args.command == 'export':
|
| 358 |
+
export_dataset(args.config)
|
| 359 |
+
elif args.command == 'full-pipeline':
|
| 360 |
+
full_pipeline(args.config, args.append)
|
| 361 |
+
elif args.command == 'modal-upload':
|
| 362 |
+
modal_upload(args.config)
|
| 363 |
+
elif args.command == 'modal-generate':
|
| 364 |
+
modal_generate(
|
| 365 |
+
args.config,
|
| 366 |
+
num_containers=args.num_containers,
|
| 367 |
+
skip_inventory=args.skip_inventory,
|
| 368 |
+
skip_relations=args.skip_relations,
|
| 369 |
+
fresh=args.fresh,
|
| 370 |
+
)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
print(f"\nError: {e}")
|
| 373 |
+
sys.exit(1)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
main()
|
dataset/scripts/export_training_data.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Export validated dataset to train/val/test splits.
|
| 3 |
+
|
| 4 |
+
Produces two task datasets from the same source samples:
|
| 5 |
+
|
| 6 |
+
1. SQL generation (prompt = question + candidates CSV, completion = SQL)
|
| 7 |
+
2. Place extraction (prompt = question only, completion = PlacesResult JSON)
|
| 8 |
+
|
| 9 |
+
Place extraction pairs are derived automatically: for each SQL sample the
|
| 10 |
+
selected_candidates give us the correct place names, subtypes, and country
|
| 11 |
+
codes that the extractor should return.
|
| 12 |
+
|
| 13 |
+
Output layout (all paths relative to dataset/):
|
| 14 |
+
output/runs/{run_name}/sql/train.jsonl
|
| 15 |
+
output/runs/{run_name}/sql/val.jsonl
|
| 16 |
+
output/runs/{run_name}/sql/test.jsonl
|
| 17 |
+
output/runs/{run_name}/places/train.jsonl
|
| 18 |
+
output/runs/{run_name}/places/val.jsonl
|
| 19 |
+
output/runs/{run_name}/places/test.jsonl
|
| 20 |
+
output/runs/{run_name}/stats.json
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
from collections import defaultdict
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import sqlparse
|
| 31 |
+
import yaml
|
| 32 |
+
|
| 33 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Loading
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
|
| 41 |
+
samples = []
|
| 42 |
+
with open(jsonl_path) as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
line = line.strip()
|
| 45 |
+
if line:
|
| 46 |
+
samples.append(json.loads(line))
|
| 47 |
+
return samples
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_run_name(config_path: Optional[Path]) -> str:
|
| 51 |
+
if config_path and config_path.exists():
|
| 52 |
+
with open(config_path) as f:
|
| 53 |
+
cfg = yaml.safe_load(f)
|
| 54 |
+
return cfg.get("run_name", "default")
|
| 55 |
+
return "default"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Splitting
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def stratified_split(
|
| 63 |
+
samples: List[Dict[str, Any]],
|
| 64 |
+
train_ratio: float = 0.8,
|
| 65 |
+
val_ratio: float = 0.1,
|
| 66 |
+
seed: int = 42,
|
| 67 |
+
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
| 68 |
+
"""Split stratified by task_family so every family is represented in each split."""
|
| 69 |
+
random.seed(seed)
|
| 70 |
+
by_family: Dict[str, List] = defaultdict(list)
|
| 71 |
+
for s in samples:
|
| 72 |
+
by_family[s["metadata"]["task_family"]].append(s)
|
| 73 |
+
|
| 74 |
+
train, val, test = [], [], []
|
| 75 |
+
for family_samples in by_family.values():
|
| 76 |
+
random.shuffle(family_samples)
|
| 77 |
+
n = len(family_samples)
|
| 78 |
+
n_train = int(n * train_ratio)
|
| 79 |
+
n_val = int(n * val_ratio)
|
| 80 |
+
train.extend(family_samples[:n_train])
|
| 81 |
+
val.extend(family_samples[n_train : n_train + n_val])
|
| 82 |
+
test.extend(family_samples[n_train + n_val :])
|
| 83 |
+
|
| 84 |
+
random.shuffle(train)
|
| 85 |
+
random.shuffle(val)
|
| 86 |
+
random.shuffle(test)
|
| 87 |
+
return train, val, test
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
# SQL generation format
|
| 92 |
+
# Conversational prompt-completion: model sees system + user, generates SQL.
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
|
| 95 |
+
_SQL_SYSTEM = """You are a text to SQL query translator that helps in natural language geocoding.
|
| 96 |
+
|
| 97 |
+
You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
|
| 98 |
+
|
| 99 |
+
<SCHEMA>
|
| 100 |
+
1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 101 |
+
query: read_parquet('divisions_area')
|
| 102 |
+
columns:
|
| 103 |
+
id VARCHAR -- unique feature id
|
| 104 |
+
names STRUCT("primary" VARCHAR, ...)
|
| 105 |
+
country VARCHAR -- ISO 3166-1 alpha-2
|
| 106 |
+
subtype VARCHAR -- country | region | dependency | county | localadmin |
|
| 107 |
+
locality | macrohood | neighborhood | microhood
|
| 108 |
+
class VARCHAR
|
| 109 |
+
region VARCHAR
|
| 110 |
+
admin_level INTEGER
|
| 111 |
+
division_id VARCHAR
|
| 112 |
+
is_land BOOLEAN
|
| 113 |
+
is_territorial BOOLEAN
|
| 114 |
+
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 115 |
+
|
| 116 |
+
2. natural_earth -- Natural Earth geography polygons (oceans, seas, rivers, terrain)
|
| 117 |
+
query: read_parquet('natural_earth')
|
| 118 |
+
columns:
|
| 119 |
+
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 120 |
+
names STRUCT("primary" VARCHAR, ...)
|
| 121 |
+
country VARCHAR
|
| 122 |
+
subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
|
| 123 |
+
class VARCHAR
|
| 124 |
+
region VARCHAR
|
| 125 |
+
admin_level INTEGER
|
| 126 |
+
is_land BOOLEAN
|
| 127 |
+
is_territorial BOOLEAN
|
| 128 |
+
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 129 |
+
</SCHEMA>
|
| 130 |
+
|
| 131 |
+
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 132 |
+
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 133 |
+
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 134 |
+
|
| 135 |
+
_CANDIDATES_COLS = [
|
| 136 |
+
"source", "id", "name", "subtype", "country", "region",
|
| 137 |
+
"admin_level", "similarity",
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _candidates_csv(candidates: List[Dict]) -> str:
|
| 142 |
+
import io
|
| 143 |
+
import csv
|
| 144 |
+
rows = []
|
| 145 |
+
for c in candidates:
|
| 146 |
+
row = {col: c.get(col, "") for col in _CANDIDATES_COLS if col in c}
|
| 147 |
+
rows.append(row)
|
| 148 |
+
if not rows:
|
| 149 |
+
return ""
|
| 150 |
+
buf = io.StringIO()
|
| 151 |
+
writer = csv.DictWriter(buf, fieldnames=[k for k in _CANDIDATES_COLS if k in rows[0]])
|
| 152 |
+
writer.writeheader()
|
| 153 |
+
writer.writerows(rows)
|
| 154 |
+
return buf.getvalue().strip()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _to_symbolic_sql(sql: str) -> str:
|
| 158 |
+
"""Normalize any hardcoded or runtime paths back to symbolic names."""
|
| 159 |
+
sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
|
| 160 |
+
sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
|
| 161 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
|
| 162 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
|
| 163 |
+
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
|
| 164 |
+
return sql
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _format_sql(sql: str) -> str:
|
| 168 |
+
"""Pretty-print SQL so the model learns clean, readable style."""
|
| 169 |
+
return sqlparse.format(
|
| 170 |
+
sql,
|
| 171 |
+
reindent=True,
|
| 172 |
+
keyword_case="upper",
|
| 173 |
+
indent_width=4,
|
| 174 |
+
).strip()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
| 178 |
+
"""Convert a raw sample to a conversational prompt-completion pair for SQL generation."""
|
| 179 |
+
sql = sample.get("target", {}).get("sql", "").strip()
|
| 180 |
+
if not sql:
|
| 181 |
+
return None
|
| 182 |
+
sql = _format_sql(_to_symbolic_sql(sql))
|
| 183 |
+
|
| 184 |
+
user_content = (
|
| 185 |
+
f"<CANDIDATES>\n{_candidates_csv(sample.get('candidates', []))}\n</CANDIDATES>\n\n"
|
| 186 |
+
f"<USER_QUERY>\n{sample['question']}\n</USER_QUERY>"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return {
|
| 190 |
+
"messages": [
|
| 191 |
+
{"role": "system", "content": _SQL_SYSTEM},
|
| 192 |
+
{"role": "user", "content": user_content},
|
| 193 |
+
{"role": "assistant", "content": sql},
|
| 194 |
+
],
|
| 195 |
+
"metadata": sample.get("metadata", {}),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ---------------------------------------------------------------------------
|
| 200 |
+
# Place extraction format
|
| 201 |
+
# Derived from the same SQL samples: selected_candidates → PlacesResult JSON.
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
|
| 204 |
+
_PLACE_SYSTEM = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
|
| 205 |
+
|
| 206 |
+
OUTPUT FORMAT:
|
| 207 |
+
{"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
|
| 208 |
+
"country" and "subtype" are optional; omit if not applicable.
|
| 209 |
+
|
| 210 |
+
RULES:
|
| 211 |
+
- Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
|
| 212 |
+
- No duplicate place names.
|
| 213 |
+
- "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
|
| 214 |
+
- "subtype": include only when the geographic level is clear from the query.
|
| 215 |
+
|
| 216 |
+
SUBTYPES:
|
| 217 |
+
country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
|
| 218 |
+
- Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
|
| 219 |
+
|
| 220 |
+
# Overture division subtypes — used to filter out natural_earth candidates
|
| 221 |
+
# from the place extraction output (NE features don't have these subtypes).
|
| 222 |
+
_DIVISION_SUBTYPES = {
|
| 223 |
+
"country", "region", "dependency", "county", "localadmin",
|
| 224 |
+
"locality", "macrohood", "neighborhood", "microhood",
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _candidate_to_place(c: Dict) -> Optional[Dict]:
|
| 229 |
+
"""Convert a selected candidate to a Place dict for PlacesResult."""
|
| 230 |
+
name = c.get("name", "").strip()
|
| 231 |
+
if not name:
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
place: Dict[str, Any] = {"place": name}
|
| 235 |
+
|
| 236 |
+
subtype = c.get("subtype", "")
|
| 237 |
+
if subtype in _DIVISION_SUBTYPES:
|
| 238 |
+
place["subtype"] = subtype
|
| 239 |
+
|
| 240 |
+
country = c.get("country", "")
|
| 241 |
+
if country and len(country) == 2:
|
| 242 |
+
place["country"] = country
|
| 243 |
+
|
| 244 |
+
return place
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def sample_to_place_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
| 248 |
+
"""Convert a raw sample to a conversational prompt-completion pair for place extraction.
|
| 249 |
+
|
| 250 |
+
Uses selected_candidates to determine the correct PlacesResult output.
|
| 251 |
+
Skips samples where no valid places can be derived.
|
| 252 |
+
"""
|
| 253 |
+
selected_ids = set(sample.get("target", {}).get("selected_candidates", []))
|
| 254 |
+
if not selected_ids:
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
id_to_candidate = {c["candidate_id"]: c for c in sample.get("candidates", [])}
|
| 258 |
+
places = []
|
| 259 |
+
seen_names: set = set()
|
| 260 |
+
|
| 261 |
+
for cid in selected_ids:
|
| 262 |
+
c = id_to_candidate.get(cid)
|
| 263 |
+
if not c:
|
| 264 |
+
continue
|
| 265 |
+
place = _candidate_to_place(c)
|
| 266 |
+
if place and place["place"].lower() not in seen_names:
|
| 267 |
+
places.append(place)
|
| 268 |
+
seen_names.add(place["place"].lower())
|
| 269 |
+
|
| 270 |
+
if not places:
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
completion_json = json.dumps({"places": places}, ensure_ascii=False)
|
| 274 |
+
|
| 275 |
+
return {
|
| 276 |
+
"messages": [
|
| 277 |
+
{"role": "system", "content": _PLACE_SYSTEM},
|
| 278 |
+
{"role": "user", "content": sample["question"]},
|
| 279 |
+
{"role": "assistant", "content": completion_json},
|
| 280 |
+
],
|
| 281 |
+
"metadata": sample.get("metadata", {}),
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ---------------------------------------------------------------------------
|
| 286 |
+
# I/O helpers
|
| 287 |
+
# ---------------------------------------------------------------------------
|
| 288 |
+
|
| 289 |
+
def save_jsonl(records: List[Dict], path: Path) -> None:
|
| 290 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 291 |
+
with open(path, "w") as f:
|
| 292 |
+
for r in records:
|
| 293 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def split_stats(samples: List[Dict]) -> Dict[str, int]:
|
| 297 |
+
counts: Dict[str, int] = defaultdict(int)
|
| 298 |
+
for s in samples:
|
| 299 |
+
counts[s.get("metadata", {}).get("task_family", "unknown")] += 1
|
| 300 |
+
return dict(sorted(counts.items()))
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# ---------------------------------------------------------------------------
|
| 304 |
+
# Main
|
| 305 |
+
# ---------------------------------------------------------------------------
|
| 306 |
+
|
| 307 |
+
def main(config_path: Optional[Path] = None) -> None:
|
| 308 |
+
script_dir = Path(__file__).parent
|
| 309 |
+
dataset_dir = script_dir.parent
|
| 310 |
+
output_dir = dataset_dir / "output"
|
| 311 |
+
|
| 312 |
+
run_name = load_run_name(config_path or dataset_dir / "config.yaml")
|
| 313 |
+
|
| 314 |
+
validated_file = output_dir / "dataset_validated.jsonl"
|
| 315 |
+
if not validated_file.exists():
|
| 316 |
+
print(f"Error: {validated_file} not found. Run validate first.")
|
| 317 |
+
sys.exit(1)
|
| 318 |
+
|
| 319 |
+
run_dir = output_dir / "runs" / run_name
|
| 320 |
+
sql_dir = run_dir / "sql"
|
| 321 |
+
places_dir = run_dir / "places"
|
| 322 |
+
|
| 323 |
+
print(f"Run name : {run_name}")
|
| 324 |
+
print(f"Output dir : {run_dir}")
|
| 325 |
+
|
| 326 |
+
# Load
|
| 327 |
+
print("\nLoading validated samples...")
|
| 328 |
+
samples = load_samples(validated_file)
|
| 329 |
+
print(f" {len(samples):,} samples loaded")
|
| 330 |
+
|
| 331 |
+
# Split once, reuse for both tasks
|
| 332 |
+
print("\nSplitting 80 / 10 / 10 (stratified by task family)...")
|
| 333 |
+
train_raw, val_raw, test_raw = stratified_split(samples)
|
| 334 |
+
print(f" train={len(train_raw):,} val={len(val_raw):,} test={len(test_raw):,}")
|
| 335 |
+
|
| 336 |
+
# --- SQL generation ---
|
| 337 |
+
print("\nBuilding SQL generation splits...")
|
| 338 |
+
sql_stats: Dict = {}
|
| 339 |
+
for split_name, raw in [("train", train_raw), ("val", val_raw), ("test", test_raw)]:
|
| 340 |
+
pairs = [p for s in raw if (p := sample_to_sql_pair(s)) is not None]
|
| 341 |
+
save_jsonl(pairs, sql_dir / f"{split_name}.jsonl")
|
| 342 |
+
sql_stats[split_name] = {"total": len(pairs), "by_family": split_stats(pairs)}
|
| 343 |
+
print(f" sql/{split_name}.jsonl — {len(pairs):,} pairs")
|
| 344 |
+
|
| 345 |
+
# --- Place extraction ---
|
| 346 |
+
print("\nBuilding place extraction splits...")
|
| 347 |
+
place_stats: Dict = {}
|
| 348 |
+
for split_name, raw in [("train", train_raw), ("val", val_raw), ("test", test_raw)]:
|
| 349 |
+
pairs = [p for s in raw if (p := sample_to_place_pair(s)) is not None]
|
| 350 |
+
save_jsonl(pairs, places_dir / f"{split_name}.jsonl")
|
| 351 |
+
place_stats[split_name] = {"total": len(pairs), "by_family": split_stats(pairs)}
|
| 352 |
+
print(f" places/{split_name}.jsonl — {len(pairs):,} pairs")
|
| 353 |
+
|
| 354 |
+
# --- Stats ---
|
| 355 |
+
stats = {
|
| 356 |
+
"run_name": run_name,
|
| 357 |
+
"total_samples": len(samples),
|
| 358 |
+
"sql_generation": sql_stats,
|
| 359 |
+
"place_extraction": place_stats,
|
| 360 |
+
}
|
| 361 |
+
stats_path = run_dir / "stats.json"
|
| 362 |
+
with open(stats_path, "w") as f:
|
| 363 |
+
json.dump(stats, f, indent=2)
|
| 364 |
+
|
| 365 |
+
print(f"\nStats written to {stats_path}")
|
| 366 |
+
print("\nDone. Training-ready files:")
|
| 367 |
+
print(f" SQL generation : {sql_dir}/{{train,val,test}}.jsonl")
|
| 368 |
+
print(f" Place extraction: {places_dir}/{{train,val,test}}.jsonl")
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
main()
|
dataset/scripts/generate_samples.py
ADDED
|
@@ -0,0 +1,1560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate synthetic training samples for text-to-SQL task.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads relation tables and entity inventories
|
| 6 |
+
2. For each SQL template, samples valid anchors
|
| 7 |
+
3. Renders and executes SQL to verify it works
|
| 8 |
+
4. Builds candidate lists with controlled distractors
|
| 9 |
+
5. Generates natural language questions using LLM
|
| 10 |
+
6. Saves complete training samples
|
| 11 |
+
|
| 12 |
+
Output:
|
| 13 |
+
- output/samples/sample_*.json (individual samples)
|
| 14 |
+
- output/dataset_raw.jsonl (all samples)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import random
|
| 19 |
+
import warnings
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import List, Dict, Any, Optional
|
| 22 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 23 |
+
from functools import partial
|
| 24 |
+
|
| 25 |
+
import duckdb
|
| 26 |
+
import pandas as pd
|
| 27 |
+
from pydantic import BaseModel
|
| 28 |
+
|
| 29 |
+
# Suppress warnings
|
| 30 |
+
warnings.filterwarnings('ignore')
|
| 31 |
+
|
| 32 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 33 |
+
|
| 34 |
+
# Fixed paths embedded in every training SQL string.
|
| 35 |
+
# The model learns these short, stable strings rather than machine-specific
|
| 36 |
+
# local paths. At inference, sql.py's _rewrite_data_paths substitutes them
|
| 37 |
+
# with the actual runtime paths from gazet.config.
|
| 38 |
+
_DIVISIONS_SQL_PATH = 'divisions_area'
|
| 39 |
+
_NATURAL_EARTH_SQL_PATH = 'natural_earth'
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _for_execution(sql: str) -> str:
|
| 43 |
+
"""Replace symbolic placeholder paths with actual local paths for verification."""
|
| 44 |
+
return (
|
| 45 |
+
sql
|
| 46 |
+
.replace("read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')")
|
| 47 |
+
.replace("read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')")
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Configurable parameters (can be overridden by CLI)
|
| 51 |
+
TARGET_COUNTS = None # Will be set in main() or by CLI
|
| 52 |
+
MAX_WORKERS = 8
|
| 53 |
+
RETRY_MULTIPLIER = 2
|
| 54 |
+
APPEND_MODE = False
|
| 55 |
+
|
| 56 |
+
# Import templates from same directory
|
| 57 |
+
from . import sql_templates
|
| 58 |
+
TEMPLATES = sql_templates.TEMPLATES
|
| 59 |
+
SQLTemplate = sql_templates.SQLTemplate
|
| 60 |
+
get_templates_by_family = sql_templates.get_templates_by_family
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Candidate(BaseModel):
|
| 64 |
+
"""Candidate entity for grounding."""
|
| 65 |
+
candidate_id: str
|
| 66 |
+
source: str
|
| 67 |
+
id: str
|
| 68 |
+
name: str
|
| 69 |
+
subtype: Optional[str] = None
|
| 70 |
+
country: Optional[str] = None
|
| 71 |
+
region: Optional[str] = None
|
| 72 |
+
admin_level: Optional[int] = None
|
| 73 |
+
similarity: float = 0.0
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TrainingSample(BaseModel):
|
| 77 |
+
"""Complete training sample."""
|
| 78 |
+
id: str
|
| 79 |
+
question: str
|
| 80 |
+
candidates: List[Candidate]
|
| 81 |
+
target: Dict[str, Any]
|
| 82 |
+
metadata: Dict[str, Any]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[str, pd.DataFrame]:
|
| 86 |
+
"""Load all precomputed relation tables."""
|
| 87 |
+
tables = {}
|
| 88 |
+
|
| 89 |
+
for file in intermediate_dir.glob("*.parquet"):
|
| 90 |
+
name = file.stem
|
| 91 |
+
tables[name] = pd.read_parquet(file)
|
| 92 |
+
if not quiet:
|
| 93 |
+
print(f" {name}: {len(tables[name])} rows")
|
| 94 |
+
|
| 95 |
+
return tables
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def sample_adjacency_anchor(adjacency_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 99 |
+
"""Sample a random adjacency pair."""
|
| 100 |
+
if adjacency_df.empty:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
row = adjacency_df.sample(n=1).iloc[0]
|
| 104 |
+
return {
|
| 105 |
+
'anchor_id': row['anchor_id'],
|
| 106 |
+
'anchor_name': row['anchor_name'],
|
| 107 |
+
'anchor_subtype': row['anchor_subtype'],
|
| 108 |
+
'anchor_country': row.get('anchor_country'), # May not exist in all tables
|
| 109 |
+
'target_subtype': row.get('target_subtype')
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def sample_intersection_anchor(intersection_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 114 |
+
"""Sample a random intersection pair."""
|
| 115 |
+
if intersection_df.empty:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
row = intersection_df.sample(n=1).iloc[0]
|
| 119 |
+
return {
|
| 120 |
+
'anchor_id': row['anchor_id'],
|
| 121 |
+
'anchor_name': row['anchor_name'],
|
| 122 |
+
'anchor_subtype': row['anchor_subtype'],
|
| 123 |
+
'target_id': row.get('target_id'),
|
| 124 |
+
'target_name': row.get('target_name'),
|
| 125 |
+
'target_subtype': row.get('target_subtype')
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 130 |
+
"""Sample a random containment pair."""
|
| 131 |
+
if containment_df.empty:
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
row = containment_df.sample(n=1).iloc[0]
|
| 135 |
+
return {
|
| 136 |
+
'container_id': row['container_id'],
|
| 137 |
+
'container_name': row['container_name'],
|
| 138 |
+
'container_subtype': row['container_subtype'],
|
| 139 |
+
'contained_subtype': row['contained_subtype']
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 144 |
+
"""Sample a random cross-source relation."""
|
| 145 |
+
if cross_source_df.empty:
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
row = cross_source_df.sample(n=1).iloc[0]
|
| 149 |
+
return {
|
| 150 |
+
'division_id': row['division_id'],
|
| 151 |
+
'division_name': row['division_name'],
|
| 152 |
+
'division_subtype': row['division_subtype'],
|
| 153 |
+
'natural_id': row['natural_id'],
|
| 154 |
+
'natural_name': row['natural_name'],
|
| 155 |
+
'natural_subtype': row['natural_subtype'],
|
| 156 |
+
'relation_type': row['relation_type']
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _merge_candidate_lists(
|
| 161 |
+
*lists: List[Candidate],
|
| 162 |
+
max_total: int = 10,
|
| 163 |
+
) -> List[Candidate]:
|
| 164 |
+
"""Merge N candidate lists, deduplicate by id, reassign candidate_ids.
|
| 165 |
+
|
| 166 |
+
Interleaves the lists so each anchor is represented before any anchor
|
| 167 |
+
gets a second candidate — matching the grouped-then-interleaved order
|
| 168 |
+
that inference produces.
|
| 169 |
+
"""
|
| 170 |
+
from itertools import zip_longest
|
| 171 |
+
|
| 172 |
+
seen: set = set()
|
| 173 |
+
merged: List[Candidate] = []
|
| 174 |
+
for row in zip_longest(*lists):
|
| 175 |
+
for c in row:
|
| 176 |
+
if c is None:
|
| 177 |
+
continue
|
| 178 |
+
if c.id not in seen:
|
| 179 |
+
merged.append(c)
|
| 180 |
+
seen.add(c.id)
|
| 181 |
+
if len(merged) >= max_total:
|
| 182 |
+
break
|
| 183 |
+
if len(merged) >= max_total:
|
| 184 |
+
break
|
| 185 |
+
for i, c in enumerate(merged, 1):
|
| 186 |
+
c.candidate_id = f"c{i}"
|
| 187 |
+
return merged
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_candidate_list(
|
| 191 |
+
con: duckdb.DuckDBPyConnection,
|
| 192 |
+
anchor_id: str,
|
| 193 |
+
anchor_name: str,
|
| 194 |
+
anchor_source: str,
|
| 195 |
+
num_candidates: int = 10,
|
| 196 |
+
difficulty: str = "medium"
|
| 197 |
+
) -> List[Candidate]:
|
| 198 |
+
"""Build candidate list with true anchor + distractors."""
|
| 199 |
+
|
| 200 |
+
# Helper to convert pandas NA to None
|
| 201 |
+
def safe_get(row, key, default=None):
|
| 202 |
+
val = row.get(key, default)
|
| 203 |
+
return None if pd.isna(val) else val
|
| 204 |
+
|
| 205 |
+
# Get the true anchor
|
| 206 |
+
if anchor_source == "divisions_area":
|
| 207 |
+
query = """
|
| 208 |
+
SELECT
|
| 209 |
+
id,
|
| 210 |
+
names."primary" AS name,
|
| 211 |
+
subtype,
|
| 212 |
+
country,
|
| 213 |
+
region,
|
| 214 |
+
admin_level
|
| 215 |
+
FROM read_parquet(?)
|
| 216 |
+
WHERE id = ?
|
| 217 |
+
"""
|
| 218 |
+
anchor_row = con.execute(query, [DIVISIONS_AREA_PATH, anchor_id]).fetchdf().iloc[0]
|
| 219 |
+
else:
|
| 220 |
+
query = """
|
| 221 |
+
SELECT
|
| 222 |
+
id,
|
| 223 |
+
names."primary" AS name,
|
| 224 |
+
subtype
|
| 225 |
+
FROM read_parquet(?)
|
| 226 |
+
WHERE id = ?
|
| 227 |
+
"""
|
| 228 |
+
anchor_row = con.execute(query, [NATURAL_EARTH_PATH, anchor_id]).fetchdf().iloc[0]
|
| 229 |
+
|
| 230 |
+
# Build true candidate
|
| 231 |
+
true_candidate = Candidate(
|
| 232 |
+
candidate_id="c1",
|
| 233 |
+
source=anchor_source,
|
| 234 |
+
id=anchor_id,
|
| 235 |
+
name=safe_get(anchor_row, 'name'),
|
| 236 |
+
subtype=safe_get(anchor_row, 'subtype'),
|
| 237 |
+
country=safe_get(anchor_row, 'country'),
|
| 238 |
+
region=safe_get(anchor_row, 'region'),
|
| 239 |
+
admin_level=safe_get(anchor_row, 'admin_level'),
|
| 240 |
+
similarity=1.0
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Build distractors based on difficulty
|
| 244 |
+
distractors = build_distractors(
|
| 245 |
+
con,
|
| 246 |
+
anchor_name,
|
| 247 |
+
anchor_source,
|
| 248 |
+
anchor_id,
|
| 249 |
+
num_candidates - 1,
|
| 250 |
+
difficulty
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Order: true anchor first, then same-source distractors, then cross-source
|
| 254 |
+
# distractors. This mirrors inference order (anchor at top by similarity,
|
| 255 |
+
# same source grouped before the other source).
|
| 256 |
+
candidates = [true_candidate] + distractors
|
| 257 |
+
|
| 258 |
+
# Reassign candidate IDs in order
|
| 259 |
+
for i, cand in enumerate(candidates, 1):
|
| 260 |
+
cand.candidate_id = f"c{i}"
|
| 261 |
+
|
| 262 |
+
return candidates
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def build_distractors(
|
| 266 |
+
con: duckdb.DuckDBPyConnection,
|
| 267 |
+
anchor_name: str,
|
| 268 |
+
anchor_source: str,
|
| 269 |
+
exclude_id: str,
|
| 270 |
+
num_distractors: int,
|
| 271 |
+
difficulty: str,
|
| 272 |
+
cross_source_ratio: float = 0.5,
|
| 273 |
+
) -> List[Candidate]:
|
| 274 |
+
"""Build distractor candidates using fuzzy search.
|
| 275 |
+
|
| 276 |
+
Always includes candidates from both sources so the model sees mixed
|
| 277 |
+
``source`` values in every training example — matching the inference
|
| 278 |
+
behaviour where search.py queries divisions_area AND natural_earth equally
|
| 279 |
+
(5 results each per place).
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
cross_source_ratio: Fraction of distractors drawn from the *other*
|
| 283 |
+
source. Defaults to 0.5 (50/50 split) to match inference exactly.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def safe_get(row, key, default=None):
|
| 287 |
+
val = row.get(key, default)
|
| 288 |
+
return None if pd.isna(val) else val
|
| 289 |
+
|
| 290 |
+
def _query_source(path: str, src_name: str, n: int, excl_id: str) -> List[Candidate]:
|
| 291 |
+
query = """
|
| 292 |
+
SELECT
|
| 293 |
+
id,
|
| 294 |
+
names."primary" AS name,
|
| 295 |
+
subtype,
|
| 296 |
+
country,
|
| 297 |
+
region,
|
| 298 |
+
admin_level,
|
| 299 |
+
jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
|
| 300 |
+
FROM read_parquet(?)
|
| 301 |
+
WHERE id != ?
|
| 302 |
+
AND names."primary" IS NOT NULL
|
| 303 |
+
ORDER BY similarity DESC
|
| 304 |
+
LIMIT ?
|
| 305 |
+
"""
|
| 306 |
+
df = con.execute(query, [anchor_name, path, excl_id, n]).fetchdf()
|
| 307 |
+
results = []
|
| 308 |
+
for _, row in df.iterrows():
|
| 309 |
+
results.append(Candidate(
|
| 310 |
+
candidate_id="temp",
|
| 311 |
+
source=src_name,
|
| 312 |
+
id=row["id"],
|
| 313 |
+
name=safe_get(row, "name"),
|
| 314 |
+
subtype=safe_get(row, "subtype"),
|
| 315 |
+
country=safe_get(row, "country"),
|
| 316 |
+
region=safe_get(row, "region"),
|
| 317 |
+
admin_level=safe_get(row, "admin_level"),
|
| 318 |
+
similarity=float(row["similarity"]),
|
| 319 |
+
))
|
| 320 |
+
return results
|
| 321 |
+
|
| 322 |
+
cross_n = max(1, round(num_distractors * cross_source_ratio))
|
| 323 |
+
same_n = num_distractors - cross_n
|
| 324 |
+
|
| 325 |
+
if anchor_source == "divisions_area":
|
| 326 |
+
same = _query_source(DIVISIONS_AREA_PATH, "divisions_area", same_n, exclude_id)
|
| 327 |
+
cross = _query_source(NATURAL_EARTH_PATH, "natural_earth", cross_n, "")
|
| 328 |
+
else:
|
| 329 |
+
same = _query_source(NATURAL_EARTH_PATH, "natural_earth", same_n, exclude_id)
|
| 330 |
+
cross = _query_source(DIVISIONS_AREA_PATH, "divisions_area", cross_n, "")
|
| 331 |
+
|
| 332 |
+
return same + cross
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def generate_adjacency_sample(
|
| 336 |
+
con: duckdb.DuckDBPyConnection,
|
| 337 |
+
adjacency_df: pd.DataFrame,
|
| 338 |
+
sample_id: str
|
| 339 |
+
) -> Optional[TrainingSample]:
|
| 340 |
+
"""Generate a sample for adjacency task."""
|
| 341 |
+
|
| 342 |
+
anchor = sample_adjacency_anchor(adjacency_df)
|
| 343 |
+
if not anchor:
|
| 344 |
+
return None
|
| 345 |
+
|
| 346 |
+
# Build SQL
|
| 347 |
+
sql = f"""WITH a AS (
|
| 348 |
+
SELECT geometry FROM read_parquet('divisions_area')
|
| 349 |
+
WHERE id = '{anchor['anchor_id']}'
|
| 350 |
+
)
|
| 351 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 352 |
+
FROM read_parquet('divisions_area') AS b, a
|
| 353 |
+
WHERE b.id != '{anchor['anchor_id']}'
|
| 354 |
+
AND b.subtype = '{anchor['target_subtype']}'
|
| 355 |
+
AND ST_Touches(a.geometry, b.geometry)"""
|
| 356 |
+
|
| 357 |
+
# Execute to verify
|
| 358 |
+
try:
|
| 359 |
+
result = con.execute(_for_execution(sql)).fetchdf()
|
| 360 |
+
if result.empty:
|
| 361 |
+
return None
|
| 362 |
+
except Exception as e:
|
| 363 |
+
print(f"SQL execution failed: {e}")
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
# Build candidates
|
| 367 |
+
candidates = build_candidate_list(
|
| 368 |
+
con,
|
| 369 |
+
anchor['anchor_id'],
|
| 370 |
+
anchor['anchor_name'],
|
| 371 |
+
"divisions_area",
|
| 372 |
+
num_candidates=10,
|
| 373 |
+
difficulty="medium"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Find which candidate is the true anchor
|
| 377 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['anchor_id']]
|
| 378 |
+
|
| 379 |
+
# Generate question
|
| 380 |
+
question = f"Which {anchor['target_subtype']}s border {anchor['anchor_name']}?"
|
| 381 |
+
|
| 382 |
+
return TrainingSample(
|
| 383 |
+
id=sample_id,
|
| 384 |
+
question=question,
|
| 385 |
+
candidates=candidates,
|
| 386 |
+
target={
|
| 387 |
+
"selected_candidates": selected_candidate_ids,
|
| 388 |
+
"sql": sql
|
| 389 |
+
},
|
| 390 |
+
metadata={
|
| 391 |
+
"task_family": "adjacency",
|
| 392 |
+
"sql_difficulty": "medium",
|
| 393 |
+
"grounding_difficulty": "medium",
|
| 394 |
+
"template_id": "adj_02",
|
| 395 |
+
"num_candidates": len(candidates),
|
| 396 |
+
"anchor_source": "divisions_area",
|
| 397 |
+
"sql_verified": True
|
| 398 |
+
}
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def generate_containment_sample(
|
| 403 |
+
con: duckdb.DuckDBPyConnection,
|
| 404 |
+
containment_df: pd.DataFrame,
|
| 405 |
+
sample_id: str
|
| 406 |
+
) -> Optional[TrainingSample]:
|
| 407 |
+
"""Generate a sample for containment task."""
|
| 408 |
+
|
| 409 |
+
anchor = sample_containment_anchor(containment_df)
|
| 410 |
+
if not anchor:
|
| 411 |
+
return None
|
| 412 |
+
|
| 413 |
+
# Build SQL
|
| 414 |
+
sql = f"""WITH a AS (
|
| 415 |
+
SELECT geometry FROM read_parquet('divisions_area')
|
| 416 |
+
WHERE id = '{anchor['container_id']}'
|
| 417 |
+
)
|
| 418 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 419 |
+
FROM read_parquet('divisions_area') AS b, a
|
| 420 |
+
WHERE b.id != '{anchor['container_id']}'
|
| 421 |
+
AND b.subtype = '{anchor['contained_subtype']}'
|
| 422 |
+
AND ST_Within(b.geometry, a.geometry)"""
|
| 423 |
+
|
| 424 |
+
# Execute to verify
|
| 425 |
+
try:
|
| 426 |
+
result = con.execute(_for_execution(sql)).fetchdf()
|
| 427 |
+
if result.empty:
|
| 428 |
+
return None
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"SQL execution failed: {e}")
|
| 431 |
+
return None
|
| 432 |
+
|
| 433 |
+
# Build candidates
|
| 434 |
+
candidates = build_candidate_list(
|
| 435 |
+
con,
|
| 436 |
+
anchor['container_id'],
|
| 437 |
+
anchor['container_name'],
|
| 438 |
+
"divisions_area",
|
| 439 |
+
num_candidates=10,
|
| 440 |
+
difficulty="medium"
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Find which candidate is the true anchor
|
| 444 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['container_id']]
|
| 445 |
+
|
| 446 |
+
# Generate question
|
| 447 |
+
question = f"What {anchor['contained_subtype']}s are in {anchor['container_name']}?"
|
| 448 |
+
|
| 449 |
+
return TrainingSample(
|
| 450 |
+
id=sample_id,
|
| 451 |
+
question=question,
|
| 452 |
+
candidates=candidates,
|
| 453 |
+
target={
|
| 454 |
+
"selected_candidates": selected_candidate_ids,
|
| 455 |
+
"sql": sql
|
| 456 |
+
},
|
| 457 |
+
metadata={
|
| 458 |
+
"task_family": "containment",
|
| 459 |
+
"sql_difficulty": "medium",
|
| 460 |
+
"grounding_difficulty": "medium",
|
| 461 |
+
"template_id": "contain_01",
|
| 462 |
+
"num_candidates": len(candidates),
|
| 463 |
+
"anchor_source": "divisions_area",
|
| 464 |
+
"sql_verified": True
|
| 465 |
+
}
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def sample_random_entity(
|
| 470 |
+
con: duckdb.DuckDBPyConnection,
|
| 471 |
+
inventory_df: pd.DataFrame,
|
| 472 |
+
source: str
|
| 473 |
+
) -> Optional[Dict[str, Any]]:
|
| 474 |
+
"""Sample a random entity from inventory."""
|
| 475 |
+
if inventory_df.empty:
|
| 476 |
+
return None
|
| 477 |
+
|
| 478 |
+
row = inventory_df.sample(n=1).iloc[0]
|
| 479 |
+
return {
|
| 480 |
+
'id': row['id'],
|
| 481 |
+
'name': row['name'],
|
| 482 |
+
'subtype': row.get('subtype'),
|
| 483 |
+
'country': row.get('country'),
|
| 484 |
+
'source': source
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def generate_template_based_sample(
|
| 489 |
+
con: duckdb.DuckDBPyConnection,
|
| 490 |
+
template: SQLTemplate,
|
| 491 |
+
tables: Dict[str, pd.DataFrame],
|
| 492 |
+
sample_id: str
|
| 493 |
+
) -> Optional[TrainingSample]:
|
| 494 |
+
"""Generate a sample based on a SQL template."""
|
| 495 |
+
|
| 496 |
+
# Sample anchor based on template requirements
|
| 497 |
+
if template.family == "direct_lookup":
|
| 498 |
+
# Just pick a random entity
|
| 499 |
+
if template.anchor_source == "divisions_area":
|
| 500 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 501 |
+
else:
|
| 502 |
+
anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
|
| 503 |
+
|
| 504 |
+
if not anchor:
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
# Render SQL
|
| 508 |
+
sql = template.sql_template.format(
|
| 509 |
+
anchor_id=anchor['id']
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Build candidates
|
| 513 |
+
candidates = build_candidate_list(
|
| 514 |
+
con, anchor['id'], anchor['name'], anchor['source'],
|
| 515 |
+
num_candidates=10, difficulty="easy"
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Question
|
| 519 |
+
question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
|
| 520 |
+
|
| 521 |
+
elif template.family == "adjacency":
|
| 522 |
+
anchor = sample_adjacency_anchor(tables['adjacency_pairs'])
|
| 523 |
+
if not anchor:
|
| 524 |
+
return None
|
| 525 |
+
|
| 526 |
+
sql = template.sql_template.format(
|
| 527 |
+
anchor_id=anchor['anchor_id'],
|
| 528 |
+
target_subtype=anchor['target_subtype']
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
candidates = build_candidate_list(
|
| 532 |
+
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
|
| 533 |
+
num_candidates=10, difficulty="medium"
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
question = random.choice(template.question_hints).format(
|
| 537 |
+
anchor_name=anchor['anchor_name'],
|
| 538 |
+
target_subtype=anchor['target_subtype']
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
elif template.family == "containment":
|
| 542 |
+
anchor = sample_containment_anchor(tables['containment_pairs'])
|
| 543 |
+
if not anchor:
|
| 544 |
+
return None
|
| 545 |
+
|
| 546 |
+
sql = template.sql_template.format(
|
| 547 |
+
anchor_id=anchor['container_id'],
|
| 548 |
+
target_subtype=anchor['contained_subtype']
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
candidates = build_candidate_list(
|
| 552 |
+
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
|
| 553 |
+
num_candidates=10, difficulty="medium"
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
question = random.choice(template.question_hints).format(
|
| 557 |
+
anchor_name=anchor['container_name'],
|
| 558 |
+
target_subtype=anchor['contained_subtype']
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
elif template.family == "intersection":
|
| 562 |
+
if template.anchor_source == "natural_earth":
|
| 563 |
+
anchor = sample_cross_source_anchor(tables['cross_source_relations'])
|
| 564 |
+
if not anchor:
|
| 565 |
+
return None
|
| 566 |
+
|
| 567 |
+
sql = template.sql_template.format(
|
| 568 |
+
anchor_id=anchor['natural_id'],
|
| 569 |
+
target_subtype='country'
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
candidates = build_candidate_list(
|
| 573 |
+
con, anchor['natural_id'], anchor['natural_name'], 'natural_earth',
|
| 574 |
+
num_candidates=10, difficulty="medium"
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
question = random.choice(template.question_hints).format(
|
| 578 |
+
anchor_name=anchor['natural_name'],
|
| 579 |
+
target_subtype='country'
|
| 580 |
+
)
|
| 581 |
+
else:
|
| 582 |
+
# Same-source intersection
|
| 583 |
+
anchor = sample_intersection_anchor(tables['intersection_pairs'])
|
| 584 |
+
if not anchor:
|
| 585 |
+
return None
|
| 586 |
+
|
| 587 |
+
# Use a generic subtype if not available
|
| 588 |
+
target_subtype = anchor.get('target_subtype') or 'region'
|
| 589 |
+
|
| 590 |
+
sql = template.sql_template.format(
|
| 591 |
+
anchor_id=anchor['anchor_id'],
|
| 592 |
+
target_subtype=target_subtype
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
candidates = build_candidate_list(
|
| 596 |
+
con, anchor['anchor_id'], anchor['anchor_name'], 'divisions_area',
|
| 597 |
+
num_candidates=10, difficulty="medium"
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
question = random.choice(template.question_hints).format(
|
| 601 |
+
anchor_name=anchor['anchor_name'],
|
| 602 |
+
target_subtype=target_subtype
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
elif template.family == "set_operations":
|
| 606 |
+
if template.template_id == "union_03":
|
| 607 |
+
# 3-anchor union by ID — candidates: 3 per anchor (9 total)
|
| 608 |
+
anchors = [
|
| 609 |
+
sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 610 |
+
for _ in range(3)
|
| 611 |
+
]
|
| 612 |
+
if any(a is None for a in anchors):
|
| 613 |
+
return None
|
| 614 |
+
anchor1, anchor2, anchor3 = anchors
|
| 615 |
+
|
| 616 |
+
sql = template.sql_template.format(
|
| 617 |
+
anchor_id_1=anchor1['id'],
|
| 618 |
+
anchor_id_2=anchor2['id'],
|
| 619 |
+
anchor_id_3=anchor3['id'],
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
per_anchor = 3
|
| 623 |
+
cands = [
|
| 624 |
+
build_candidate_list(con, a['id'], a['name'], 'divisions_area',
|
| 625 |
+
num_candidates=per_anchor, difficulty="medium")
|
| 626 |
+
for a in anchors
|
| 627 |
+
]
|
| 628 |
+
candidates = _merge_candidate_lists(*cands, max_total=9)
|
| 629 |
+
|
| 630 |
+
question = random.choice(template.question_hints).format(
|
| 631 |
+
anchor_1_name=anchor1['name'],
|
| 632 |
+
anchor_2_name=anchor2['name'],
|
| 633 |
+
anchor_3_name=anchor3['name'],
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
elif template.template_id in ("contain_multi_01", "contain_multi_02"):
|
| 637 |
+
# country IN clause — 2 or 3 anchors, each contributes its country code
|
| 638 |
+
num_a = 3 if template.template_id == "contain_multi_02" else 2
|
| 639 |
+
anchors = [
|
| 640 |
+
sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 641 |
+
for _ in range(num_a)
|
| 642 |
+
]
|
| 643 |
+
if any(a is None for a in anchors):
|
| 644 |
+
return None
|
| 645 |
+
|
| 646 |
+
countries = [a.get('country') or 'US' for a in anchors]
|
| 647 |
+
target_subtype = random.choice(['region', 'locality'])
|
| 648 |
+
per_anchor = 3 if num_a == 3 else 4
|
| 649 |
+
|
| 650 |
+
fmt_kwargs = dict(
|
| 651 |
+
target_subtype=target_subtype,
|
| 652 |
+
)
|
| 653 |
+
for i, c in enumerate(countries, 1):
|
| 654 |
+
fmt_kwargs[f'country_{i}'] = c
|
| 655 |
+
|
| 656 |
+
sql = template.sql_template.format(**fmt_kwargs)
|
| 657 |
+
|
| 658 |
+
cands = [
|
| 659 |
+
build_candidate_list(con, a['id'], a['name'], 'divisions_area',
|
| 660 |
+
num_candidates=per_anchor, difficulty="medium")
|
| 661 |
+
for a in anchors
|
| 662 |
+
]
|
| 663 |
+
candidates = _merge_candidate_lists(*cands, max_total=num_a * per_anchor)
|
| 664 |
+
|
| 665 |
+
q_kwargs = dict(target_subtype=target_subtype)
|
| 666 |
+
for i, a in enumerate(anchors, 1):
|
| 667 |
+
q_kwargs[f'anchor_{i}_name'] = a['name']
|
| 668 |
+
|
| 669 |
+
question = random.choice(template.question_hints).format(**q_kwargs)
|
| 670 |
+
|
| 671 |
+
elif template.template_id == "union_02":
|
| 672 |
+
# Filtered union: ST_Union_Agg of contained sub-features
|
| 673 |
+
pair = sample_containment_anchor(tables['containment_pairs'])
|
| 674 |
+
if not pair:
|
| 675 |
+
return None
|
| 676 |
+
|
| 677 |
+
target_subtype = pair.get('contained_subtype', 'locality')
|
| 678 |
+
sql = template.sql_template.format(
|
| 679 |
+
anchor_id=pair['container_id'],
|
| 680 |
+
target_subtype=target_subtype,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
candidates = build_candidate_list(
|
| 684 |
+
con, pair['container_id'], pair['container_name'], 'divisions_area',
|
| 685 |
+
num_candidates=10, difficulty="medium"
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
question = random.choice(template.question_hints).format(
|
| 689 |
+
anchor_name=pair['container_name'],
|
| 690 |
+
target_subtype=target_subtype,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
else:
|
| 694 |
+
# union_01: 2-anchor union by ID — candidates: 5 per anchor
|
| 695 |
+
anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 696 |
+
anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 697 |
+
if not anchor1 or not anchor2:
|
| 698 |
+
return None
|
| 699 |
+
|
| 700 |
+
sql = template.sql_template.format(
|
| 701 |
+
anchor_id_1=anchor1['id'],
|
| 702 |
+
anchor_id_2=anchor2['id'],
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
cands1 = build_candidate_list(
|
| 706 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 707 |
+
num_candidates=5, difficulty="medium"
|
| 708 |
+
)
|
| 709 |
+
cands2 = build_candidate_list(
|
| 710 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 711 |
+
num_candidates=5, difficulty="medium"
|
| 712 |
+
)
|
| 713 |
+
candidates = _merge_candidate_lists(cands1, cands2, max_total=10)
|
| 714 |
+
|
| 715 |
+
question = random.choice(template.question_hints).format(
|
| 716 |
+
anchor_1_name=anchor1['name'],
|
| 717 |
+
anchor_2_name=anchor2['name'],
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
elif template.family == "buffer":
|
| 721 |
+
# Buffer operations
|
| 722 |
+
# Kilometre distances used by buffer_01 and buffer_03 templates.
|
| 723 |
+
# Metre distances used by buffer_02 and buffer_04 templates.
|
| 724 |
+
# The template SQL divides by 111 320 to convert to degrees.
|
| 725 |
+
_buffer_km_choices = [1, 2, 5, 10, 25, 50, 100, 200]
|
| 726 |
+
_buffer_m_choices = [100, 250, 500, 1000, 2000, 5000]
|
| 727 |
+
|
| 728 |
+
if template.num_anchors == 1:
|
| 729 |
+
if template.anchor_source == "natural_earth":
|
| 730 |
+
anchor = sample_random_entity(con, tables['natural_earth_inventory'], 'natural_earth')
|
| 731 |
+
else:
|
| 732 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 733 |
+
if not anchor:
|
| 734 |
+
return None
|
| 735 |
+
|
| 736 |
+
# Choose unit based on which placeholder the template uses.
|
| 737 |
+
uses_km = "{buffer_km}" in template.sql_template
|
| 738 |
+
if uses_km:
|
| 739 |
+
buffer_val = random.choice(_buffer_km_choices)
|
| 740 |
+
fmt_kwargs = dict(
|
| 741 |
+
anchor_id=anchor['id'],
|
| 742 |
+
buffer_km=buffer_val,
|
| 743 |
+
)
|
| 744 |
+
q_kwargs = dict(anchor_name=anchor['name'], buffer_km=buffer_val)
|
| 745 |
+
else:
|
| 746 |
+
buffer_val = random.choice(_buffer_m_choices)
|
| 747 |
+
fmt_kwargs = dict(
|
| 748 |
+
anchor_id=anchor['id'],
|
| 749 |
+
buffer_m=buffer_val,
|
| 750 |
+
)
|
| 751 |
+
q_kwargs = dict(anchor_name=anchor['name'], buffer_m=buffer_val)
|
| 752 |
+
|
| 753 |
+
sql = template.sql_template.format(**fmt_kwargs)
|
| 754 |
+
|
| 755 |
+
candidates = build_candidate_list(
|
| 756 |
+
con, anchor['id'], anchor['name'], anchor['source'],
|
| 757 |
+
num_candidates=10, difficulty="medium"
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
question = random.choice(template.question_hints).format(**q_kwargs)
|
| 761 |
+
else:
|
| 762 |
+
# Two anchor buffer (union / set-op style) — kept for completeness.
|
| 763 |
+
anchor1 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 764 |
+
anchor2 = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 765 |
+
|
| 766 |
+
if not anchor1 or not anchor2:
|
| 767 |
+
return None
|
| 768 |
+
|
| 769 |
+
buffer_val = random.choice(_buffer_km_choices[:4]) # smaller range for two-anchor
|
| 770 |
+
sql = template.sql_template.format(
|
| 771 |
+
anchor_id_1=anchor1['id'],
|
| 772 |
+
anchor_id_2=anchor2['id'],
|
| 773 |
+
buffer_km=buffer_val,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
candidates1 = build_candidate_list(
|
| 777 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 778 |
+
num_candidates=5, difficulty="medium"
|
| 779 |
+
)
|
| 780 |
+
candidates2 = build_candidate_list(
|
| 781 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 782 |
+
num_candidates=5, difficulty="medium"
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
candidates = candidates1 + candidates2
|
| 786 |
+
seen_ids = set()
|
| 787 |
+
unique_candidates = []
|
| 788 |
+
for c in candidates:
|
| 789 |
+
if c.id not in seen_ids:
|
| 790 |
+
unique_candidates.append(c)
|
| 791 |
+
seen_ids.add(c.id)
|
| 792 |
+
candidates = unique_candidates[:10]
|
| 793 |
+
|
| 794 |
+
for i, c in enumerate(candidates, 1):
|
| 795 |
+
c.candidate_id = f"c{i}"
|
| 796 |
+
|
| 797 |
+
question = random.choice(template.question_hints).format(
|
| 798 |
+
anchor_1_name=anchor1['name'],
|
| 799 |
+
anchor_2_name=anchor2['name'],
|
| 800 |
+
buffer_km=buffer_val,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
elif template.family == "partial_selection":
|
| 804 |
+
# Partial selection (northern half, clipping, etc.)
|
| 805 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 806 |
+
if not anchor:
|
| 807 |
+
return None
|
| 808 |
+
|
| 809 |
+
if template.num_anchors == 1:
|
| 810 |
+
sql = template.sql_template.format(
|
| 811 |
+
anchor_id=anchor['id'],
|
| 812 |
+
)
|
| 813 |
+
question = random.choice(template.question_hints).format(
|
| 814 |
+
anchor_name=anchor['name'],
|
| 815 |
+
)
|
| 816 |
+
candidates = build_candidate_list(
|
| 817 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 818 |
+
num_candidates=10, difficulty="hard",
|
| 819 |
+
)
|
| 820 |
+
else:
|
| 821 |
+
# Mixed-source clip: division intersected with a natural_earth feature.
|
| 822 |
+
# Use cross_source_relations so the pair is guaranteed to intersect —
|
| 823 |
+
# random sampling almost never produces an intersecting pair.
|
| 824 |
+
cs_df = tables.get('cross_source_relations', pd.DataFrame())
|
| 825 |
+
if cs_df.empty:
|
| 826 |
+
return None
|
| 827 |
+
row = cs_df.sample(n=1).iloc[0]
|
| 828 |
+
clip_feature = {
|
| 829 |
+
'id': row['natural_id'],
|
| 830 |
+
'name': row['natural_name'],
|
| 831 |
+
'source': 'natural_earth',
|
| 832 |
+
}
|
| 833 |
+
# Override the division anchor with the paired division so the
|
| 834 |
+
# ST_Intersects check in the SQL is guaranteed to pass.
|
| 835 |
+
anchor = {
|
| 836 |
+
'id': row['division_id'],
|
| 837 |
+
'name': row['division_name'],
|
| 838 |
+
'source': 'divisions_area',
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
sql = template.sql_template.format(
|
| 842 |
+
anchor_id=anchor['id'],
|
| 843 |
+
clip_feature_id=clip_feature['id'],
|
| 844 |
+
)
|
| 845 |
+
question = random.choice(template.question_hints).format(
|
| 846 |
+
anchor_name=anchor['name'],
|
| 847 |
+
clip_feature_name=clip_feature['name'],
|
| 848 |
+
)
|
| 849 |
+
# Build candidates for BOTH anchors so the model sees both IDs
|
| 850 |
+
# in context and learns to pick the right one for each placeholder.
|
| 851 |
+
div_cands = build_candidate_list(
|
| 852 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 853 |
+
num_candidates=5, difficulty="hard",
|
| 854 |
+
)
|
| 855 |
+
ne_cands = build_candidate_list(
|
| 856 |
+
con, clip_feature['id'], clip_feature['name'], 'natural_earth',
|
| 857 |
+
num_candidates=5, difficulty="hard",
|
| 858 |
+
)
|
| 859 |
+
candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
|
| 860 |
+
|
| 861 |
+
elif template.family == "aggregation":
|
| 862 |
+
top_n = random.choice([3, 5, 10])
|
| 863 |
+
target_subtype = random.choice(['locality', 'region'])
|
| 864 |
+
|
| 865 |
+
if template.template_id in ['agg_03', 'agg_04']:
|
| 866 |
+
# Country-level aggregation: SQL uses country code, not anchor id.
|
| 867 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 868 |
+
if not anchor:
|
| 869 |
+
return None
|
| 870 |
+
|
| 871 |
+
country = anchor.get('country') or 'US'
|
| 872 |
+
|
| 873 |
+
sql = template.sql_template.format(
|
| 874 |
+
country=country,
|
| 875 |
+
target_subtype=target_subtype,
|
| 876 |
+
top_n=top_n,
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
candidates = build_candidate_list(
|
| 880 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 881 |
+
num_candidates=10, difficulty="hard"
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
question = random.choice(template.question_hints).format(
|
| 885 |
+
top_n=top_n,
|
| 886 |
+
target_subtype=target_subtype,
|
| 887 |
+
anchor_name=anchor['name'],
|
| 888 |
+
)
|
| 889 |
+
else:
|
| 890 |
+
# Containment-based aggregation: anchor is the container region.
|
| 891 |
+
anchor = sample_containment_anchor(tables['containment_pairs'])
|
| 892 |
+
if not anchor:
|
| 893 |
+
return None
|
| 894 |
+
|
| 895 |
+
sql = template.sql_template.format(
|
| 896 |
+
anchor_id=anchor['container_id'],
|
| 897 |
+
target_subtype=target_subtype,
|
| 898 |
+
top_n=top_n,
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
candidates = build_candidate_list(
|
| 902 |
+
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
|
| 903 |
+
num_candidates=10, difficulty="hard"
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
question = random.choice(template.question_hints).format(
|
| 907 |
+
top_n=top_n,
|
| 908 |
+
target_subtype=target_subtype,
|
| 909 |
+
anchor_name=anchor['container_name'],
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
elif template.family == "chained":
|
| 913 |
+
# Use pre-filtered coastal/landlocked containment pairs so the SQL
|
| 914 |
+
# verification step doesn't constantly return empty results.
|
| 915 |
+
if template.template_id == "chained_01":
|
| 916 |
+
table_key = 'coastal_containment_pairs'
|
| 917 |
+
elif template.template_id == "chained_02":
|
| 918 |
+
table_key = 'landlocked_containment_pairs'
|
| 919 |
+
else:
|
| 920 |
+
table_key = 'containment_pairs'
|
| 921 |
+
anchor = sample_containment_anchor(tables.get(table_key, tables['containment_pairs']))
|
| 922 |
+
if not anchor:
|
| 923 |
+
return None
|
| 924 |
+
|
| 925 |
+
target_subtype = anchor.get('contained_subtype', 'locality')
|
| 926 |
+
|
| 927 |
+
sql = template.sql_template.format(
|
| 928 |
+
anchor_id=anchor['container_id'],
|
| 929 |
+
target_subtype=target_subtype,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
candidates = build_candidate_list(
|
| 933 |
+
con, anchor['container_id'], anchor['container_name'], 'divisions_area',
|
| 934 |
+
num_candidates=10, difficulty="hard"
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
question = random.choice(template.question_hints).format(
|
| 938 |
+
anchor_name=anchor['container_name'],
|
| 939 |
+
target_subtype=target_subtype,
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
elif template.family == "multi_adjacency":
|
| 943 |
+
# Use common_neighbor_pairs so anchor1 and anchor2 are guaranteed to
|
| 944 |
+
# share at least one touching neighbour — SQL will return non-empty.
|
| 945 |
+
cn_df = tables.get('common_neighbor_pairs', pd.DataFrame())
|
| 946 |
+
if cn_df.empty:
|
| 947 |
+
return None
|
| 948 |
+
row = cn_df.sample(n=1).iloc[0]
|
| 949 |
+
anchor1 = {'id': row['anchor_id_1'], 'name': row['anchor_name_1'], 'source': 'divisions_area'}
|
| 950 |
+
anchor2 = {'id': row['anchor_id_2'], 'name': row['anchor_name_2'], 'source': 'divisions_area'}
|
| 951 |
+
|
| 952 |
+
sql = template.sql_template.format(
|
| 953 |
+
anchor_id_1=anchor1['id'],
|
| 954 |
+
anchor_id_2=anchor2['id'],
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
candidates1 = build_candidate_list(
|
| 958 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 959 |
+
num_candidates=5, difficulty="medium"
|
| 960 |
+
)
|
| 961 |
+
candidates2 = build_candidate_list(
|
| 962 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 963 |
+
num_candidates=5, difficulty="medium"
|
| 964 |
+
)
|
| 965 |
+
candidates = _merge_candidate_lists(candidates1, candidates2)
|
| 966 |
+
|
| 967 |
+
question = random.choice(template.question_hints).format(
|
| 968 |
+
anchor_1_name=anchor1['name'],
|
| 969 |
+
anchor_2_name=anchor2['name'],
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
elif template.family == "difference":
|
| 973 |
+
if template.anchor_source == "mixed":
|
| 974 |
+
# divisions_area anchor differenced against a natural_earth feature.
|
| 975 |
+
# Use cross_source_relations so the pair is guaranteed to intersect
|
| 976 |
+
# (ST_Difference on non-intersecting geometries is always equal to
|
| 977 |
+
# the original geometry — a trivial and uninformative sample).
|
| 978 |
+
cs_df = tables.get('cross_source_relations', pd.DataFrame())
|
| 979 |
+
if cs_df.empty:
|
| 980 |
+
return None
|
| 981 |
+
row = cs_df.sample(n=1).iloc[0]
|
| 982 |
+
anchor = {
|
| 983 |
+
'id': row['division_id'],
|
| 984 |
+
'name': row['division_name'],
|
| 985 |
+
'source': 'divisions_area',
|
| 986 |
+
}
|
| 987 |
+
clip_feature = {
|
| 988 |
+
'id': row['natural_id'],
|
| 989 |
+
'name': row['natural_name'],
|
| 990 |
+
'source': 'natural_earth',
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
sql = template.sql_template.format(
|
| 994 |
+
anchor_id=anchor['id'],
|
| 995 |
+
clip_feature_id=clip_feature['id'],
|
| 996 |
+
)
|
| 997 |
+
question = random.choice(template.question_hints).format(
|
| 998 |
+
anchor_name=anchor['name'],
|
| 999 |
+
clip_feature_name=clip_feature['name'],
|
| 1000 |
+
)
|
| 1001 |
+
# Build candidates for BOTH anchors — model must see both IDs
|
| 1002 |
+
# to correctly assign anchor_id vs clip_feature_id in the SQL.
|
| 1003 |
+
div_cands = build_candidate_list(
|
| 1004 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 1005 |
+
num_candidates=5, difficulty="hard",
|
| 1006 |
+
)
|
| 1007 |
+
ne_cands = build_candidate_list(
|
| 1008 |
+
con, clip_feature['id'], clip_feature['name'], 'natural_earth',
|
| 1009 |
+
num_candidates=5, difficulty="hard",
|
| 1010 |
+
)
|
| 1011 |
+
candidates = _merge_candidate_lists(div_cands, ne_cands, max_total=10)
|
| 1012 |
+
|
| 1013 |
+
else:
|
| 1014 |
+
# Two divisions_area anchors — use containment pairs so the
|
| 1015 |
+
# smaller (contained) is guaranteed to intersect the larger.
|
| 1016 |
+
pair = sample_containment_anchor(tables['containment_pairs'])
|
| 1017 |
+
if not pair:
|
| 1018 |
+
return None
|
| 1019 |
+
|
| 1020 |
+
anchor1 = {'id': pair['container_id'], 'name': pair['container_name']}
|
| 1021 |
+
anchor2_row = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 1022 |
+
if not anchor2_row:
|
| 1023 |
+
return None
|
| 1024 |
+
anchor2 = anchor2_row
|
| 1025 |
+
|
| 1026 |
+
sql = template.sql_template.format(
|
| 1027 |
+
anchor_id_1=anchor1['id'],
|
| 1028 |
+
anchor_id_2=anchor2['id'],
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
candidates1 = build_candidate_list(
|
| 1032 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 1033 |
+
num_candidates=5, difficulty="medium"
|
| 1034 |
+
)
|
| 1035 |
+
candidates2 = build_candidate_list(
|
| 1036 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 1037 |
+
num_candidates=5, difficulty="medium"
|
| 1038 |
+
)
|
| 1039 |
+
candidates = _merge_candidate_lists(candidates1, candidates2)
|
| 1040 |
+
|
| 1041 |
+
question = random.choice(template.question_hints).format(
|
| 1042 |
+
anchor_1_name=anchor1['name'],
|
| 1043 |
+
anchor_2_name=anchor2['name'],
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
elif template.family == "border_corridor":
|
| 1047 |
+
# Buffered border zone — needs two anchors that actually touch.
|
| 1048 |
+
pair = sample_adjacency_anchor(tables['adjacency_pairs'])
|
| 1049 |
+
if not pair:
|
| 1050 |
+
return None
|
| 1051 |
+
|
| 1052 |
+
# The adjacency table only records one direction; sample a second
|
| 1053 |
+
# anchor that is known to be adjacent to the first.
|
| 1054 |
+
anchor1 = {'id': pair['anchor_id'], 'name': pair['anchor_name']}
|
| 1055 |
+
|
| 1056 |
+
# Find a random neighbour of anchor1 from adjacency pairs
|
| 1057 |
+
neighbours = tables['adjacency_pairs']
|
| 1058 |
+
neighbours = neighbours[neighbours['anchor_id'] == anchor1['id']]
|
| 1059 |
+
if neighbours.empty:
|
| 1060 |
+
return None
|
| 1061 |
+
nb_row = neighbours.sample(n=1).iloc[0]
|
| 1062 |
+
anchor2 = {'id': nb_row.get('target_id', nb_row['anchor_id']), 'name': nb_row.get('target_name', nb_row['anchor_name'])}
|
| 1063 |
+
if anchor1['id'] == anchor2['id']:
|
| 1064 |
+
return None
|
| 1065 |
+
|
| 1066 |
+
buffer_val = random.choice([5, 10, 25, 50])
|
| 1067 |
+
|
| 1068 |
+
sql = template.sql_template.format(
|
| 1069 |
+
anchor_id_1=anchor1['id'],
|
| 1070 |
+
anchor_id_2=anchor2['id'],
|
| 1071 |
+
buffer_km=buffer_val,
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
candidates1 = build_candidate_list(
|
| 1075 |
+
con, anchor1['id'], anchor1['name'], 'divisions_area',
|
| 1076 |
+
num_candidates=5, difficulty="medium"
|
| 1077 |
+
)
|
| 1078 |
+
candidates2 = build_candidate_list(
|
| 1079 |
+
con, anchor2['id'], anchor2['name'], 'divisions_area',
|
| 1080 |
+
num_candidates=5, difficulty="medium"
|
| 1081 |
+
)
|
| 1082 |
+
candidates = _merge_candidate_lists(candidates1, candidates2)
|
| 1083 |
+
|
| 1084 |
+
question = random.choice(template.question_hints).format(
|
| 1085 |
+
anchor_1_name=anchor1['name'],
|
| 1086 |
+
anchor_2_name=anchor2['name'],
|
| 1087 |
+
buffer_km=buffer_val,
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
elif template.family == "window_function":
|
| 1091 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 1092 |
+
if not anchor:
|
| 1093 |
+
return None
|
| 1094 |
+
|
| 1095 |
+
country = anchor.get('country') or 'US'
|
| 1096 |
+
target_subtype = random.choice(['locality', 'neighborhood'])
|
| 1097 |
+
|
| 1098 |
+
sql = template.sql_template.format(
|
| 1099 |
+
country=country,
|
| 1100 |
+
target_subtype=target_subtype,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
candidates = build_candidate_list(
|
| 1104 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 1105 |
+
num_candidates=10, difficulty="hard"
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
question = random.choice(template.question_hints).format(
|
| 1109 |
+
anchor_name=anchor['name'],
|
| 1110 |
+
target_subtype=target_subtype,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
elif template.family == "attribute_filter":
|
| 1114 |
+
anchor = sample_random_entity(con, tables['divisions_area_inventory'], 'divisions_area')
|
| 1115 |
+
if not anchor:
|
| 1116 |
+
return None
|
| 1117 |
+
|
| 1118 |
+
country = anchor.get('country') or 'US'
|
| 1119 |
+
target_subtype = template.target_subtype or random.choice(['dependency', 'region', 'locality'])
|
| 1120 |
+
|
| 1121 |
+
sql = template.sql_template.format(
|
| 1122 |
+
country=country,
|
| 1123 |
+
target_subtype=target_subtype,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
candidates = build_candidate_list(
|
| 1127 |
+
con, anchor['id'], anchor['name'], 'divisions_area',
|
| 1128 |
+
num_candidates=10, difficulty="medium"
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
question = random.choice(template.question_hints).format(
|
| 1132 |
+
anchor_name=anchor['name'],
|
| 1133 |
+
target_subtype=target_subtype,
|
| 1134 |
+
country=country,
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
else:
|
| 1138 |
+
# Skip unsupported families
|
| 1139 |
+
return None
|
| 1140 |
+
|
| 1141 |
+
# Execute SQL to verify
|
| 1142 |
+
try:
|
| 1143 |
+
result = con.execute(_for_execution(sql)).fetchdf()
|
| 1144 |
+
if result.empty:
|
| 1145 |
+
return None
|
| 1146 |
+
except Exception as e:
|
| 1147 |
+
# Errors are tracked in worker return, no need to print
|
| 1148 |
+
return None
|
| 1149 |
+
|
| 1150 |
+
# Collect every anchor ID that appears in the generated SQL so we can
|
| 1151 |
+
# mark them as the "selected" candidates in the training sample.
|
| 1152 |
+
_multi_anchor_families = {"set_operations", "multi_adjacency", "difference", "border_corridor"}
|
| 1153 |
+
|
| 1154 |
+
# Mixed partial_selection (partial_05) and mixed difference (diff_02) each
|
| 1155 |
+
# have two anchors from different sources — both must be marked selected.
|
| 1156 |
+
_is_mixed_two_anchor = (
|
| 1157 |
+
template.anchor_source == "mixed" and template.num_anchors == 2
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
if template.family in _multi_anchor_families and template.num_anchors >= 2:
|
| 1161 |
+
anchor_ids: set = set()
|
| 1162 |
+
for var in ("anchor1", "anchor2", "anchor3"):
|
| 1163 |
+
obj = locals().get(var)
|
| 1164 |
+
if obj:
|
| 1165 |
+
anchor_ids.add(obj.get("id", ""))
|
| 1166 |
+
if "anchors" in locals():
|
| 1167 |
+
for a in locals()["anchors"]:
|
| 1168 |
+
if a:
|
| 1169 |
+
anchor_ids.add(a.get("id", ""))
|
| 1170 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id in anchor_ids]
|
| 1171 |
+
|
| 1172 |
+
elif _is_mixed_two_anchor:
|
| 1173 |
+
# partial_05 / diff_02: anchor (division) + clip_feature (natural_earth)
|
| 1174 |
+
mixed_ids = {anchor.get("id", ""), clip_feature.get("id", "")}
|
| 1175 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id in mixed_ids]
|
| 1176 |
+
|
| 1177 |
+
else:
|
| 1178 |
+
anchor_id_to_find = (
|
| 1179 |
+
anchor.get('anchor_id')
|
| 1180 |
+
or anchor.get('container_id')
|
| 1181 |
+
or anchor.get('natural_id')
|
| 1182 |
+
or anchor.get('id')
|
| 1183 |
+
)
|
| 1184 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
|
| 1185 |
+
|
| 1186 |
+
return TrainingSample(
|
| 1187 |
+
id=sample_id,
|
| 1188 |
+
question=question,
|
| 1189 |
+
candidates=candidates,
|
| 1190 |
+
target={
|
| 1191 |
+
"selected_candidates": selected_candidate_ids,
|
| 1192 |
+
"sql": sql
|
| 1193 |
+
},
|
| 1194 |
+
metadata={
|
| 1195 |
+
"task_family": template.family,
|
| 1196 |
+
"sql_difficulty": template.sql_difficulty,
|
| 1197 |
+
"grounding_difficulty": "medium",
|
| 1198 |
+
"template_id": template.template_id,
|
| 1199 |
+
"num_candidates": len(candidates),
|
| 1200 |
+
"anchor_source": template.anchor_source,
|
| 1201 |
+
"sql_verified": True
|
| 1202 |
+
}
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def generate_cross_source_sample(
|
| 1207 |
+
con: duckdb.DuckDBPyConnection,
|
| 1208 |
+
cross_source_df: pd.DataFrame,
|
| 1209 |
+
sample_id: str
|
| 1210 |
+
) -> Optional[TrainingSample]:
|
| 1211 |
+
"""Generate a sample for cross-source intersection task."""
|
| 1212 |
+
|
| 1213 |
+
anchor = sample_cross_source_anchor(cross_source_df)
|
| 1214 |
+
if not anchor:
|
| 1215 |
+
return None
|
| 1216 |
+
|
| 1217 |
+
# Build SQL (natural feature -> divisions)
|
| 1218 |
+
sql = f"""WITH a AS (
|
| 1219 |
+
SELECT geometry FROM read_parquet('natural_earth')
|
| 1220 |
+
WHERE id = '{anchor['natural_id']}'
|
| 1221 |
+
)
|
| 1222 |
+
SELECT b.id, b.names."primary" AS name, b.geometry
|
| 1223 |
+
FROM read_parquet('divisions_area') AS b, a
|
| 1224 |
+
WHERE b.subtype = 'country'
|
| 1225 |
+
AND ST_Intersects(b.geometry, a.geometry)"""
|
| 1226 |
+
|
| 1227 |
+
# Execute to verify
|
| 1228 |
+
try:
|
| 1229 |
+
result = con.execute(_for_execution(sql)).fetchdf()
|
| 1230 |
+
if result.empty:
|
| 1231 |
+
return None
|
| 1232 |
+
except Exception as e:
|
| 1233 |
+
print(f"SQL execution failed: {e}")
|
| 1234 |
+
return None
|
| 1235 |
+
|
| 1236 |
+
# Build candidates for natural feature
|
| 1237 |
+
candidates = build_candidate_list(
|
| 1238 |
+
con,
|
| 1239 |
+
anchor['natural_id'],
|
| 1240 |
+
anchor['natural_name'],
|
| 1241 |
+
"natural_earth",
|
| 1242 |
+
num_candidates=10,
|
| 1243 |
+
difficulty="medium"
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
# Find which candidate is the true anchor
|
| 1247 |
+
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor['natural_id']]
|
| 1248 |
+
|
| 1249 |
+
# Generate question
|
| 1250 |
+
question = f"Which countries intersect the {anchor['natural_name']}?"
|
| 1251 |
+
|
| 1252 |
+
return TrainingSample(
|
| 1253 |
+
id=sample_id,
|
| 1254 |
+
question=question,
|
| 1255 |
+
candidates=candidates,
|
| 1256 |
+
target={
|
| 1257 |
+
"selected_candidates": selected_candidate_ids,
|
| 1258 |
+
"sql": sql
|
| 1259 |
+
},
|
| 1260 |
+
metadata={
|
| 1261 |
+
"task_family": "intersection",
|
| 1262 |
+
"sql_difficulty": "medium-hard",
|
| 1263 |
+
"grounding_difficulty": "medium",
|
| 1264 |
+
"template_id": "intersect_02",
|
| 1265 |
+
"num_candidates": len(candidates),
|
| 1266 |
+
"anchor_source": "natural_earth",
|
| 1267 |
+
"sql_verified": True
|
| 1268 |
+
}
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
def generate_sample_batch_worker(args):
|
| 1273 |
+
"""Worker function that processes a batch of work items with a single DuckDB connection.
|
| 1274 |
+
|
| 1275 |
+
Initializes DuckDB, spatial extension, templates module, and relation tables
|
| 1276 |
+
ONCE per batch, then processes all items sequentially.
|
| 1277 |
+
"""
|
| 1278 |
+
from pathlib import Path
|
| 1279 |
+
|
| 1280 |
+
work_items, intermediate_dir_str = args
|
| 1281 |
+
|
| 1282 |
+
# Convert string back to Path
|
| 1283 |
+
intermediate_dir = Path(intermediate_dir_str)
|
| 1284 |
+
|
| 1285 |
+
# Initialize DuckDB ONCE for the entire batch
|
| 1286 |
+
con = duckdb.connect()
|
| 1287 |
+
con.execute("SET enable_progress_bar=false")
|
| 1288 |
+
con.execute("INSTALL spatial")
|
| 1289 |
+
con.execute("LOAD spatial")
|
| 1290 |
+
|
| 1291 |
+
# Load relation tables ONCE
|
| 1292 |
+
tables = load_relation_tables(intermediate_dir, quiet=True)
|
| 1293 |
+
|
| 1294 |
+
# Process all items in batch
|
| 1295 |
+
results = []
|
| 1296 |
+
for family, template_dict, sample_id, _ in work_items:
|
| 1297 |
+
# Reconstruct template from dict (sql_templates is already imported at module level)
|
| 1298 |
+
template = sql_templates.SQLTemplate(**template_dict)
|
| 1299 |
+
try:
|
| 1300 |
+
sample = generate_template_based_sample(con, template, tables, sample_id)
|
| 1301 |
+
if sample:
|
| 1302 |
+
results.append((sample, family, template.template_id, None))
|
| 1303 |
+
else:
|
| 1304 |
+
results.append((None, family, template.template_id, "Empty result"))
|
| 1305 |
+
except Exception as e:
|
| 1306 |
+
results.append((None, family, template_dict.get('template_id', 'unknown'), str(e)))
|
| 1307 |
+
|
| 1308 |
+
con.close()
|
| 1309 |
+
return results
|
| 1310 |
+
|
| 1311 |
+
|
| 1312 |
+
def generate_batch_core(
|
| 1313 |
+
work_items: List[tuple],
|
| 1314 |
+
intermediate_dir: str,
|
| 1315 |
+
) -> List[Dict[str, Any]]:
|
| 1316 |
+
"""Standalone batch worker usable from Modal or any remote context.
|
| 1317 |
+
|
| 1318 |
+
Data paths are resolved via GAZET_DATA_DIR env var (set in Modal image).
|
| 1319 |
+
|
| 1320 |
+
Args:
|
| 1321 |
+
work_items: List of (family, template_dict, sample_id, _) tuples
|
| 1322 |
+
intermediate_dir: Path to intermediate dir with relation parquets
|
| 1323 |
+
|
| 1324 |
+
Returns:
|
| 1325 |
+
List of dicts with keys: sample (dict or None), family, template_id, error
|
| 1326 |
+
"""
|
| 1327 |
+
from pathlib import Path as _Path
|
| 1328 |
+
intermediate = _Path(intermediate_dir)
|
| 1329 |
+
|
| 1330 |
+
con = duckdb.connect()
|
| 1331 |
+
con.execute("SET enable_progress_bar=false")
|
| 1332 |
+
con.execute("INSTALL spatial")
|
| 1333 |
+
con.execute("LOAD spatial")
|
| 1334 |
+
|
| 1335 |
+
tables = load_relation_tables(intermediate, quiet=True)
|
| 1336 |
+
|
| 1337 |
+
results = []
|
| 1338 |
+
for family, template_dict, sample_id, _ in work_items:
|
| 1339 |
+
template = sql_templates.SQLTemplate(**template_dict)
|
| 1340 |
+
try:
|
| 1341 |
+
sample = generate_template_based_sample(con, template, tables, sample_id)
|
| 1342 |
+
if sample:
|
| 1343 |
+
results.append({
|
| 1344 |
+
"sample": sample.model_dump(),
|
| 1345 |
+
"family": family,
|
| 1346 |
+
"template_id": template.template_id,
|
| 1347 |
+
"error": None,
|
| 1348 |
+
})
|
| 1349 |
+
else:
|
| 1350 |
+
results.append({
|
| 1351 |
+
"sample": None,
|
| 1352 |
+
"family": family,
|
| 1353 |
+
"template_id": template.template_id,
|
| 1354 |
+
"error": "Empty result",
|
| 1355 |
+
})
|
| 1356 |
+
except Exception as e:
|
| 1357 |
+
results.append({
|
| 1358 |
+
"sample": None,
|
| 1359 |
+
"family": family,
|
| 1360 |
+
"template_id": template_dict.get('template_id', 'unknown'),
|
| 1361 |
+
"error": str(e),
|
| 1362 |
+
})
|
| 1363 |
+
|
| 1364 |
+
con.close()
|
| 1365 |
+
return results
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
def prepare_work_items(
|
| 1369 |
+
target_counts: Dict[str, int],
|
| 1370 |
+
retry_multiplier: int = 2,
|
| 1371 |
+
start_counter: int = 1,
|
| 1372 |
+
intermediate_dir_str: str = "",
|
| 1373 |
+
) -> List[tuple]:
|
| 1374 |
+
"""Prepare shuffled work items for sample generation.
|
| 1375 |
+
|
| 1376 |
+
Returns list of (family, template_dict, sample_id, intermediate_dir_str) tuples.
|
| 1377 |
+
Reusable by both local main() and Modal orchestrator.
|
| 1378 |
+
"""
|
| 1379 |
+
work_items = []
|
| 1380 |
+
sample_counter = start_counter
|
| 1381 |
+
|
| 1382 |
+
for family, target_count in target_counts.items():
|
| 1383 |
+
if target_count == 0:
|
| 1384 |
+
continue
|
| 1385 |
+
|
| 1386 |
+
family_templates = [t for t in TEMPLATES if t.family == family]
|
| 1387 |
+
if not family_templates:
|
| 1388 |
+
print(f"No templates found for {family}, skipping...")
|
| 1389 |
+
continue
|
| 1390 |
+
|
| 1391 |
+
for _ in range(target_count * retry_multiplier):
|
| 1392 |
+
template = random.choice(family_templates)
|
| 1393 |
+
template_dict = {
|
| 1394 |
+
'template_id': template.template_id,
|
| 1395 |
+
'family': template.family,
|
| 1396 |
+
'sql_difficulty': template.sql_difficulty,
|
| 1397 |
+
'anchor_source': template.anchor_source,
|
| 1398 |
+
'num_anchors': template.num_anchors,
|
| 1399 |
+
'sql_template': template.sql_template,
|
| 1400 |
+
'question_hints': template.question_hints,
|
| 1401 |
+
'target_subtype': template.target_subtype,
|
| 1402 |
+
'requires_buffer': template.requires_buffer,
|
| 1403 |
+
'requires_aggregation': template.requires_aggregation
|
| 1404 |
+
}
|
| 1405 |
+
work_items.append((
|
| 1406 |
+
family,
|
| 1407 |
+
template_dict,
|
| 1408 |
+
f"sample_{sample_counter:06d}",
|
| 1409 |
+
intermediate_dir_str,
|
| 1410 |
+
))
|
| 1411 |
+
sample_counter += 1
|
| 1412 |
+
|
| 1413 |
+
random.shuffle(work_items)
|
| 1414 |
+
return work_items
|
| 1415 |
+
|
| 1416 |
+
|
| 1417 |
+
def main():
|
| 1418 |
+
"""Generate training samples."""
|
| 1419 |
+
global TARGET_COUNTS, MAX_WORKERS, RETRY_MULTIPLIER, APPEND_MODE
|
| 1420 |
+
|
| 1421 |
+
# Setup paths
|
| 1422 |
+
script_dir = Path(__file__).parent
|
| 1423 |
+
intermediate_dir = script_dir.parent / "intermediate"
|
| 1424 |
+
output_dir = script_dir.parent / "output"
|
| 1425 |
+
|
| 1426 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 1427 |
+
|
| 1428 |
+
# Load relation tables once to check availability
|
| 1429 |
+
print("Loading relation tables...")
|
| 1430 |
+
tables = load_relation_tables(intermediate_dir, quiet=False)
|
| 1431 |
+
|
| 1432 |
+
# Use configured target counts or defaults
|
| 1433 |
+
if TARGET_COUNTS is None:
|
| 1434 |
+
target_counts = {
|
| 1435 |
+
'direct_lookup': 100,
|
| 1436 |
+
'adjacency': 150,
|
| 1437 |
+
'multi_adjacency': 75,
|
| 1438 |
+
'containment': 100,
|
| 1439 |
+
'intersection': 100,
|
| 1440 |
+
'buffer': 100,
|
| 1441 |
+
'chained': 150,
|
| 1442 |
+
'difference': 75,
|
| 1443 |
+
'border_corridor': 75,
|
| 1444 |
+
'set_operations': 150,
|
| 1445 |
+
'partial_selection': 75,
|
| 1446 |
+
'aggregation': 100,
|
| 1447 |
+
'window_function': 75,
|
| 1448 |
+
'attribute_filter': 75,
|
| 1449 |
+
}
|
| 1450 |
+
else:
|
| 1451 |
+
target_counts = TARGET_COUNTS
|
| 1452 |
+
|
| 1453 |
+
# Load existing samples if in append mode
|
| 1454 |
+
existing_samples = []
|
| 1455 |
+
existing_sample_ids = set()
|
| 1456 |
+
jsonl_file = output_dir / "dataset_raw.jsonl"
|
| 1457 |
+
|
| 1458 |
+
if APPEND_MODE and jsonl_file.exists():
|
| 1459 |
+
print(f"\nAppend mode: Loading existing samples from {jsonl_file}")
|
| 1460 |
+
with open(jsonl_file, 'r') as f:
|
| 1461 |
+
for line in f:
|
| 1462 |
+
if line.strip():
|
| 1463 |
+
sample_data = json.loads(line)
|
| 1464 |
+
existing_samples.append(sample_data)
|
| 1465 |
+
existing_sample_ids.add(sample_data['id'])
|
| 1466 |
+
print(f" Found {len(existing_samples)} existing samples")
|
| 1467 |
+
|
| 1468 |
+
# Determine starting sample counter
|
| 1469 |
+
max_existing_id = max([int(s['id'].split('_')[1]) for s in existing_samples if s['id'].startswith('sample_')], default=0)
|
| 1470 |
+
sample_counter = max_existing_id + 1
|
| 1471 |
+
else:
|
| 1472 |
+
sample_counter = 1
|
| 1473 |
+
|
| 1474 |
+
# Prepare work items using shared helper
|
| 1475 |
+
work_items = prepare_work_items(
|
| 1476 |
+
target_counts=target_counts,
|
| 1477 |
+
retry_multiplier=RETRY_MULTIPLIER,
|
| 1478 |
+
start_counter=sample_counter,
|
| 1479 |
+
intermediate_dir_str=str(intermediate_dir),
|
| 1480 |
+
)
|
| 1481 |
+
starting_sample_counter = sample_counter
|
| 1482 |
+
|
| 1483 |
+
# Partition work items into batches (one per worker)
|
| 1484 |
+
num_workers = min(MAX_WORKERS, len(work_items))
|
| 1485 |
+
if num_workers == 0:
|
| 1486 |
+
print("No work items to process")
|
| 1487 |
+
return
|
| 1488 |
+
batch_size = (len(work_items) + num_workers - 1) // num_workers
|
| 1489 |
+
batches = []
|
| 1490 |
+
for i in range(0, len(work_items), batch_size):
|
| 1491 |
+
batch = work_items[i:i + batch_size]
|
| 1492 |
+
batches.append((batch, str(intermediate_dir)))
|
| 1493 |
+
|
| 1494 |
+
# Generate samples in parallel (one batch per worker)
|
| 1495 |
+
active_families = len([f for f in target_counts.values() if f > 0])
|
| 1496 |
+
print(f"\nGenerating {len(work_items)} samples across {active_families} families...")
|
| 1497 |
+
print(f" Split into {len(batches)} batches of ~{batch_size} items (1 DuckDB init per batch)")
|
| 1498 |
+
if APPEND_MODE and existing_samples:
|
| 1499 |
+
print(f"Appending: starting from sample_{starting_sample_counter:03d}")
|
| 1500 |
+
|
| 1501 |
+
all_samples = []
|
| 1502 |
+
family_progress = {f: {'success': 0, 'failed': 0} for f in target_counts.keys() if target_counts[f] > 0}
|
| 1503 |
+
|
| 1504 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 1505 |
+
# Submit one batch per worker
|
| 1506 |
+
futures = {executor.submit(generate_sample_batch_worker, batch): i for i, batch in enumerate(batches)}
|
| 1507 |
+
|
| 1508 |
+
# Collect results as batches complete
|
| 1509 |
+
batches_done = 0
|
| 1510 |
+
for future in as_completed(futures):
|
| 1511 |
+
try:
|
| 1512 |
+
batch_results = future.result()
|
| 1513 |
+
for sample, family, template_id, error in batch_results:
|
| 1514 |
+
if sample:
|
| 1515 |
+
all_samples.append(sample)
|
| 1516 |
+
family_progress[family]['success'] += 1
|
| 1517 |
+
else:
|
| 1518 |
+
family_progress[family]['failed'] += 1
|
| 1519 |
+
except Exception as e:
|
| 1520 |
+
print(f"\n Batch failed: {e}")
|
| 1521 |
+
|
| 1522 |
+
batches_done += 1
|
| 1523 |
+
total_done = sum(p['success'] + p['failed'] for p in family_progress.values())
|
| 1524 |
+
print(f"\r Progress: {total_done}/{len(work_items)} samples ({batches_done}/{len(batches)} batches) ", end='', flush=True)
|
| 1525 |
+
|
| 1526 |
+
print() # New line after progress
|
| 1527 |
+
|
| 1528 |
+
# Show distribution (keep all samples, no filtering)
|
| 1529 |
+
print("\nResults by family:")
|
| 1530 |
+
for family in sorted(family_progress.keys()):
|
| 1531 |
+
success = family_progress[family]['success']
|
| 1532 |
+
failed = family_progress[family]['failed']
|
| 1533 |
+
target = target_counts.get(family, 0)
|
| 1534 |
+
total = success + failed
|
| 1535 |
+
success_rate = (success / total * 100) if total > 0 else 0
|
| 1536 |
+
print(f" {family:20s}: {success:3d} success / {failed:3d} failed ({success_rate:5.1f}% success rate, target: {target})")
|
| 1537 |
+
|
| 1538 |
+
# Save combined JSONL (skip individual JSON files for speed at scale)
|
| 1539 |
+
print(f"\nSaving {len(all_samples)} new samples...")
|
| 1540 |
+
if APPEND_MODE and existing_samples:
|
| 1541 |
+
# Append to existing dataset
|
| 1542 |
+
print(f"Appending to existing dataset ({len(existing_samples)} existing samples)")
|
| 1543 |
+
with open(jsonl_file, 'a') as f:
|
| 1544 |
+
for sample in all_samples:
|
| 1545 |
+
f.write(json.dumps(sample.model_dump()) + '\n')
|
| 1546 |
+
total_samples = len(existing_samples) + len(all_samples)
|
| 1547 |
+
else:
|
| 1548 |
+
# Overwrite dataset
|
| 1549 |
+
with open(jsonl_file, 'w') as f:
|
| 1550 |
+
for sample in all_samples:
|
| 1551 |
+
f.write(json.dumps(sample.model_dump()) + '\n')
|
| 1552 |
+
total_samples = len(all_samples)
|
| 1553 |
+
|
| 1554 |
+
print(f"\nGenerated {len(all_samples)} new samples")
|
| 1555 |
+
print(f"Total dataset size: {total_samples} samples")
|
| 1556 |
+
print(f" Dataset: {jsonl_file}")
|
| 1557 |
+
|
| 1558 |
+
|
| 1559 |
+
if __name__ == "__main__":
|
| 1560 |
+
main()
|
dataset/scripts/sql_templates.py
ADDED
|
@@ -0,0 +1,1651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL template definitions for synthetic data generation.
|
| 3 |
+
|
| 4 |
+
Geometry output convention
|
| 5 |
+
--------------------------
|
| 6 |
+
Every final SELECT wraps geometry with ST_AsGeoJSON():
|
| 7 |
+
ST_AsGeoJSON(geometry) AS geometry
|
| 8 |
+
This returns a GeoJSON string instead of raw WKB bytes, which is directly
|
| 9 |
+
JSON-serialisable and matches what the serving stack expects.
|
| 10 |
+
|
| 11 |
+
CTEs that compute intermediate geometries (used only for spatial predicates
|
| 12 |
+
or ST_Area) keep the column as raw GEOMETRY so DuckDB spatial functions work.
|
| 13 |
+
|
| 14 |
+
Buffer distance convention
|
| 15 |
+
--------------------------
|
| 16 |
+
All buffer templates use {buffer_km} or {buffer_m} (never degrees).
|
| 17 |
+
SQL converts to degrees: metres / 111_320.
|
| 18 |
+
|
| 19 |
+
Mixed-source candidates
|
| 20 |
+
-----------------------
|
| 21 |
+
generate_samples.py pads every candidate list with 50 % cross-source
|
| 22 |
+
distractors so the model always sees both source values and learns the
|
| 23 |
+
correct parquet path from the candidates table.
|
| 24 |
+
|
| 25 |
+
Template families
|
| 26 |
+
-----------------
|
| 27 |
+
direct_lookup Simple single-feature fetch by ID.
|
| 28 |
+
adjacency ST_Touches — features sharing a border.
|
| 29 |
+
multi_adjacency Features that simultaneously touch TWO anchors.
|
| 30 |
+
containment ST_Within / ST_Contains — hierarchical nesting.
|
| 31 |
+
intersection ST_Intersects — overlapping or crossing features.
|
| 32 |
+
buffer ST_Buffer — proximity zones in km or metres.
|
| 33 |
+
chained Containment + EXISTS/NOT EXISTS sea predicate.
|
| 34 |
+
difference ST_Difference — geometry subtraction.
|
| 35 |
+
border_corridor Buffered ST_Intersection of a shared border.
|
| 36 |
+
set_operations ST_Union_Agg — merging multiple geometries.
|
| 37 |
+
partial_selection Bbox clipping — directional halves or feature clips.
|
| 38 |
+
aggregation TOP-N by area with ORDER BY.
|
| 39 |
+
window_function ROW_NUMBER() OVER (PARTITION BY) — per-group ranking.
|
| 40 |
+
attribute_filter Pure attribute predicates: is_land, country, etc.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
from dataclasses import dataclass
|
| 44 |
+
from typing import List, Literal
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class SQLTemplate:
|
| 49 |
+
"""SQL template for synthetic data generation."""
|
| 50 |
+
|
| 51 |
+
template_id: str
|
| 52 |
+
family: str
|
| 53 |
+
sql_difficulty: Literal["easy", "medium", "medium-hard", "hard"]
|
| 54 |
+
anchor_source: Literal["divisions_area", "natural_earth", "mixed"]
|
| 55 |
+
num_anchors: int
|
| 56 |
+
sql_template: str
|
| 57 |
+
question_hints: List[str]
|
| 58 |
+
target_subtype: str = None
|
| 59 |
+
requires_buffer: bool = False
|
| 60 |
+
requires_aggregation: bool = False
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# Template catalog
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
TEMPLATES = [
|
| 68 |
+
|
| 69 |
+
# ── DIRECT LOOKUP ────────────────────────────────────────────────────────
|
| 70 |
+
|
| 71 |
+
SQLTemplate(
|
| 72 |
+
template_id="lookup_01",
|
| 73 |
+
family="direct_lookup",
|
| 74 |
+
sql_difficulty="easy",
|
| 75 |
+
anchor_source="divisions_area",
|
| 76 |
+
num_anchors=1,
|
| 77 |
+
sql_template=(
|
| 78 |
+
"SELECT ST_AsGeoJSON(geometry) AS geometry,"
|
| 79 |
+
" names.\"primary\" AS name, id, subtype, country"
|
| 80 |
+
" FROM read_parquet('divisions_area')"
|
| 81 |
+
" WHERE id = '{anchor_id}'"
|
| 82 |
+
),
|
| 83 |
+
question_hints=[
|
| 84 |
+
"Show me {anchor_name}",
|
| 85 |
+
"Get the boundary of {anchor_name}",
|
| 86 |
+
"Find {anchor_name}",
|
| 87 |
+
"Where is {anchor_name}?",
|
| 88 |
+
"Give me the outline of {anchor_name}",
|
| 89 |
+
"Display {anchor_name} on a map",
|
| 90 |
+
"What does {anchor_name} look like?",
|
| 91 |
+
"I need the shape of {anchor_name}",
|
| 92 |
+
"Pull up {anchor_name}",
|
| 93 |
+
"Can you show {anchor_name}?",
|
| 94 |
+
"Map of {anchor_name}",
|
| 95 |
+
"{anchor_name} boundary",
|
| 96 |
+
"Locate {anchor_name} for me",
|
| 97 |
+
],
|
| 98 |
+
),
|
| 99 |
+
|
| 100 |
+
SQLTemplate(
|
| 101 |
+
template_id="lookup_02",
|
| 102 |
+
family="direct_lookup",
|
| 103 |
+
sql_difficulty="easy",
|
| 104 |
+
anchor_source="natural_earth",
|
| 105 |
+
num_anchors=1,
|
| 106 |
+
sql_template=(
|
| 107 |
+
"SELECT ST_AsGeoJSON(geometry) AS geometry,"
|
| 108 |
+
" names.\"primary\" AS name, id, subtype"
|
| 109 |
+
" FROM read_parquet('natural_earth')"
|
| 110 |
+
" WHERE id = '{anchor_id}'"
|
| 111 |
+
),
|
| 112 |
+
question_hints=[
|
| 113 |
+
"Show me the {anchor_name}",
|
| 114 |
+
"Get {anchor_name}",
|
| 115 |
+
"Find the {anchor_name}",
|
| 116 |
+
"Where is the {anchor_name}?",
|
| 117 |
+
"Show the extent of the {anchor_name}",
|
| 118 |
+
"Give me the geometry of the {anchor_name}",
|
| 119 |
+
"Display the {anchor_name}",
|
| 120 |
+
"Pull up the {anchor_name}",
|
| 121 |
+
"I want to see the {anchor_name}",
|
| 122 |
+
"Map the {anchor_name}",
|
| 123 |
+
"How big is the {anchor_name}?",
|
| 124 |
+
"Outline of the {anchor_name}",
|
| 125 |
+
],
|
| 126 |
+
),
|
| 127 |
+
|
| 128 |
+
# ── ADJACENCY ────────────────────────────────────────────────────────────
|
| 129 |
+
|
| 130 |
+
SQLTemplate(
|
| 131 |
+
template_id="adj_01",
|
| 132 |
+
family="adjacency",
|
| 133 |
+
sql_difficulty="medium",
|
| 134 |
+
anchor_source="divisions_area",
|
| 135 |
+
num_anchors=1,
|
| 136 |
+
sql_template=(
|
| 137 |
+
"WITH a AS ("
|
| 138 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 139 |
+
")"
|
| 140 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 141 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 142 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 143 |
+
" WHERE b.id != '{anchor_id}'"
|
| 144 |
+
" AND ST_Touches(a.geometry, b.geometry)"
|
| 145 |
+
),
|
| 146 |
+
question_hints=[
|
| 147 |
+
"Which regions border {anchor_name}?",
|
| 148 |
+
"What administrative units touch {anchor_name}?",
|
| 149 |
+
"List all places adjacent to {anchor_name}",
|
| 150 |
+
"What shares a border with {anchor_name}?",
|
| 151 |
+
"Neighbours of {anchor_name}",
|
| 152 |
+
"What is adjacent to {anchor_name}?",
|
| 153 |
+
"What surrounds {anchor_name}?",
|
| 154 |
+
"Places next to {anchor_name}",
|
| 155 |
+
"Everything bordering {anchor_name}",
|
| 156 |
+
],
|
| 157 |
+
),
|
| 158 |
+
|
| 159 |
+
SQLTemplate(
|
| 160 |
+
template_id="adj_02",
|
| 161 |
+
family="adjacency",
|
| 162 |
+
sql_difficulty="medium",
|
| 163 |
+
anchor_source="divisions_area",
|
| 164 |
+
num_anchors=1,
|
| 165 |
+
target_subtype="region",
|
| 166 |
+
sql_template=(
|
| 167 |
+
"WITH a AS ("
|
| 168 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 169 |
+
")"
|
| 170 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 171 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 172 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 173 |
+
" WHERE b.id != '{anchor_id}'"
|
| 174 |
+
" AND b.subtype = '{target_subtype}'"
|
| 175 |
+
" AND ST_Touches(a.geometry, b.geometry)"
|
| 176 |
+
),
|
| 177 |
+
question_hints=[
|
| 178 |
+
"Which {target_subtype}s border {anchor_name}?",
|
| 179 |
+
"What {target_subtype}s share a border with {anchor_name}?",
|
| 180 |
+
"{target_subtype}s that touch {anchor_name}",
|
| 181 |
+
"Neighbouring {target_subtype}s of {anchor_name}",
|
| 182 |
+
"Which {target_subtype}s are adjacent to {anchor_name}?",
|
| 183 |
+
"{target_subtype}s along the {anchor_name} border",
|
| 184 |
+
"Find {target_subtype}s next to {anchor_name}",
|
| 185 |
+
],
|
| 186 |
+
),
|
| 187 |
+
|
| 188 |
+
SQLTemplate(
|
| 189 |
+
template_id="adj_03",
|
| 190 |
+
family="adjacency",
|
| 191 |
+
sql_difficulty="medium",
|
| 192 |
+
anchor_source="divisions_area",
|
| 193 |
+
num_anchors=1,
|
| 194 |
+
target_subtype="sea",
|
| 195 |
+
sql_template=(
|
| 196 |
+
"WITH a AS ("
|
| 197 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 198 |
+
")"
|
| 199 |
+
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 200 |
+
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 201 |
+
" FROM read_parquet('natural_earth') AS n, a"
|
| 202 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 203 |
+
" AND ST_Touches(a.geometry, n.geometry)"
|
| 204 |
+
),
|
| 205 |
+
question_hints=[
|
| 206 |
+
"Which seas touch {anchor_name}?",
|
| 207 |
+
"What seas border {anchor_name}?",
|
| 208 |
+
"Which bodies of water is {anchor_name} adjacent to?",
|
| 209 |
+
"What ocean or sea borders {anchor_name}?",
|
| 210 |
+
"Which oceans touch {anchor_name}?",
|
| 211 |
+
"What coastline does {anchor_name} have?",
|
| 212 |
+
"Which water bodies does {anchor_name} border?",
|
| 213 |
+
"Does {anchor_name} have access to the sea?",
|
| 214 |
+
"What ocean is {anchor_name} on?",
|
| 215 |
+
],
|
| 216 |
+
),
|
| 217 |
+
|
| 218 |
+
# ── MULTI-ADJACENCY ──────────────────────────────────────────────────────
|
| 219 |
+
|
| 220 |
+
SQLTemplate(
|
| 221 |
+
template_id="multi_adj_01",
|
| 222 |
+
family="multi_adjacency",
|
| 223 |
+
sql_difficulty="hard",
|
| 224 |
+
anchor_source="divisions_area",
|
| 225 |
+
num_anchors=2,
|
| 226 |
+
sql_template=(
|
| 227 |
+
"WITH a AS ("
|
| 228 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_1}'"
|
| 229 |
+
"),"
|
| 230 |
+
" b AS ("
|
| 231 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_2}'"
|
| 232 |
+
")"
|
| 233 |
+
" SELECT c.id, c.names.\"primary\" AS name, c.subtype, c.country,"
|
| 234 |
+
" ST_AsGeoJSON(c.geometry) AS geometry"
|
| 235 |
+
" FROM read_parquet('divisions_area') AS c, a, b"
|
| 236 |
+
" WHERE c.id NOT IN ('{anchor_id_1}', '{anchor_id_2}')"
|
| 237 |
+
" AND ST_Touches(c.geometry, a.geometry)"
|
| 238 |
+
" AND ST_Touches(c.geometry, b.geometry)"
|
| 239 |
+
),
|
| 240 |
+
question_hints=[
|
| 241 |
+
"Which regions border both {anchor_1_name} and {anchor_2_name}?",
|
| 242 |
+
"What places touch both {anchor_1_name} and {anchor_2_name}?",
|
| 243 |
+
"Regions adjacent to both {anchor_1_name} and {anchor_2_name}",
|
| 244 |
+
"What lies between {anchor_1_name} and {anchor_2_name}?",
|
| 245 |
+
"Common neighbours of {anchor_1_name} and {anchor_2_name}",
|
| 246 |
+
],
|
| 247 |
+
),
|
| 248 |
+
|
| 249 |
+
# ── CONTAINMENT ──────────────────────────────────────────────────────────
|
| 250 |
+
|
| 251 |
+
SQLTemplate(
|
| 252 |
+
template_id="contain_01",
|
| 253 |
+
family="containment",
|
| 254 |
+
sql_difficulty="medium",
|
| 255 |
+
anchor_source="divisions_area",
|
| 256 |
+
num_anchors=1,
|
| 257 |
+
target_subtype="locality",
|
| 258 |
+
sql_template=(
|
| 259 |
+
"WITH a AS ("
|
| 260 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 261 |
+
")"
|
| 262 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 263 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 264 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 265 |
+
" WHERE b.id != '{anchor_id}'"
|
| 266 |
+
" AND b.subtype = '{target_subtype}'"
|
| 267 |
+
" AND ST_Within(b.geometry, a.geometry)"
|
| 268 |
+
),
|
| 269 |
+
question_hints=[
|
| 270 |
+
"What {target_subtype}s are in {anchor_name}?",
|
| 271 |
+
"Which {target_subtype}s fall within {anchor_name}?",
|
| 272 |
+
"List all {target_subtype}s inside {anchor_name}",
|
| 273 |
+
"{target_subtype}s contained by {anchor_name}",
|
| 274 |
+
"All {target_subtype}s within the boundaries of {anchor_name}",
|
| 275 |
+
"{target_subtype}s of {anchor_name}",
|
| 276 |
+
"Show every {target_subtype} in {anchor_name}",
|
| 277 |
+
],
|
| 278 |
+
),
|
| 279 |
+
|
| 280 |
+
SQLTemplate(
|
| 281 |
+
template_id="contain_02",
|
| 282 |
+
family="containment",
|
| 283 |
+
sql_difficulty="medium",
|
| 284 |
+
anchor_source="divisions_area",
|
| 285 |
+
num_anchors=1,
|
| 286 |
+
target_subtype="country",
|
| 287 |
+
sql_template=(
|
| 288 |
+
"WITH a AS ("
|
| 289 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 290 |
+
")"
|
| 291 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype,"
|
| 292 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 293 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 294 |
+
" WHERE b.id != '{anchor_id}'"
|
| 295 |
+
" AND b.subtype = '{target_subtype}'"
|
| 296 |
+
" AND ST_Contains(b.geometry, a.geometry)"
|
| 297 |
+
),
|
| 298 |
+
question_hints=[
|
| 299 |
+
"What country contains {anchor_name}?",
|
| 300 |
+
"Which country is {anchor_name} in?",
|
| 301 |
+
"What country does {anchor_name} belong to?",
|
| 302 |
+
"Which nation contains {anchor_name}?",
|
| 303 |
+
"{anchor_name} is part of which country?",
|
| 304 |
+
"Where does {anchor_name} fall geographically?",
|
| 305 |
+
"What country is {anchor_name} located in?",
|
| 306 |
+
],
|
| 307 |
+
),
|
| 308 |
+
|
| 309 |
+
SQLTemplate(
|
| 310 |
+
template_id="contain_03",
|
| 311 |
+
family="containment",
|
| 312 |
+
sql_difficulty="medium",
|
| 313 |
+
anchor_source="natural_earth",
|
| 314 |
+
num_anchors=1,
|
| 315 |
+
target_subtype="region",
|
| 316 |
+
sql_template=(
|
| 317 |
+
"WITH a AS ("
|
| 318 |
+
" SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
|
| 319 |
+
")"
|
| 320 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 321 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 322 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 323 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 324 |
+
" AND ST_Within(b.geometry, a.geometry)"
|
| 325 |
+
),
|
| 326 |
+
question_hints=[
|
| 327 |
+
"Which {target_subtype}s are in the {anchor_name}?",
|
| 328 |
+
"What {target_subtype}s fall within the {anchor_name}?",
|
| 329 |
+
"{target_subtype}s inside the {anchor_name}",
|
| 330 |
+
"Administrative {target_subtype}s within the {anchor_name}",
|
| 331 |
+
"All regions contained by the {anchor_name}",
|
| 332 |
+
"What {target_subtype}s does the {anchor_name} contain?",
|
| 333 |
+
"{target_subtype}s covered by the {anchor_name}",
|
| 334 |
+
],
|
| 335 |
+
),
|
| 336 |
+
|
| 337 |
+
# ── INTERSECTION ─────────────────────────────────────────────────────────
|
| 338 |
+
|
| 339 |
+
SQLTemplate(
|
| 340 |
+
template_id="intersect_01",
|
| 341 |
+
family="intersection",
|
| 342 |
+
sql_difficulty="medium-hard",
|
| 343 |
+
anchor_source="divisions_area",
|
| 344 |
+
num_anchors=1,
|
| 345 |
+
target_subtype="region",
|
| 346 |
+
sql_template=(
|
| 347 |
+
"WITH a AS ("
|
| 348 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 349 |
+
")"
|
| 350 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 351 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 352 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 353 |
+
" WHERE b.id != '{anchor_id}'"
|
| 354 |
+
" AND b.subtype = '{target_subtype}'"
|
| 355 |
+
" AND ST_Intersects(b.geometry, a.geometry)"
|
| 356 |
+
),
|
| 357 |
+
question_hints=[
|
| 358 |
+
"Which {target_subtype}s intersect {anchor_name}?",
|
| 359 |
+
"What {target_subtype}s overlap with {anchor_name}?",
|
| 360 |
+
"{target_subtype}s that cross into {anchor_name}",
|
| 361 |
+
"Which {target_subtype}s overlap {anchor_name}?",
|
| 362 |
+
"{target_subtype}s partially inside {anchor_name}",
|
| 363 |
+
"What {target_subtype}s extend into {anchor_name}?",
|
| 364 |
+
],
|
| 365 |
+
),
|
| 366 |
+
|
| 367 |
+
SQLTemplate(
|
| 368 |
+
template_id="intersect_02",
|
| 369 |
+
family="intersection",
|
| 370 |
+
sql_difficulty="medium-hard",
|
| 371 |
+
anchor_source="natural_earth",
|
| 372 |
+
num_anchors=1,
|
| 373 |
+
target_subtype="country",
|
| 374 |
+
sql_template=(
|
| 375 |
+
"WITH a AS ("
|
| 376 |
+
" SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
|
| 377 |
+
")"
|
| 378 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype,"
|
| 379 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 380 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 381 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 382 |
+
" AND ST_Intersects(b.geometry, a.geometry)"
|
| 383 |
+
),
|
| 384 |
+
question_hints=[
|
| 385 |
+
"Which countries intersect the {anchor_name}?",
|
| 386 |
+
"What countries does the {anchor_name} pass through?",
|
| 387 |
+
"Countries that overlap with the {anchor_name}",
|
| 388 |
+
"Which countries touch the {anchor_name}?",
|
| 389 |
+
"Nations intersected by the {anchor_name}",
|
| 390 |
+
"Which nations does the {anchor_name} cross?",
|
| 391 |
+
"Countries along the {anchor_name}",
|
| 392 |
+
"What countries does the {anchor_name} cover?",
|
| 393 |
+
"Countries that the {anchor_name} spans across",
|
| 394 |
+
],
|
| 395 |
+
),
|
| 396 |
+
|
| 397 |
+
# ── BUFFER ───────────────────────────────────────────────────────────────
|
| 398 |
+
# CTE computes the buffered geometry (raw) for the spatial join.
|
| 399 |
+
# Final SELECT wraps the result features with ST_AsGeoJSON.
|
| 400 |
+
|
| 401 |
+
SQLTemplate(
|
| 402 |
+
template_id="buffer_01",
|
| 403 |
+
family="buffer",
|
| 404 |
+
sql_difficulty="hard",
|
| 405 |
+
anchor_source="divisions_area",
|
| 406 |
+
num_anchors=1,
|
| 407 |
+
requires_buffer=True,
|
| 408 |
+
sql_template=(
|
| 409 |
+
"WITH a AS ("
|
| 410 |
+
" SELECT ST_Buffer(geometry, {buffer_km} * 1000.0 / 111320.0) AS geom"
|
| 411 |
+
" FROM read_parquet('divisions_area')"
|
| 412 |
+
" WHERE id = '{anchor_id}'"
|
| 413 |
+
")"
|
| 414 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 415 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 416 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 417 |
+
" WHERE b.id != '{anchor_id}'"
|
| 418 |
+
" AND ST_Intersects(b.geometry, a.geom)"
|
| 419 |
+
),
|
| 420 |
+
question_hints=[
|
| 421 |
+
"What is within {buffer_km} km of {anchor_name}?",
|
| 422 |
+
"Administrative units within {buffer_km} km of {anchor_name}",
|
| 423 |
+
"Features within a {buffer_km} km radius of {anchor_name}",
|
| 424 |
+
"Places within {buffer_km} kilometers of {anchor_name}",
|
| 425 |
+
"{buffer_km} km buffer around {anchor_name}",
|
| 426 |
+
"What falls within {buffer_km} km of {anchor_name}?",
|
| 427 |
+
"Everything within {buffer_km} km of {anchor_name}",
|
| 428 |
+
],
|
| 429 |
+
),
|
| 430 |
+
|
| 431 |
+
SQLTemplate(
|
| 432 |
+
template_id="buffer_02",
|
| 433 |
+
family="buffer",
|
| 434 |
+
sql_difficulty="hard",
|
| 435 |
+
anchor_source="divisions_area",
|
| 436 |
+
num_anchors=1,
|
| 437 |
+
requires_buffer=True,
|
| 438 |
+
sql_template=(
|
| 439 |
+
"WITH a AS ("
|
| 440 |
+
" SELECT ST_Buffer(geometry, {buffer_m} / 111320.0) AS geom"
|
| 441 |
+
" FROM read_parquet('divisions_area')"
|
| 442 |
+
" WHERE id = '{anchor_id}'"
|
| 443 |
+
")"
|
| 444 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 445 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 446 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 447 |
+
" WHERE b.id != '{anchor_id}'"
|
| 448 |
+
" AND ST_Intersects(b.geometry, a.geom)"
|
| 449 |
+
),
|
| 450 |
+
question_hints=[
|
| 451 |
+
"What is within {buffer_m} meters of {anchor_name}?",
|
| 452 |
+
"Features within {buffer_m} m of {anchor_name}",
|
| 453 |
+
"Places within {buffer_m} metres of {anchor_name}",
|
| 454 |
+
"{buffer_m} meter buffer around {anchor_name}",
|
| 455 |
+
"What falls within {buffer_m} m of {anchor_name}?",
|
| 456 |
+
"Administrative units within {buffer_m} metres of {anchor_name}",
|
| 457 |
+
],
|
| 458 |
+
),
|
| 459 |
+
|
| 460 |
+
SQLTemplate(
|
| 461 |
+
template_id="buffer_03",
|
| 462 |
+
family="buffer",
|
| 463 |
+
sql_difficulty="hard",
|
| 464 |
+
anchor_source="natural_earth",
|
| 465 |
+
num_anchors=1,
|
| 466 |
+
requires_buffer=True,
|
| 467 |
+
sql_template=(
|
| 468 |
+
"WITH a AS ("
|
| 469 |
+
" SELECT ST_Buffer(geometry, {buffer_km} * 1000.0 / 111320.0) AS geom"
|
| 470 |
+
" FROM read_parquet('natural_earth')"
|
| 471 |
+
" WHERE id = '{anchor_id}'"
|
| 472 |
+
")"
|
| 473 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 474 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 475 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 476 |
+
" WHERE ST_Intersects(b.geometry, a.geom)"
|
| 477 |
+
),
|
| 478 |
+
question_hints=[
|
| 479 |
+
"What administrative units are within {buffer_km} km of the {anchor_name}?",
|
| 480 |
+
"Countries within {buffer_km} km of the {anchor_name}",
|
| 481 |
+
"Regions within {buffer_km} km of the {anchor_name}",
|
| 482 |
+
"What falls within {buffer_km} km of the {anchor_name}?",
|
| 483 |
+
"Administrative divisions within a {buffer_km} km radius of the {anchor_name}",
|
| 484 |
+
"Places within {buffer_km} kilometers of the {anchor_name}",
|
| 485 |
+
],
|
| 486 |
+
),
|
| 487 |
+
|
| 488 |
+
SQLTemplate(
|
| 489 |
+
template_id="buffer_04",
|
| 490 |
+
family="buffer",
|
| 491 |
+
sql_difficulty="hard",
|
| 492 |
+
anchor_source="natural_earth",
|
| 493 |
+
num_anchors=1,
|
| 494 |
+
requires_buffer=True,
|
| 495 |
+
sql_template=(
|
| 496 |
+
"WITH a AS ("
|
| 497 |
+
" SELECT ST_Buffer(geometry, {buffer_m} / 111320.0) AS geom"
|
| 498 |
+
" FROM read_parquet('natural_earth')"
|
| 499 |
+
" WHERE id = '{anchor_id}'"
|
| 500 |
+
")"
|
| 501 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 502 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 503 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 504 |
+
" WHERE ST_Intersects(b.geometry, a.geom)"
|
| 505 |
+
),
|
| 506 |
+
question_hints=[
|
| 507 |
+
"What is within {buffer_m} meters of the {anchor_name}?",
|
| 508 |
+
"Administrative units within {buffer_m} m of the {anchor_name}",
|
| 509 |
+
"Places within {buffer_m} metres of the {anchor_name}",
|
| 510 |
+
"{buffer_m} meter buffer around the {anchor_name}",
|
| 511 |
+
],
|
| 512 |
+
),
|
| 513 |
+
|
| 514 |
+
# ── CHAINED ──────────────────────────────────────────────────────────────
|
| 515 |
+
# Containment + EXISTS/NOT EXISTS ocean/sea.
|
| 516 |
+
# CTE holds raw geometry for ST_Within; final SELECT wraps with ST_AsGeoJSON.
|
| 517 |
+
|
| 518 |
+
SQLTemplate(
|
| 519 |
+
template_id="chained_01",
|
| 520 |
+
family="chained",
|
| 521 |
+
sql_difficulty="hard",
|
| 522 |
+
anchor_source="divisions_area",
|
| 523 |
+
num_anchors=1,
|
| 524 |
+
target_subtype="locality",
|
| 525 |
+
sql_template=(
|
| 526 |
+
"WITH region AS ("
|
| 527 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 528 |
+
")"
|
| 529 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 530 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 531 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 532 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 533 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 534 |
+
" AND EXISTS ("
|
| 535 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 536 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 537 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 538 |
+
" )"
|
| 539 |
+
),
|
| 540 |
+
question_hints=[
|
| 541 |
+
"Coastal {target_subtype}s of {anchor_name}",
|
| 542 |
+
"{target_subtype}s in {anchor_name} with sea access",
|
| 543 |
+
"Which {target_subtype}s in {anchor_name} are on the coast?",
|
| 544 |
+
"Seaside {target_subtype}s within {anchor_name}",
|
| 545 |
+
"{target_subtype}s in {anchor_name} bordering the sea",
|
| 546 |
+
"Oceanfront {target_subtype}s in {anchor_name}",
|
| 547 |
+
"Which {target_subtype}s in {anchor_name} have a coastline?",
|
| 548 |
+
],
|
| 549 |
+
),
|
| 550 |
+
|
| 551 |
+
SQLTemplate(
|
| 552 |
+
template_id="chained_02",
|
| 553 |
+
family="chained",
|
| 554 |
+
sql_difficulty="hard",
|
| 555 |
+
anchor_source="divisions_area",
|
| 556 |
+
num_anchors=1,
|
| 557 |
+
target_subtype="country",
|
| 558 |
+
sql_template=(
|
| 559 |
+
"WITH region AS ("
|
| 560 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 561 |
+
")"
|
| 562 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 563 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 564 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 565 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 566 |
+
" AND ST_Intersects(b.geometry, region.geometry)"
|
| 567 |
+
" AND NOT EXISTS ("
|
| 568 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 569 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 570 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 571 |
+
" )"
|
| 572 |
+
),
|
| 573 |
+
question_hints=[
|
| 574 |
+
"Landlocked {target_subtype}s in {anchor_name}",
|
| 575 |
+
"Which {target_subtype}s in {anchor_name} have no sea access?",
|
| 576 |
+
"{target_subtype}s in {anchor_name} that are landlocked",
|
| 577 |
+
"{target_subtype}s in {anchor_name} with no coastline",
|
| 578 |
+
"Which {target_subtype}s within {anchor_name} are landlocked?",
|
| 579 |
+
"Interior {target_subtype}s of {anchor_name} with no ocean border",
|
| 580 |
+
],
|
| 581 |
+
),
|
| 582 |
+
|
| 583 |
+
SQLTemplate(
|
| 584 |
+
template_id="chained_03",
|
| 585 |
+
family="chained",
|
| 586 |
+
sql_difficulty="hard",
|
| 587 |
+
anchor_source="divisions_area",
|
| 588 |
+
num_anchors=1,
|
| 589 |
+
target_subtype="locality",
|
| 590 |
+
sql_template=(
|
| 591 |
+
"WITH region AS ("
|
| 592 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 593 |
+
")"
|
| 594 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 595 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 596 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 597 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 598 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 599 |
+
" AND EXISTS ("
|
| 600 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 601 |
+
" WHERE n.subtype IN ('Terrain area', 'Island group', 'Peninsula')"
|
| 602 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 603 |
+
" )"
|
| 604 |
+
),
|
| 605 |
+
question_hints=[
|
| 606 |
+
"{target_subtype}s in {anchor_name} on a terrain feature or island",
|
| 607 |
+
"{target_subtype}s of {anchor_name} on a peninsula or island group",
|
| 608 |
+
"{target_subtype}s within {anchor_name} on notable landforms",
|
| 609 |
+
"Island and peninsula {target_subtype}s of {anchor_name}",
|
| 610 |
+
],
|
| 611 |
+
),
|
| 612 |
+
|
| 613 |
+
# ── DIFFERENCE ───────────────────────────────────────────────────────────
|
| 614 |
+
# CTEs hold raw geometry; ST_Difference result wrapped with ST_AsGeoJSON.
|
| 615 |
+
|
| 616 |
+
SQLTemplate(
|
| 617 |
+
template_id="diff_01",
|
| 618 |
+
family="difference",
|
| 619 |
+
sql_difficulty="hard",
|
| 620 |
+
anchor_source="divisions_area",
|
| 621 |
+
num_anchors=2,
|
| 622 |
+
sql_template=(
|
| 623 |
+
"WITH a AS ("
|
| 624 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_1}'"
|
| 625 |
+
"),"
|
| 626 |
+
" b AS ("
|
| 627 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_2}'"
|
| 628 |
+
")"
|
| 629 |
+
" SELECT ST_AsGeoJSON(ST_Difference(a.geometry, b.geometry)) AS geometry"
|
| 630 |
+
" FROM a, b"
|
| 631 |
+
" WHERE ST_Intersects(a.geometry, b.geometry)"
|
| 632 |
+
),
|
| 633 |
+
question_hints=[
|
| 634 |
+
"{anchor_1_name} excluding {anchor_2_name}",
|
| 635 |
+
"{anchor_1_name} minus {anchor_2_name}",
|
| 636 |
+
"The part of {anchor_1_name} that is not in {anchor_2_name}",
|
| 637 |
+
"{anchor_1_name} without the {anchor_2_name} area",
|
| 638 |
+
"Remove {anchor_2_name} from {anchor_1_name}",
|
| 639 |
+
"{anchor_1_name} with {anchor_2_name} cut out",
|
| 640 |
+
"Subtract {anchor_2_name} from {anchor_1_name}",
|
| 641 |
+
"What is left of {anchor_1_name} after removing {anchor_2_name}?",
|
| 642 |
+
],
|
| 643 |
+
),
|
| 644 |
+
|
| 645 |
+
SQLTemplate(
|
| 646 |
+
template_id="diff_02",
|
| 647 |
+
family="difference",
|
| 648 |
+
sql_difficulty="hard",
|
| 649 |
+
anchor_source="mixed",
|
| 650 |
+
num_anchors=2,
|
| 651 |
+
sql_template=(
|
| 652 |
+
"WITH a AS ("
|
| 653 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 654 |
+
"),"
|
| 655 |
+
" b AS ("
|
| 656 |
+
" SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{clip_feature_id}'"
|
| 657 |
+
")"
|
| 658 |
+
" SELECT ST_AsGeoJSON(ST_Difference(a.geometry, b.geometry)) AS geometry"
|
| 659 |
+
" FROM a, b"
|
| 660 |
+
" WHERE ST_Intersects(a.geometry, b.geometry)"
|
| 661 |
+
),
|
| 662 |
+
question_hints=[
|
| 663 |
+
"The part of {anchor_name} outside the {clip_feature_name}",
|
| 664 |
+
"{anchor_name} excluding the {clip_feature_name}",
|
| 665 |
+
"{anchor_name} minus the {clip_feature_name}",
|
| 666 |
+
"The land area of {anchor_name} not covered by the {clip_feature_name}",
|
| 667 |
+
"{anchor_name} with the {clip_feature_name} removed",
|
| 668 |
+
"What remains of {anchor_name} after removing the {clip_feature_name}?",
|
| 669 |
+
],
|
| 670 |
+
),
|
| 671 |
+
|
| 672 |
+
# ── BORDER CORRIDOR ──────────────────────────────────────────────────────
|
| 673 |
+
# Intermediate intersection kept raw; final buffer wrapped with ST_AsGeoJSON.
|
| 674 |
+
|
| 675 |
+
SQLTemplate(
|
| 676 |
+
template_id="corridor_01",
|
| 677 |
+
family="border_corridor",
|
| 678 |
+
sql_difficulty="hard",
|
| 679 |
+
anchor_source="divisions_area",
|
| 680 |
+
num_anchors=2,
|
| 681 |
+
requires_buffer=True,
|
| 682 |
+
sql_template=(
|
| 683 |
+
"WITH a AS ("
|
| 684 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_1}'"
|
| 685 |
+
"),"
|
| 686 |
+
" b AS ("
|
| 687 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id_2}'"
|
| 688 |
+
"),"
|
| 689 |
+
" border AS ("
|
| 690 |
+
" SELECT ST_Intersection(a.geometry, b.geometry) AS line"
|
| 691 |
+
" FROM a, b"
|
| 692 |
+
" WHERE ST_Intersects(a.geometry, b.geometry)"
|
| 693 |
+
")"
|
| 694 |
+
" SELECT ST_AsGeoJSON(ST_Buffer(border.line, {buffer_km} * 1000.0 / 111320.0)) AS geometry"
|
| 695 |
+
" FROM border"
|
| 696 |
+
" WHERE border.line IS NOT NULL"
|
| 697 |
+
),
|
| 698 |
+
question_hints=[
|
| 699 |
+
"{buffer_km} km zone along the border between {anchor_1_name} and {anchor_2_name}",
|
| 700 |
+
"The {buffer_km} km border corridor between {anchor_1_name} and {anchor_2_name}",
|
| 701 |
+
"Area within {buffer_km} km of the {anchor_1_name}-{anchor_2_name} border",
|
| 702 |
+
"The region straddling the border of {anchor_1_name} and {anchor_2_name} within {buffer_km} km",
|
| 703 |
+
"{buffer_km} km on either side of the {anchor_1_name} and {anchor_2_name} border",
|
| 704 |
+
"Buffer the {anchor_1_name}-{anchor_2_name} boundary by {buffer_km} km",
|
| 705 |
+
],
|
| 706 |
+
),
|
| 707 |
+
|
| 708 |
+
# ── SET OPERATIONS ───────────────────────────────────────────────────────
|
| 709 |
+
# union_01 / union_02: 2-anchor and filtered-containment unions.
|
| 710 |
+
# union_03: 3-anchor union — trains the model on IN-clause with 3 IDs.
|
| 711 |
+
# contain_multi: subtype within multiple countries via country IN clause.
|
| 712 |
+
|
| 713 |
+
SQLTemplate(
|
| 714 |
+
template_id="union_01",
|
| 715 |
+
family="set_operations",
|
| 716 |
+
sql_difficulty="medium-hard",
|
| 717 |
+
anchor_source="divisions_area",
|
| 718 |
+
num_anchors=2,
|
| 719 |
+
sql_template=(
|
| 720 |
+
"SELECT ST_AsGeoJSON(ST_Union_Agg(geometry)) AS geometry,"
|
| 721 |
+
" array_agg(names.\"primary\") AS names"
|
| 722 |
+
" FROM read_parquet('divisions_area')"
|
| 723 |
+
" WHERE id IN ('{anchor_id_1}', '{anchor_id_2}')"
|
| 724 |
+
),
|
| 725 |
+
question_hints=[
|
| 726 |
+
"The combined area of {anchor_1_name} and {anchor_2_name}",
|
| 727 |
+
"Union of {anchor_1_name} and {anchor_2_name}",
|
| 728 |
+
"Merge {anchor_1_name} and {anchor_2_name}",
|
| 729 |
+
"{anchor_1_name} and {anchor_2_name} together",
|
| 730 |
+
"Combined geometry of {anchor_1_name} and {anchor_2_name}",
|
| 731 |
+
],
|
| 732 |
+
),
|
| 733 |
+
|
| 734 |
+
SQLTemplate(
|
| 735 |
+
template_id="union_03",
|
| 736 |
+
family="set_operations",
|
| 737 |
+
sql_difficulty="medium-hard",
|
| 738 |
+
anchor_source="divisions_area",
|
| 739 |
+
num_anchors=3,
|
| 740 |
+
sql_template=(
|
| 741 |
+
"SELECT ST_AsGeoJSON(ST_Union_Agg(geometry)) AS geometry,"
|
| 742 |
+
" array_agg(names.\"primary\") AS names"
|
| 743 |
+
" FROM read_parquet('divisions_area')"
|
| 744 |
+
" WHERE id IN ('{anchor_id_1}', '{anchor_id_2}', '{anchor_id_3}')"
|
| 745 |
+
),
|
| 746 |
+
question_hints=[
|
| 747 |
+
"Show me {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 748 |
+
"The combined area of {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 749 |
+
"Union of {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 750 |
+
"Merge {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 751 |
+
"{anchor_1_name}, {anchor_2_name} and {anchor_3_name} together",
|
| 752 |
+
"Display {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 753 |
+
],
|
| 754 |
+
),
|
| 755 |
+
|
| 756 |
+
SQLTemplate(
|
| 757 |
+
template_id="contain_multi_01",
|
| 758 |
+
family="set_operations",
|
| 759 |
+
sql_difficulty="medium-hard",
|
| 760 |
+
anchor_source="divisions_area",
|
| 761 |
+
num_anchors=2,
|
| 762 |
+
target_subtype="region",
|
| 763 |
+
sql_template=(
|
| 764 |
+
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 765 |
+
" ST_AsGeoJSON(geometry) AS geometry"
|
| 766 |
+
" FROM read_parquet('divisions_area')"
|
| 767 |
+
" WHERE country IN ('{country_1}', '{country_2}')"
|
| 768 |
+
" AND subtype = '{target_subtype}'"
|
| 769 |
+
),
|
| 770 |
+
question_hints=[
|
| 771 |
+
"{target_subtype}s of {anchor_1_name} and {anchor_2_name}",
|
| 772 |
+
"All {target_subtype}s in {anchor_1_name} and {anchor_2_name}",
|
| 773 |
+
"Show {target_subtype}s across {anchor_1_name} and {anchor_2_name}",
|
| 774 |
+
"{target_subtype}s belonging to {anchor_1_name} and {anchor_2_name}",
|
| 775 |
+
"List {target_subtype}s in both {anchor_1_name} and {anchor_2_name}",
|
| 776 |
+
],
|
| 777 |
+
),
|
| 778 |
+
|
| 779 |
+
SQLTemplate(
|
| 780 |
+
template_id="contain_multi_02",
|
| 781 |
+
family="set_operations",
|
| 782 |
+
sql_difficulty="medium-hard",
|
| 783 |
+
anchor_source="divisions_area",
|
| 784 |
+
num_anchors=3,
|
| 785 |
+
target_subtype="region",
|
| 786 |
+
sql_template=(
|
| 787 |
+
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 788 |
+
" ST_AsGeoJSON(geometry) AS geometry"
|
| 789 |
+
" FROM read_parquet('divisions_area')"
|
| 790 |
+
" WHERE country IN ('{country_1}', '{country_2}', '{country_3}')"
|
| 791 |
+
" AND subtype = '{target_subtype}'"
|
| 792 |
+
),
|
| 793 |
+
question_hints=[
|
| 794 |
+
"{target_subtype}s of {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 795 |
+
"All {target_subtype}s in {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 796 |
+
"Show {target_subtype}s across {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 797 |
+
"List {target_subtype}s in {anchor_1_name}, {anchor_2_name} and {anchor_3_name}",
|
| 798 |
+
],
|
| 799 |
+
),
|
| 800 |
+
|
| 801 |
+
SQLTemplate(
|
| 802 |
+
template_id="union_02",
|
| 803 |
+
family="set_operations",
|
| 804 |
+
sql_difficulty="hard",
|
| 805 |
+
anchor_source="divisions_area",
|
| 806 |
+
num_anchors=1,
|
| 807 |
+
target_subtype="locality",
|
| 808 |
+
requires_aggregation=True,
|
| 809 |
+
sql_template=(
|
| 810 |
+
"WITH a AS ("
|
| 811 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 812 |
+
")"
|
| 813 |
+
" SELECT ST_AsGeoJSON(ST_Union_Agg(b.geometry)) AS geometry"
|
| 814 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 815 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 816 |
+
" AND ST_Within(b.geometry, a.geometry)"
|
| 817 |
+
),
|
| 818 |
+
question_hints=[
|
| 819 |
+
"Merge all {target_subtype}s in {anchor_name} into one geometry",
|
| 820 |
+
"Combined geometry of all {target_subtype}s in {anchor_name}",
|
| 821 |
+
"Union of all {target_subtype}s within {anchor_name}",
|
| 822 |
+
"All {target_subtype}s of {anchor_name} merged together",
|
| 823 |
+
"The overall extent of {target_subtype}s in {anchor_name}",
|
| 824 |
+
],
|
| 825 |
+
),
|
| 826 |
+
|
| 827 |
+
# ── PARTIAL SELECTION ────────────────────────────────────────────────────
|
| 828 |
+
# Bbox clip CTEs use raw geometry; ST_Intersection result wrapped.
|
| 829 |
+
|
| 830 |
+
SQLTemplate(
|
| 831 |
+
template_id="partial_01",
|
| 832 |
+
family="partial_selection",
|
| 833 |
+
sql_difficulty="hard",
|
| 834 |
+
anchor_source="divisions_area",
|
| 835 |
+
num_anchors=1,
|
| 836 |
+
sql_template=(
|
| 837 |
+
"WITH a AS ("
|
| 838 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 839 |
+
"),"
|
| 840 |
+
" bbox AS ("
|
| 841 |
+
" SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
|
| 842 |
+
" ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
|
| 843 |
+
"),"
|
| 844 |
+
" clip AS ("
|
| 845 |
+
" SELECT ST_MakeEnvelope(xmin, (ymin + ymax) / 2.0, xmax, ymax) AS half_geom FROM bbox"
|
| 846 |
+
")"
|
| 847 |
+
" SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
|
| 848 |
+
" FROM a, clip"
|
| 849 |
+
),
|
| 850 |
+
question_hints=[
|
| 851 |
+
"The northern half of {anchor_name}",
|
| 852 |
+
"Northern part of {anchor_name}",
|
| 853 |
+
"The top half of {anchor_name}",
|
| 854 |
+
"Northern portion of {anchor_name}",
|
| 855 |
+
],
|
| 856 |
+
),
|
| 857 |
+
|
| 858 |
+
SQLTemplate(
|
| 859 |
+
template_id="partial_02",
|
| 860 |
+
family="partial_selection",
|
| 861 |
+
sql_difficulty="hard",
|
| 862 |
+
anchor_source="divisions_area",
|
| 863 |
+
num_anchors=1,
|
| 864 |
+
sql_template=(
|
| 865 |
+
"WITH a AS ("
|
| 866 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 867 |
+
"),"
|
| 868 |
+
" bbox AS ("
|
| 869 |
+
" SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
|
| 870 |
+
" ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
|
| 871 |
+
"),"
|
| 872 |
+
" clip AS ("
|
| 873 |
+
" SELECT ST_MakeEnvelope(xmin, ymin, xmax, (ymin + ymax) / 2.0) AS half_geom FROM bbox"
|
| 874 |
+
")"
|
| 875 |
+
" SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
|
| 876 |
+
" FROM a, clip"
|
| 877 |
+
),
|
| 878 |
+
question_hints=[
|
| 879 |
+
"The southern half of {anchor_name}",
|
| 880 |
+
"Southern part of {anchor_name}",
|
| 881 |
+
"The bottom half of {anchor_name}",
|
| 882 |
+
"Southern portion of {anchor_name}",
|
| 883 |
+
],
|
| 884 |
+
),
|
| 885 |
+
|
| 886 |
+
SQLTemplate(
|
| 887 |
+
template_id="partial_03",
|
| 888 |
+
family="partial_selection",
|
| 889 |
+
sql_difficulty="hard",
|
| 890 |
+
anchor_source="divisions_area",
|
| 891 |
+
num_anchors=1,
|
| 892 |
+
sql_template=(
|
| 893 |
+
"WITH a AS ("
|
| 894 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 895 |
+
"),"
|
| 896 |
+
" bbox AS ("
|
| 897 |
+
" SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
|
| 898 |
+
" ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
|
| 899 |
+
"),"
|
| 900 |
+
" clip AS ("
|
| 901 |
+
" SELECT ST_MakeEnvelope((xmin + xmax) / 2.0, ymin, xmax, ymax) AS half_geom FROM bbox"
|
| 902 |
+
")"
|
| 903 |
+
" SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
|
| 904 |
+
" FROM a, clip"
|
| 905 |
+
),
|
| 906 |
+
question_hints=[
|
| 907 |
+
"The eastern half of {anchor_name}",
|
| 908 |
+
"Eastern part of {anchor_name}",
|
| 909 |
+
"The right half of {anchor_name}",
|
| 910 |
+
"Eastern portion of {anchor_name}",
|
| 911 |
+
],
|
| 912 |
+
),
|
| 913 |
+
|
| 914 |
+
SQLTemplate(
|
| 915 |
+
template_id="partial_04",
|
| 916 |
+
family="partial_selection",
|
| 917 |
+
sql_difficulty="hard",
|
| 918 |
+
anchor_source="divisions_area",
|
| 919 |
+
num_anchors=1,
|
| 920 |
+
sql_template=(
|
| 921 |
+
"WITH a AS ("
|
| 922 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 923 |
+
"),"
|
| 924 |
+
" bbox AS ("
|
| 925 |
+
" SELECT ST_XMin(geometry) AS xmin, ST_XMax(geometry) AS xmax,"
|
| 926 |
+
" ST_YMin(geometry) AS ymin, ST_YMax(geometry) AS ymax FROM a"
|
| 927 |
+
"),"
|
| 928 |
+
" clip AS ("
|
| 929 |
+
" SELECT ST_MakeEnvelope(xmin, ymin, (xmin + xmax) / 2.0, ymax) AS half_geom FROM bbox"
|
| 930 |
+
")"
|
| 931 |
+
" SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, clip.half_geom)) AS geometry"
|
| 932 |
+
" FROM a, clip"
|
| 933 |
+
),
|
| 934 |
+
question_hints=[
|
| 935 |
+
"The western half of {anchor_name}",
|
| 936 |
+
"Western part of {anchor_name}",
|
| 937 |
+
"The left half of {anchor_name}",
|
| 938 |
+
"Western portion of {anchor_name}",
|
| 939 |
+
],
|
| 940 |
+
),
|
| 941 |
+
|
| 942 |
+
SQLTemplate(
|
| 943 |
+
template_id="partial_05",
|
| 944 |
+
family="partial_selection",
|
| 945 |
+
sql_difficulty="hard",
|
| 946 |
+
anchor_source="mixed",
|
| 947 |
+
num_anchors=2,
|
| 948 |
+
sql_template=(
|
| 949 |
+
"WITH a AS ("
|
| 950 |
+
" SELECT geometry AS g1 FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 951 |
+
"),"
|
| 952 |
+
" b AS ("
|
| 953 |
+
" SELECT geometry AS g2 FROM read_parquet('natural_earth') WHERE id = '{clip_feature_id}'"
|
| 954 |
+
")"
|
| 955 |
+
" SELECT ST_AsGeoJSON(ST_Intersection(a.g1, b.g2)) AS geometry"
|
| 956 |
+
" FROM a, b"
|
| 957 |
+
" WHERE ST_Intersects(a.g1, b.g2)"
|
| 958 |
+
),
|
| 959 |
+
question_hints=[
|
| 960 |
+
"The part of {anchor_name} that overlaps the {clip_feature_name}",
|
| 961 |
+
"{anchor_name} within the {clip_feature_name}",
|
| 962 |
+
"The portion of {anchor_name} inside the {clip_feature_name}",
|
| 963 |
+
"Clip {anchor_name} to the {clip_feature_name}",
|
| 964 |
+
],
|
| 965 |
+
),
|
| 966 |
+
|
| 967 |
+
# ── AGGREGATION ──────────────────────────────────────────────────────────
|
| 968 |
+
# ST_Area uses raw geometry in the ORDER BY; final SELECT wraps output.
|
| 969 |
+
|
| 970 |
+
SQLTemplate(
|
| 971 |
+
template_id="agg_01",
|
| 972 |
+
family="aggregation",
|
| 973 |
+
sql_difficulty="hard",
|
| 974 |
+
anchor_source="divisions_area",
|
| 975 |
+
num_anchors=1,
|
| 976 |
+
target_subtype=None, # filled at generation time: locality or region
|
| 977 |
+
requires_aggregation=True,
|
| 978 |
+
sql_template=(
|
| 979 |
+
"WITH a AS ("
|
| 980 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 981 |
+
")"
|
| 982 |
+
" SELECT b.id, b.names.\"primary\" AS name,"
|
| 983 |
+
" ST_AsGeoJSON(b.geometry) AS geometry,"
|
| 984 |
+
" ST_Area(b.geometry) AS area"
|
| 985 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 986 |
+
" WHERE ST_Within(b.geometry, a.geometry)"
|
| 987 |
+
" AND b.subtype = '{target_subtype}'"
|
| 988 |
+
" ORDER BY area DESC"
|
| 989 |
+
" LIMIT {top_n}"
|
| 990 |
+
),
|
| 991 |
+
question_hints=[
|
| 992 |
+
"Top {top_n} largest {target_subtype}s in {anchor_name}",
|
| 993 |
+
"Biggest {top_n} {target_subtype}s in {anchor_name}",
|
| 994 |
+
"{top_n} largest {target_subtype}s inside {anchor_name}",
|
| 995 |
+
"The {top_n} biggest {target_subtype}s within {anchor_name}",
|
| 996 |
+
"Largest {target_subtype} in {anchor_name}",
|
| 997 |
+
"Which {target_subtype} in {anchor_name} has the most area?",
|
| 998 |
+
],
|
| 999 |
+
),
|
| 1000 |
+
|
| 1001 |
+
SQLTemplate(
|
| 1002 |
+
template_id="agg_02",
|
| 1003 |
+
family="aggregation",
|
| 1004 |
+
sql_difficulty="hard",
|
| 1005 |
+
anchor_source="divisions_area",
|
| 1006 |
+
num_anchors=1,
|
| 1007 |
+
target_subtype=None, # filled at generation time: locality or region
|
| 1008 |
+
requires_aggregation=True,
|
| 1009 |
+
sql_template=(
|
| 1010 |
+
"WITH a AS ("
|
| 1011 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1012 |
+
")"
|
| 1013 |
+
" SELECT b.id, b.names.\"primary\" AS name,"
|
| 1014 |
+
" ST_AsGeoJSON(b.geometry) AS geometry,"
|
| 1015 |
+
" ST_Area(b.geometry) AS area"
|
| 1016 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 1017 |
+
" WHERE ST_Within(b.geometry, a.geometry)"
|
| 1018 |
+
" AND b.subtype = '{target_subtype}'"
|
| 1019 |
+
" ORDER BY area ASC"
|
| 1020 |
+
" LIMIT {top_n}"
|
| 1021 |
+
),
|
| 1022 |
+
question_hints=[
|
| 1023 |
+
"Top {top_n} smallest {target_subtype}s in {anchor_name}",
|
| 1024 |
+
"Smallest {top_n} {target_subtype}s in {anchor_name}",
|
| 1025 |
+
"{top_n} smallest {target_subtype}s inside {anchor_name}",
|
| 1026 |
+
"The {top_n} tiniest {target_subtype}s within {anchor_name}",
|
| 1027 |
+
"Smallest {target_subtype} in {anchor_name}",
|
| 1028 |
+
"Which {target_subtype} in {anchor_name} has the least area?",
|
| 1029 |
+
],
|
| 1030 |
+
),
|
| 1031 |
+
|
| 1032 |
+
SQLTemplate(
|
| 1033 |
+
template_id="agg_03",
|
| 1034 |
+
family="aggregation",
|
| 1035 |
+
sql_difficulty="hard",
|
| 1036 |
+
anchor_source="divisions_area",
|
| 1037 |
+
num_anchors=1,
|
| 1038 |
+
target_subtype=None, # filled at generation time: locality or region
|
| 1039 |
+
requires_aggregation=True,
|
| 1040 |
+
sql_template=(
|
| 1041 |
+
"SELECT id, names.\"primary\" AS name,"
|
| 1042 |
+
" ST_AsGeoJSON(geometry) AS geometry,"
|
| 1043 |
+
" ST_Area(geometry) AS area"
|
| 1044 |
+
" FROM read_parquet('divisions_area')"
|
| 1045 |
+
" WHERE country = '{country}'"
|
| 1046 |
+
" AND subtype = '{target_subtype}'"
|
| 1047 |
+
" ORDER BY area DESC"
|
| 1048 |
+
" LIMIT {top_n}"
|
| 1049 |
+
),
|
| 1050 |
+
question_hints=[
|
| 1051 |
+
"Top {top_n} largest {target_subtype}s in {anchor_name}",
|
| 1052 |
+
"{top_n} biggest {target_subtype}s in {anchor_name}",
|
| 1053 |
+
"Largest {top_n} {target_subtype}s in {anchor_name}",
|
| 1054 |
+
"The {top_n} largest {target_subtype}s in {anchor_name}",
|
| 1055 |
+
"Biggest {target_subtype} in {anchor_name}",
|
| 1056 |
+
"Which {target_subtype} in {anchor_name} is the largest?",
|
| 1057 |
+
],
|
| 1058 |
+
),
|
| 1059 |
+
|
| 1060 |
+
SQLTemplate(
|
| 1061 |
+
template_id="agg_04",
|
| 1062 |
+
family="aggregation",
|
| 1063 |
+
sql_difficulty="hard",
|
| 1064 |
+
anchor_source="divisions_area",
|
| 1065 |
+
num_anchors=1,
|
| 1066 |
+
target_subtype=None, # filled at generation time: locality or region
|
| 1067 |
+
requires_aggregation=True,
|
| 1068 |
+
sql_template=(
|
| 1069 |
+
"SELECT id, names.\"primary\" AS name,"
|
| 1070 |
+
" ST_AsGeoJSON(geometry) AS geometry,"
|
| 1071 |
+
" ST_Area(geometry) AS area"
|
| 1072 |
+
" FROM read_parquet('divisions_area')"
|
| 1073 |
+
" WHERE country = '{country}'"
|
| 1074 |
+
" AND subtype = '{target_subtype}'"
|
| 1075 |
+
" ORDER BY area ASC"
|
| 1076 |
+
" LIMIT {top_n}"
|
| 1077 |
+
),
|
| 1078 |
+
question_hints=[
|
| 1079 |
+
"Top {top_n} smallest {target_subtype}s in {anchor_name}",
|
| 1080 |
+
"{top_n} smallest {target_subtype}s in {anchor_name}",
|
| 1081 |
+
"Smallest {top_n} {target_subtype}s in {anchor_name}",
|
| 1082 |
+
"The {top_n} smallest {target_subtype}s in {anchor_name}",
|
| 1083 |
+
"Smallest {target_subtype} in {anchor_name}",
|
| 1084 |
+
"Which {target_subtype} in {anchor_name} is the smallest?",
|
| 1085 |
+
],
|
| 1086 |
+
),
|
| 1087 |
+
|
| 1088 |
+
# ── WINDOW FUNCTION ──────────────────────────────────────────────────────
|
| 1089 |
+
# CTE keeps raw geometry for ST_Area; final SELECT wraps with ST_AsGeoJSON.
|
| 1090 |
+
|
| 1091 |
+
SQLTemplate(
|
| 1092 |
+
template_id="window_01",
|
| 1093 |
+
family="window_function",
|
| 1094 |
+
sql_difficulty="hard",
|
| 1095 |
+
anchor_source="divisions_area",
|
| 1096 |
+
num_anchors=1,
|
| 1097 |
+
target_subtype="locality",
|
| 1098 |
+
requires_aggregation=True,
|
| 1099 |
+
sql_template=(
|
| 1100 |
+
"WITH ranked AS ("
|
| 1101 |
+
" SELECT id, names.\"primary\" AS name, subtype, country, region, geometry,"
|
| 1102 |
+
" ST_Area(geometry) AS area,"
|
| 1103 |
+
" ROW_NUMBER() OVER (PARTITION BY region ORDER BY ST_Area(geometry) DESC) AS rn"
|
| 1104 |
+
" FROM read_parquet('divisions_area')"
|
| 1105 |
+
" WHERE country = '{country}'"
|
| 1106 |
+
" AND subtype = '{target_subtype}'"
|
| 1107 |
+
")"
|
| 1108 |
+
" SELECT id, name, subtype, country, region,"
|
| 1109 |
+
" ST_AsGeoJSON(geometry) AS geometry, area"
|
| 1110 |
+
" FROM ranked"
|
| 1111 |
+
" WHERE rn = 1"
|
| 1112 |
+
),
|
| 1113 |
+
question_hints=[
|
| 1114 |
+
"The largest {target_subtype} in each region of {anchor_name}",
|
| 1115 |
+
"Biggest {target_subtype} per region in {anchor_name}",
|
| 1116 |
+
"Largest {target_subtype} for every region of {anchor_name}",
|
| 1117 |
+
"The biggest {target_subtype} in each province of {anchor_name}",
|
| 1118 |
+
],
|
| 1119 |
+
),
|
| 1120 |
+
|
| 1121 |
+
SQLTemplate(
|
| 1122 |
+
template_id="window_02",
|
| 1123 |
+
family="window_function",
|
| 1124 |
+
sql_difficulty="hard",
|
| 1125 |
+
anchor_source="divisions_area",
|
| 1126 |
+
num_anchors=1,
|
| 1127 |
+
target_subtype="locality",
|
| 1128 |
+
requires_aggregation=True,
|
| 1129 |
+
sql_template=(
|
| 1130 |
+
"WITH ranked AS ("
|
| 1131 |
+
" SELECT id, names.\"primary\" AS name, subtype, country, region, geometry,"
|
| 1132 |
+
" ST_Area(geometry) AS area,"
|
| 1133 |
+
" ROW_NUMBER() OVER (PARTITION BY region ORDER BY ST_Area(geometry) ASC) AS rn"
|
| 1134 |
+
" FROM read_parquet('divisions_area')"
|
| 1135 |
+
" WHERE country = '{country}'"
|
| 1136 |
+
" AND subtype = '{target_subtype}'"
|
| 1137 |
+
")"
|
| 1138 |
+
" SELECT id, name, subtype, country, region,"
|
| 1139 |
+
" ST_AsGeoJSON(geometry) AS geometry, area"
|
| 1140 |
+
" FROM ranked"
|
| 1141 |
+
" WHERE rn = 1"
|
| 1142 |
+
),
|
| 1143 |
+
question_hints=[
|
| 1144 |
+
"The smallest {target_subtype} in each region of {anchor_name}",
|
| 1145 |
+
"Smallest {target_subtype} per region in {anchor_name}",
|
| 1146 |
+
"Tiniest {target_subtype} for every region of {anchor_name}",
|
| 1147 |
+
"The smallest {target_subtype} in each province of {anchor_name}",
|
| 1148 |
+
],
|
| 1149 |
+
),
|
| 1150 |
+
|
| 1151 |
+
# ── ATTRIBUTE FILTER ─────────────────────────────────────────────────────
|
| 1152 |
+
# No spatial op — pure WHERE on is_land / is_territorial / country.
|
| 1153 |
+
|
| 1154 |
+
SQLTemplate(
|
| 1155 |
+
template_id="attr_01",
|
| 1156 |
+
family="attribute_filter",
|
| 1157 |
+
sql_difficulty="medium",
|
| 1158 |
+
anchor_source="divisions_area",
|
| 1159 |
+
num_anchors=1,
|
| 1160 |
+
target_subtype="dependency",
|
| 1161 |
+
sql_template=(
|
| 1162 |
+
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 1163 |
+
" ST_AsGeoJSON(geometry) AS geometry"
|
| 1164 |
+
" FROM read_parquet('divisions_area')"
|
| 1165 |
+
" WHERE country = '{country}'"
|
| 1166 |
+
" AND is_land = TRUE"
|
| 1167 |
+
" AND subtype = '{target_subtype}'"
|
| 1168 |
+
),
|
| 1169 |
+
question_hints=[
|
| 1170 |
+
"Island territories of {anchor_name}",
|
| 1171 |
+
"Overseas island {target_subtype}s belonging to {anchor_name}",
|
| 1172 |
+
"Which islands are part of {anchor_name}?",
|
| 1173 |
+
"Land territories of {anchor_name}",
|
| 1174 |
+
"Island possessions of {anchor_name}",
|
| 1175 |
+
"{anchor_name}'s island {target_subtype}s",
|
| 1176 |
+
],
|
| 1177 |
+
),
|
| 1178 |
+
|
| 1179 |
+
SQLTemplate(
|
| 1180 |
+
template_id="attr_02",
|
| 1181 |
+
family="attribute_filter",
|
| 1182 |
+
sql_difficulty="medium",
|
| 1183 |
+
anchor_source="divisions_area",
|
| 1184 |
+
num_anchors=1,
|
| 1185 |
+
target_subtype="region",
|
| 1186 |
+
sql_template=(
|
| 1187 |
+
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 1188 |
+
" ST_AsGeoJSON(geometry) AS geometry"
|
| 1189 |
+
" FROM read_parquet('divisions_area')"
|
| 1190 |
+
" WHERE country = '{country}'"
|
| 1191 |
+
" AND is_territorial = TRUE"
|
| 1192 |
+
" AND subtype = '{target_subtype}'"
|
| 1193 |
+
),
|
| 1194 |
+
question_hints=[
|
| 1195 |
+
"Territorial {target_subtype}s of {anchor_name}",
|
| 1196 |
+
"Official territorial divisions of {anchor_name}",
|
| 1197 |
+
"Recognised territorial {target_subtype}s belonging to {anchor_name}",
|
| 1198 |
+
"Which territorial regions does {anchor_name} have?",
|
| 1199 |
+
],
|
| 1200 |
+
),
|
| 1201 |
+
|
| 1202 |
+
SQLTemplate(
|
| 1203 |
+
template_id="attr_03",
|
| 1204 |
+
family="attribute_filter",
|
| 1205 |
+
sql_difficulty="medium",
|
| 1206 |
+
anchor_source="divisions_area",
|
| 1207 |
+
num_anchors=1,
|
| 1208 |
+
target_subtype="locality",
|
| 1209 |
+
sql_template=(
|
| 1210 |
+
"SELECT id, names.\"primary\" AS name, subtype, country,"
|
| 1211 |
+
" ST_AsGeoJSON(geometry) AS geometry"
|
| 1212 |
+
" FROM read_parquet('divisions_area')"
|
| 1213 |
+
" WHERE country = '{country}'"
|
| 1214 |
+
" AND subtype = '{target_subtype}'"
|
| 1215 |
+
" AND is_land = TRUE"
|
| 1216 |
+
),
|
| 1217 |
+
question_hints=[
|
| 1218 |
+
"Land-based {target_subtype}s of {anchor_name}",
|
| 1219 |
+
"{target_subtype}s on the mainland of {anchor_name}",
|
| 1220 |
+
"All {target_subtype}s on land in {anchor_name}",
|
| 1221 |
+
"Non-island {target_subtype}s of {anchor_name}",
|
| 1222 |
+
],
|
| 1223 |
+
),
|
| 1224 |
+
|
| 1225 |
+
# ── NATURAL EARTH ADJACENCY ─────────────────────────────────────────────
|
| 1226 |
+
# Division anchor, natural_earth targets. Handler formats anchor_id and
|
| 1227 |
+
# target_subtype but the SQL hardcodes NE subtypes (like adj_03).
|
| 1228 |
+
|
| 1229 |
+
SQLTemplate(
|
| 1230 |
+
template_id="adj_04",
|
| 1231 |
+
family="adjacency",
|
| 1232 |
+
sql_difficulty="medium",
|
| 1233 |
+
anchor_source="divisions_area",
|
| 1234 |
+
num_anchors=1,
|
| 1235 |
+
target_subtype="river",
|
| 1236 |
+
sql_template=(
|
| 1237 |
+
"WITH a AS ("
|
| 1238 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1239 |
+
")"
|
| 1240 |
+
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 1241 |
+
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 1242 |
+
" FROM read_parquet('natural_earth') AS n, a"
|
| 1243 |
+
" WHERE n.subtype IN ('River', 'Lake', 'Basin')"
|
| 1244 |
+
" AND ST_Intersects(a.geometry, n.geometry)"
|
| 1245 |
+
),
|
| 1246 |
+
question_hints=[
|
| 1247 |
+
"What rivers or lakes are in {anchor_name}?",
|
| 1248 |
+
"Natural water features of {anchor_name}",
|
| 1249 |
+
"Which rivers flow through {anchor_name}?",
|
| 1250 |
+
"Lakes and rivers within {anchor_name}",
|
| 1251 |
+
"Water features inside {anchor_name}",
|
| 1252 |
+
"What bodies of water cross {anchor_name}?",
|
| 1253 |
+
"Rivers of {anchor_name}",
|
| 1254 |
+
"Show me the lakes in {anchor_name}",
|
| 1255 |
+
],
|
| 1256 |
+
),
|
| 1257 |
+
|
| 1258 |
+
SQLTemplate(
|
| 1259 |
+
template_id="adj_05",
|
| 1260 |
+
family="adjacency",
|
| 1261 |
+
sql_difficulty="medium",
|
| 1262 |
+
anchor_source="divisions_area",
|
| 1263 |
+
num_anchors=1,
|
| 1264 |
+
target_subtype="range",
|
| 1265 |
+
sql_template=(
|
| 1266 |
+
"WITH a AS ("
|
| 1267 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1268 |
+
")"
|
| 1269 |
+
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 1270 |
+
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 1271 |
+
" FROM read_parquet('natural_earth') AS n, a"
|
| 1272 |
+
" WHERE n.subtype IN ('Range/Mts', 'Terrain area', 'Peninsula', 'Depression')"
|
| 1273 |
+
" AND ST_Intersects(a.geometry, n.geometry)"
|
| 1274 |
+
),
|
| 1275 |
+
question_hints=[
|
| 1276 |
+
"What mountain ranges are in {anchor_name}?",
|
| 1277 |
+
"Terrain features of {anchor_name}",
|
| 1278 |
+
"Which mountain ranges cross {anchor_name}?",
|
| 1279 |
+
"Landforms inside {anchor_name}",
|
| 1280 |
+
"Peninsulas and ranges in {anchor_name}",
|
| 1281 |
+
"Geographic features within {anchor_name}",
|
| 1282 |
+
"Mountains of {anchor_name}",
|
| 1283 |
+
"What terrain does {anchor_name} contain?",
|
| 1284 |
+
],
|
| 1285 |
+
),
|
| 1286 |
+
|
| 1287 |
+
# ── NATURAL EARTH INTERSECTION ──────────────────────────────────────────
|
| 1288 |
+
# intersect_03: NE anchor, finding overlapping regions (vs countries in
|
| 1289 |
+
# intersect_02). Uses cross_source_relations handler.
|
| 1290 |
+
# intersect_04: division anchor, finding NE features that overlap it.
|
| 1291 |
+
# Uses intersection_pairs handler (extra NE subtypes ignored in SQL).
|
| 1292 |
+
|
| 1293 |
+
SQLTemplate(
|
| 1294 |
+
template_id="intersect_03",
|
| 1295 |
+
family="intersection",
|
| 1296 |
+
sql_difficulty="medium-hard",
|
| 1297 |
+
anchor_source="natural_earth",
|
| 1298 |
+
num_anchors=1,
|
| 1299 |
+
target_subtype="region",
|
| 1300 |
+
sql_template=(
|
| 1301 |
+
"WITH a AS ("
|
| 1302 |
+
" SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
|
| 1303 |
+
")"
|
| 1304 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1305 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1306 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 1307 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1308 |
+
" AND ST_Intersects(b.geometry, a.geometry)"
|
| 1309 |
+
),
|
| 1310 |
+
question_hints=[
|
| 1311 |
+
"Which regions does the {anchor_name} pass through?",
|
| 1312 |
+
"What administrative regions overlap with the {anchor_name}?",
|
| 1313 |
+
"Regions that the {anchor_name} crosses",
|
| 1314 |
+
"Administrative areas intersected by the {anchor_name}",
|
| 1315 |
+
"What provinces does the {anchor_name} span?",
|
| 1316 |
+
"Regions along the {anchor_name}",
|
| 1317 |
+
"Which provinces overlap the {anchor_name}?",
|
| 1318 |
+
],
|
| 1319 |
+
),
|
| 1320 |
+
|
| 1321 |
+
SQLTemplate(
|
| 1322 |
+
template_id="intersect_04",
|
| 1323 |
+
family="intersection",
|
| 1324 |
+
sql_difficulty="medium-hard",
|
| 1325 |
+
anchor_source="divisions_area",
|
| 1326 |
+
num_anchors=1,
|
| 1327 |
+
target_subtype="region",
|
| 1328 |
+
sql_template=(
|
| 1329 |
+
"WITH a AS ("
|
| 1330 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1331 |
+
")"
|
| 1332 |
+
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 1333 |
+
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 1334 |
+
" FROM read_parquet('natural_earth') AS n, a"
|
| 1335 |
+
" WHERE ST_Intersects(n.geometry, a.geometry)"
|
| 1336 |
+
),
|
| 1337 |
+
question_hints=[
|
| 1338 |
+
"What natural features intersect {anchor_name}?",
|
| 1339 |
+
"Natural earth features that overlap {anchor_name}",
|
| 1340 |
+
"Which geographic features cross {anchor_name}?",
|
| 1341 |
+
"Everything from natural earth that touches {anchor_name}",
|
| 1342 |
+
"What geographic features does {anchor_name} contain?",
|
| 1343 |
+
"Natural features within or crossing {anchor_name}",
|
| 1344 |
+
],
|
| 1345 |
+
),
|
| 1346 |
+
|
| 1347 |
+
# ── NATURAL EARTH CHAINED ───────────────────────────────────────────────
|
| 1348 |
+
# chained_04: localities in a region that intersect a river or lake.
|
| 1349 |
+
# chained_05: localities in a region that lie on a mountain range.
|
| 1350 |
+
|
| 1351 |
+
SQLTemplate(
|
| 1352 |
+
template_id="chained_04",
|
| 1353 |
+
family="chained",
|
| 1354 |
+
sql_difficulty="hard",
|
| 1355 |
+
anchor_source="divisions_area",
|
| 1356 |
+
num_anchors=1,
|
| 1357 |
+
target_subtype="locality",
|
| 1358 |
+
sql_template=(
|
| 1359 |
+
"WITH region AS ("
|
| 1360 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1361 |
+
")"
|
| 1362 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1363 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1364 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 1365 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1366 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 1367 |
+
" AND EXISTS ("
|
| 1368 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1369 |
+
" WHERE n.subtype IN ('River', 'Lake', 'Basin')"
|
| 1370 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1371 |
+
" )"
|
| 1372 |
+
),
|
| 1373 |
+
question_hints=[
|
| 1374 |
+
"Riverside {target_subtype}s in {anchor_name}",
|
| 1375 |
+
"{target_subtype}s in {anchor_name} near a river or lake",
|
| 1376 |
+
"Which {target_subtype}s in {anchor_name} are on a waterway?",
|
| 1377 |
+
"Lakeside or riverside {target_subtype}s within {anchor_name}",
|
| 1378 |
+
"{target_subtype}s in {anchor_name} that touch a river",
|
| 1379 |
+
"Which {target_subtype}s in {anchor_name} are on a lake?",
|
| 1380 |
+
"Waterfront {target_subtype}s of {anchor_name}",
|
| 1381 |
+
],
|
| 1382 |
+
),
|
| 1383 |
+
|
| 1384 |
+
SQLTemplate(
|
| 1385 |
+
template_id="chained_05",
|
| 1386 |
+
family="chained",
|
| 1387 |
+
sql_difficulty="hard",
|
| 1388 |
+
anchor_source="divisions_area",
|
| 1389 |
+
num_anchors=1,
|
| 1390 |
+
target_subtype="locality",
|
| 1391 |
+
sql_template=(
|
| 1392 |
+
"WITH region AS ("
|
| 1393 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1394 |
+
")"
|
| 1395 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1396 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1397 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 1398 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1399 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 1400 |
+
" AND EXISTS ("
|
| 1401 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1402 |
+
" WHERE n.subtype IN ('Range/Mts', 'Depression')"
|
| 1403 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1404 |
+
" )"
|
| 1405 |
+
),
|
| 1406 |
+
question_hints=[
|
| 1407 |
+
"Mountain {target_subtype}s in {anchor_name}",
|
| 1408 |
+
"{target_subtype}s in {anchor_name} on a mountain range",
|
| 1409 |
+
"Which {target_subtype}s in {anchor_name} are in the mountains?",
|
| 1410 |
+
"Highland {target_subtype}s within {anchor_name}",
|
| 1411 |
+
"{target_subtype}s of {anchor_name} in mountainous terrain",
|
| 1412 |
+
"{target_subtype}s in {anchor_name} near a mountain range",
|
| 1413 |
+
],
|
| 1414 |
+
),
|
| 1415 |
+
|
| 1416 |
+
# ── CHAINED (county-level) ──────────────────────────────────────────────
|
| 1417 |
+
# Same spatial patterns as chained_01..05 but targeting counties/districts
|
| 1418 |
+
# so the model learns "coastal districts of X", "riverside counties", etc.
|
| 1419 |
+
|
| 1420 |
+
SQLTemplate(
|
| 1421 |
+
template_id="chained_06",
|
| 1422 |
+
family="chained",
|
| 1423 |
+
sql_difficulty="hard",
|
| 1424 |
+
anchor_source="divisions_area",
|
| 1425 |
+
num_anchors=1,
|
| 1426 |
+
target_subtype="county",
|
| 1427 |
+
sql_template=(
|
| 1428 |
+
"WITH region AS ("
|
| 1429 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1430 |
+
")"
|
| 1431 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1432 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1433 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 1434 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1435 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 1436 |
+
" AND EXISTS ("
|
| 1437 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1438 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 1439 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1440 |
+
" )"
|
| 1441 |
+
),
|
| 1442 |
+
question_hints=[
|
| 1443 |
+
"Coastal {target_subtype}s of {anchor_name}",
|
| 1444 |
+
"Which districts of {anchor_name} are on the coast?",
|
| 1445 |
+
"{target_subtype}s in {anchor_name} that border the sea",
|
| 1446 |
+
"Seaside {target_subtype}s within {anchor_name}",
|
| 1447 |
+
"{target_subtype}s of {anchor_name} with ocean access",
|
| 1448 |
+
"Which {target_subtype}s in {anchor_name} touch the sea?",
|
| 1449 |
+
"Maritime {target_subtype}s of {anchor_name}",
|
| 1450 |
+
],
|
| 1451 |
+
),
|
| 1452 |
+
|
| 1453 |
+
SQLTemplate(
|
| 1454 |
+
template_id="chained_07",
|
| 1455 |
+
family="chained",
|
| 1456 |
+
sql_difficulty="hard",
|
| 1457 |
+
anchor_source="divisions_area",
|
| 1458 |
+
num_anchors=1,
|
| 1459 |
+
target_subtype="county",
|
| 1460 |
+
sql_template=(
|
| 1461 |
+
"WITH region AS ("
|
| 1462 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1463 |
+
")"
|
| 1464 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1465 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1466 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 1467 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1468 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 1469 |
+
" AND NOT EXISTS ("
|
| 1470 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1471 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 1472 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1473 |
+
" )"
|
| 1474 |
+
),
|
| 1475 |
+
question_hints=[
|
| 1476 |
+
"Landlocked {target_subtype}s of {anchor_name}",
|
| 1477 |
+
"Which districts of {anchor_name} have no coastline?",
|
| 1478 |
+
"Interior {target_subtype}s within {anchor_name}",
|
| 1479 |
+
"{target_subtype}s in {anchor_name} with no sea access",
|
| 1480 |
+
"Non-coastal {target_subtype}s of {anchor_name}",
|
| 1481 |
+
"Inland {target_subtype}s of {anchor_name}",
|
| 1482 |
+
],
|
| 1483 |
+
),
|
| 1484 |
+
|
| 1485 |
+
SQLTemplate(
|
| 1486 |
+
template_id="chained_08",
|
| 1487 |
+
family="chained",
|
| 1488 |
+
sql_difficulty="hard",
|
| 1489 |
+
anchor_source="divisions_area",
|
| 1490 |
+
num_anchors=1,
|
| 1491 |
+
target_subtype="county",
|
| 1492 |
+
sql_template=(
|
| 1493 |
+
"WITH region AS ("
|
| 1494 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1495 |
+
")"
|
| 1496 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1497 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1498 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 1499 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1500 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 1501 |
+
" AND EXISTS ("
|
| 1502 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1503 |
+
" WHERE n.subtype IN ('River', 'Lake', 'Basin')"
|
| 1504 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1505 |
+
" )"
|
| 1506 |
+
),
|
| 1507 |
+
question_hints=[
|
| 1508 |
+
"Riverside {target_subtype}s of {anchor_name}",
|
| 1509 |
+
"Which districts of {anchor_name} have a river or lake?",
|
| 1510 |
+
"{target_subtype}s in {anchor_name} on a waterway",
|
| 1511 |
+
"Lakeside {target_subtype}s within {anchor_name}",
|
| 1512 |
+
"{target_subtype}s of {anchor_name} along a river",
|
| 1513 |
+
"Which {target_subtype}s in {anchor_name} border a lake?",
|
| 1514 |
+
],
|
| 1515 |
+
),
|
| 1516 |
+
|
| 1517 |
+
SQLTemplate(
|
| 1518 |
+
template_id="chained_09",
|
| 1519 |
+
family="chained",
|
| 1520 |
+
sql_difficulty="hard",
|
| 1521 |
+
anchor_source="divisions_area",
|
| 1522 |
+
num_anchors=1,
|
| 1523 |
+
target_subtype="county",
|
| 1524 |
+
sql_template=(
|
| 1525 |
+
"WITH region AS ("
|
| 1526 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1527 |
+
")"
|
| 1528 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1529 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1530 |
+
" FROM read_parquet('divisions_area') AS b, region"
|
| 1531 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1532 |
+
" AND ST_Within(b.geometry, region.geometry)"
|
| 1533 |
+
" AND EXISTS ("
|
| 1534 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1535 |
+
" WHERE n.subtype IN ('Range/Mts', 'Depression')"
|
| 1536 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1537 |
+
" )"
|
| 1538 |
+
),
|
| 1539 |
+
question_hints=[
|
| 1540 |
+
"Mountain {target_subtype}s of {anchor_name}",
|
| 1541 |
+
"Which districts of {anchor_name} are in the mountains?",
|
| 1542 |
+
"{target_subtype}s in {anchor_name} on a mountain range",
|
| 1543 |
+
"Highland {target_subtype}s within {anchor_name}",
|
| 1544 |
+
"{target_subtype}s of {anchor_name} in mountainous terrain",
|
| 1545 |
+
"Which {target_subtype}s in {anchor_name} have mountain ranges?",
|
| 1546 |
+
],
|
| 1547 |
+
),
|
| 1548 |
+
|
| 1549 |
+
# ── NATURAL EARTH CONTAINMENT ───────────────────────────────────────────
|
| 1550 |
+
# contain_04: NE anchor (sea/gulf/bay), find countries that touch it.
|
| 1551 |
+
# Uses containment handler via containment_pairs.
|
| 1552 |
+
|
| 1553 |
+
SQLTemplate(
|
| 1554 |
+
template_id="contain_04",
|
| 1555 |
+
family="containment",
|
| 1556 |
+
sql_difficulty="medium",
|
| 1557 |
+
anchor_source="natural_earth",
|
| 1558 |
+
num_anchors=1,
|
| 1559 |
+
target_subtype="country",
|
| 1560 |
+
sql_template=(
|
| 1561 |
+
"WITH a AS ("
|
| 1562 |
+
" SELECT geometry FROM read_parquet('natural_earth') WHERE id = '{anchor_id}'"
|
| 1563 |
+
")"
|
| 1564 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype,"
|
| 1565 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1566 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 1567 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1568 |
+
" AND ST_Intersects(b.geometry, a.geometry)"
|
| 1569 |
+
),
|
| 1570 |
+
question_hints=[
|
| 1571 |
+
"Which countries border the {anchor_name}?",
|
| 1572 |
+
"What countries are along the {anchor_name}?",
|
| 1573 |
+
"Countries surrounding the {anchor_name}",
|
| 1574 |
+
"Nations on the {anchor_name}",
|
| 1575 |
+
"Which countries touch the {anchor_name}?",
|
| 1576 |
+
"Countries with coastline on the {anchor_name}",
|
| 1577 |
+
"What nations lie on the {anchor_name}?",
|
| 1578 |
+
],
|
| 1579 |
+
),
|
| 1580 |
+
|
| 1581 |
+
# ── NATURAL EARTH BUFFER ────────────────────────────────────────────────
|
| 1582 |
+
# buffer_05: NE anchor, find other NE features within a buffer distance.
|
| 1583 |
+
# Uses buffer handler for natural_earth.
|
| 1584 |
+
|
| 1585 |
+
SQLTemplate(
|
| 1586 |
+
template_id="buffer_05",
|
| 1587 |
+
family="buffer",
|
| 1588 |
+
sql_difficulty="hard",
|
| 1589 |
+
anchor_source="natural_earth",
|
| 1590 |
+
num_anchors=1,
|
| 1591 |
+
requires_buffer=True,
|
| 1592 |
+
sql_template=(
|
| 1593 |
+
"WITH a AS ("
|
| 1594 |
+
" SELECT ST_Buffer(geometry, {buffer_km} * 1000.0 / 111320.0) AS geom"
|
| 1595 |
+
" FROM read_parquet('natural_earth')"
|
| 1596 |
+
" WHERE id = '{anchor_id}'"
|
| 1597 |
+
")"
|
| 1598 |
+
" SELECT n.id, n.names.\"primary\" AS name, n.subtype,"
|
| 1599 |
+
" ST_AsGeoJSON(n.geometry) AS geometry"
|
| 1600 |
+
" FROM read_parquet('natural_earth') AS n, a"
|
| 1601 |
+
" WHERE ST_Intersects(n.geometry, a.geom)"
|
| 1602 |
+
),
|
| 1603 |
+
question_hints=[
|
| 1604 |
+
"Natural features within {buffer_km} km of the {anchor_name}",
|
| 1605 |
+
"What is within {buffer_km} km of the {anchor_name}?",
|
| 1606 |
+
"Geographic features near the {anchor_name} within {buffer_km} km",
|
| 1607 |
+
"Everything within {buffer_km} km of the {anchor_name}",
|
| 1608 |
+
"What natural features are close to the {anchor_name}?",
|
| 1609 |
+
"{buffer_km} km radius around the {anchor_name}",
|
| 1610 |
+
],
|
| 1611 |
+
),
|
| 1612 |
+
|
| 1613 |
+
]
|
| 1614 |
+
|
| 1615 |
+
|
| 1616 |
+
# ---------------------------------------------------------------------------
|
| 1617 |
+
# Helpers
|
| 1618 |
+
# ---------------------------------------------------------------------------
|
| 1619 |
+
|
| 1620 |
+
def get_templates_by_family(family: str) -> List[SQLTemplate]:
|
| 1621 |
+
"""Return all templates for a specific task family."""
|
| 1622 |
+
return [t for t in TEMPLATES if t.family == family]
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
+
def get_template_by_id(template_id: str) -> SQLTemplate:
|
| 1626 |
+
"""Return a template by its ID, raising ValueError if not found."""
|
| 1627 |
+
for t in TEMPLATES:
|
| 1628 |
+
if t.template_id == template_id:
|
| 1629 |
+
return t
|
| 1630 |
+
raise ValueError(f"Template '{template_id}' not found")
|
| 1631 |
+
|
| 1632 |
+
|
| 1633 |
+
if __name__ == "__main__":
|
| 1634 |
+
families: dict = {}
|
| 1635 |
+
for t in TEMPLATES:
|
| 1636 |
+
families[t.family] = families.get(t.family, 0) + 1
|
| 1637 |
+
|
| 1638 |
+
print("SQL Template Catalog")
|
| 1639 |
+
print("=" * 60)
|
| 1640 |
+
for family, count in sorted(families.items()):
|
| 1641 |
+
print(f"{family:20s}: {count:2d} templates")
|
| 1642 |
+
print(f"{'TOTAL':20s}: {len(TEMPLATES):2d} templates")
|
| 1643 |
+
|
| 1644 |
+
# Verify every template's final SELECT wraps geometry with ST_AsGeoJSON
|
| 1645 |
+
print()
|
| 1646 |
+
print("Geometry output check (all should show ST_AsGeoJSON)")
|
| 1647 |
+
print("=" * 60)
|
| 1648 |
+
for t in TEMPLATES:
|
| 1649 |
+
has_geojson = "ST_AsGeoJSON" in t.sql_template
|
| 1650 |
+
status = "OK" if has_geojson else "MISSING"
|
| 1651 |
+
print(f" {t.template_id:20s}: {status}")
|
dataset/scripts/validate_dataset.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validate and balance the generated dataset.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Loads all generated samples
|
| 6 |
+
2. Validates SQL executability
|
| 7 |
+
3. Checks candidate list quality
|
| 8 |
+
4. Balances across task families and difficulty
|
| 9 |
+
5. Removes duplicates
|
| 10 |
+
6. Generates dataset statistics
|
| 11 |
+
|
| 12 |
+
Output:
|
| 13 |
+
- output/dataset_validated.jsonl
|
| 14 |
+
- output/dataset_stats.json
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Dict, Any, Tuple
|
| 20 |
+
from collections import Counter
|
| 21 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 22 |
+
|
| 23 |
+
import duckdb
|
| 24 |
+
import pandas as pd
|
| 25 |
+
|
| 26 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_samples(jsonl_path: Path) -> List[Dict[str, Any]]:
|
| 30 |
+
"""Load samples from JSONL file."""
|
| 31 |
+
samples = []
|
| 32 |
+
with open(jsonl_path, 'r') as f:
|
| 33 |
+
for line in f:
|
| 34 |
+
samples.append(json.loads(line))
|
| 35 |
+
return samples
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _resolve_paths(sql: str) -> str:
|
| 39 |
+
"""Replace symbolic placeholder paths with actual runtime paths for execution."""
|
| 40 |
+
sql = sql.replace(
|
| 41 |
+
"read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')"
|
| 42 |
+
)
|
| 43 |
+
sql = sql.replace(
|
| 44 |
+
"read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
|
| 45 |
+
)
|
| 46 |
+
# Legacy fixed Docker paths from earlier dataset versions
|
| 47 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", DIVISIONS_AREA_PATH)
|
| 48 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", DIVISIONS_AREA_PATH)
|
| 49 |
+
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", NATURAL_EARTH_PATH)
|
| 50 |
+
return sql
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _to_symbolic_sql(sql: str) -> str:
|
| 54 |
+
"""Normalize any hardcoded or runtime paths back to symbolic names for storage."""
|
| 55 |
+
# Current local runtime paths
|
| 56 |
+
sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
|
| 57 |
+
sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
|
| 58 |
+
# Legacy Docker paths
|
| 59 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
|
| 60 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
|
| 61 |
+
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
|
| 62 |
+
return sql
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
|
| 66 |
+
"""Validate that SQL executes without error.
|
| 67 |
+
|
| 68 |
+
Resolves symbolic path placeholders to actual runtime paths before execution.
|
| 69 |
+
"""
|
| 70 |
+
try:
|
| 71 |
+
result = con.execute(_resolve_paths(sql)).fetchdf()
|
| 72 |
+
if result.empty:
|
| 73 |
+
return False, "Empty result"
|
| 74 |
+
return True, "OK"
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return False, str(e)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def validate_candidates(sample: Dict[str, Any]) -> tuple[bool, str]:
|
| 80 |
+
"""Validate candidate list quality."""
|
| 81 |
+
candidates = sample['candidates']
|
| 82 |
+
selected = sample['target']['selected_candidates']
|
| 83 |
+
|
| 84 |
+
# Check we have candidates
|
| 85 |
+
if not candidates:
|
| 86 |
+
return False, "No candidates"
|
| 87 |
+
|
| 88 |
+
# Check selected candidates exist
|
| 89 |
+
candidate_ids = {c['candidate_id'] for c in candidates}
|
| 90 |
+
for sel_id in selected:
|
| 91 |
+
if sel_id not in candidate_ids:
|
| 92 |
+
return False, f"Selected candidate {sel_id} not in candidate list"
|
| 93 |
+
|
| 94 |
+
# Check for duplicates
|
| 95 |
+
ids = [c['id'] for c in candidates]
|
| 96 |
+
if len(ids) != len(set(ids)):
|
| 97 |
+
return False, "Duplicate candidates"
|
| 98 |
+
|
| 99 |
+
return True, "OK"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def validate_sample(con: duckdb.DuckDBPyConnection, sample: Dict[str, Any]) -> tuple[bool, List[str]]:
|
| 103 |
+
"""Validate a single sample. Returns (is_valid, list_of_issues)."""
|
| 104 |
+
issues = []
|
| 105 |
+
|
| 106 |
+
# Skip SQL re-execution if already verified during generation
|
| 107 |
+
if not sample.get('metadata', {}).get('sql_verified', False):
|
| 108 |
+
sql_valid, sql_msg = validate_sql(con, sample['target']['sql'])
|
| 109 |
+
if not sql_valid:
|
| 110 |
+
issues.append(f"SQL: {sql_msg}")
|
| 111 |
+
|
| 112 |
+
# Validate candidates
|
| 113 |
+
cand_valid, cand_msg = validate_candidates(sample)
|
| 114 |
+
if not cand_valid:
|
| 115 |
+
issues.append(f"Candidates: {cand_msg}")
|
| 116 |
+
|
| 117 |
+
# Check question exists
|
| 118 |
+
if not sample.get('question') or len(sample['question'].strip()) == 0:
|
| 119 |
+
issues.append("Empty question")
|
| 120 |
+
|
| 121 |
+
return len(issues) == 0, issues
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]]:
|
| 125 |
+
"""Worker function for parallel validation. Returns (sample_id, is_valid, issues)."""
|
| 126 |
+
# Each worker creates its own DuckDB connection
|
| 127 |
+
con = duckdb.connect()
|
| 128 |
+
con.execute("SET enable_progress_bar=false")
|
| 129 |
+
con.execute("INSTALL spatial")
|
| 130 |
+
con.execute("LOAD spatial")
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
is_valid, issues = validate_sample(con, sample)
|
| 134 |
+
con.close()
|
| 135 |
+
if is_valid:
|
| 136 |
+
sample['target']['sql'] = _to_symbolic_sql(sample['target']['sql'])
|
| 137 |
+
return (sample['id'], is_valid, issues, sample if is_valid else None)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
con.close()
|
| 140 |
+
return (sample['id'], False, [f"Validation error: {str(e)}"], None)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def compute_statistics(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 144 |
+
"""Compute dataset statistics."""
|
| 145 |
+
|
| 146 |
+
stats = {
|
| 147 |
+
'total_samples': len(samples),
|
| 148 |
+
'task_families': {},
|
| 149 |
+
'sql_difficulty': {},
|
| 150 |
+
'grounding_difficulty': {},
|
| 151 |
+
'anchor_sources': {},
|
| 152 |
+
'avg_candidates_per_sample': 0,
|
| 153 |
+
'avg_question_length': 0,
|
| 154 |
+
'countries_covered': set(),
|
| 155 |
+
'subtypes_covered': set()
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
total_candidates = 0
|
| 159 |
+
total_question_length = 0
|
| 160 |
+
|
| 161 |
+
for sample in samples:
|
| 162 |
+
meta = sample['metadata']
|
| 163 |
+
|
| 164 |
+
# Count by family
|
| 165 |
+
family = meta['task_family']
|
| 166 |
+
stats['task_families'][family] = stats['task_families'].get(family, 0) + 1
|
| 167 |
+
|
| 168 |
+
# Count by SQL difficulty
|
| 169 |
+
sql_diff = meta['sql_difficulty']
|
| 170 |
+
stats['sql_difficulty'][sql_diff] = stats['sql_difficulty'].get(sql_diff, 0) + 1
|
| 171 |
+
|
| 172 |
+
# Count by grounding difficulty
|
| 173 |
+
ground_diff = meta['grounding_difficulty']
|
| 174 |
+
stats['grounding_difficulty'][ground_diff] = stats['grounding_difficulty'].get(ground_diff, 0) + 1
|
| 175 |
+
|
| 176 |
+
# Count by anchor source
|
| 177 |
+
anchor_src = meta['anchor_source']
|
| 178 |
+
stats['anchor_sources'][anchor_src] = stats['anchor_sources'].get(anchor_src, 0) + 1
|
| 179 |
+
|
| 180 |
+
# Candidates
|
| 181 |
+
total_candidates += len(sample['candidates'])
|
| 182 |
+
|
| 183 |
+
# Question length
|
| 184 |
+
total_question_length += len(sample['question'].split())
|
| 185 |
+
|
| 186 |
+
# Countries and subtypes (from selected/answer candidates only)
|
| 187 |
+
selected_ids = set(sample.get('target', {}).get('selected_candidates', []))
|
| 188 |
+
for cand in sample['candidates']:
|
| 189 |
+
if cand['candidate_id'] in selected_ids:
|
| 190 |
+
if cand.get('country'):
|
| 191 |
+
stats['countries_covered'].add(cand['country'])
|
| 192 |
+
if cand.get('subtype'):
|
| 193 |
+
stats['subtypes_covered'].add(cand['subtype'])
|
| 194 |
+
|
| 195 |
+
stats['avg_candidates_per_sample'] = total_candidates / len(samples) if samples else 0
|
| 196 |
+
stats['avg_question_length'] = total_question_length / len(samples) if samples else 0
|
| 197 |
+
stats['countries_covered'] = sorted(list(stats['countries_covered']))
|
| 198 |
+
stats['subtypes_covered'] = sorted(list(stats['subtypes_covered']))
|
| 199 |
+
|
| 200 |
+
return stats
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def main():
|
| 204 |
+
"""Validate and analyze dataset."""
|
| 205 |
+
|
| 206 |
+
script_dir = Path(__file__).parent
|
| 207 |
+
output_dir = script_dir.parent / "output"
|
| 208 |
+
|
| 209 |
+
raw_file = output_dir / "dataset_raw.jsonl"
|
| 210 |
+
validated_file = output_dir / "dataset_validated.jsonl"
|
| 211 |
+
stats_file = output_dir / "dataset_stats.json"
|
| 212 |
+
|
| 213 |
+
if not raw_file.exists():
|
| 214 |
+
print(f"Error: {raw_file} not found. Run generate_samples.py first.")
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
# Load samples
|
| 218 |
+
print("Loading samples...")
|
| 219 |
+
samples = load_samples(raw_file)
|
| 220 |
+
print(f"Loaded {len(samples)} samples")
|
| 221 |
+
|
| 222 |
+
# Validate samples in parallel
|
| 223 |
+
print("\nValidating samples in parallel...")
|
| 224 |
+
valid_samples = []
|
| 225 |
+
invalid_samples = []
|
| 226 |
+
|
| 227 |
+
with ProcessPoolExecutor(max_workers=8) as executor:
|
| 228 |
+
# Submit all validation tasks
|
| 229 |
+
futures = {executor.submit(validate_sample_worker, sample): sample for sample in samples}
|
| 230 |
+
|
| 231 |
+
# Collect results as they complete
|
| 232 |
+
completed = 0
|
| 233 |
+
for future in as_completed(futures):
|
| 234 |
+
sample_id, is_valid, issues, validated_sample = future.result()
|
| 235 |
+
|
| 236 |
+
if is_valid:
|
| 237 |
+
valid_samples.append(validated_sample)
|
| 238 |
+
else:
|
| 239 |
+
invalid_samples.append((sample_id, issues))
|
| 240 |
+
|
| 241 |
+
completed += 1
|
| 242 |
+
if completed % 50 == 0 or completed == len(samples):
|
| 243 |
+
print(f"\r Progress: {completed}/{len(samples)} ", end='', flush=True)
|
| 244 |
+
|
| 245 |
+
print() # New line after progress
|
| 246 |
+
|
| 247 |
+
print(f"\nValidation results:")
|
| 248 |
+
print(f" Valid: {len(valid_samples)}")
|
| 249 |
+
print(f" Invalid: {len(invalid_samples)}")
|
| 250 |
+
|
| 251 |
+
if invalid_samples and len(invalid_samples) <= 20:
|
| 252 |
+
print("\nInvalid samples:")
|
| 253 |
+
for sample_id, issues in invalid_samples[:20]:
|
| 254 |
+
print(f" {sample_id}: {', '.join(issues)}")
|
| 255 |
+
elif invalid_samples:
|
| 256 |
+
print(f"\n{len(invalid_samples)} invalid samples (showing first 20):")
|
| 257 |
+
for sample_id, issues in invalid_samples[:20]:
|
| 258 |
+
print(f" {sample_id}: {', '.join(issues)}")
|
| 259 |
+
|
| 260 |
+
# Save validated samples
|
| 261 |
+
if valid_samples:
|
| 262 |
+
with open(validated_file, 'w') as f:
|
| 263 |
+
for sample in valid_samples:
|
| 264 |
+
f.write(json.dumps(sample) + '\n')
|
| 265 |
+
print(f"\nSaved {len(valid_samples)} valid samples to {validated_file}")
|
| 266 |
+
|
| 267 |
+
# Compute statistics
|
| 268 |
+
print("\nComputing statistics...")
|
| 269 |
+
stats = compute_statistics(valid_samples)
|
| 270 |
+
|
| 271 |
+
# Save statistics
|
| 272 |
+
# Convert sets to lists for JSON serialization
|
| 273 |
+
stats_json = {k: (list(v) if isinstance(v, set) else v) for k, v in stats.items()}
|
| 274 |
+
with open(stats_file, 'w') as f:
|
| 275 |
+
json.dump(stats_json, f, indent=2)
|
| 276 |
+
print(f"Saved statistics to {stats_file}")
|
| 277 |
+
|
| 278 |
+
# Print summary
|
| 279 |
+
print("\n" + "=" * 60)
|
| 280 |
+
print("DATASET STATISTICS")
|
| 281 |
+
print("=" * 60)
|
| 282 |
+
print(f"\nTotal samples: {stats['total_samples']}")
|
| 283 |
+
|
| 284 |
+
print("\nTask families:")
|
| 285 |
+
for family, count in sorted(stats['task_families'].items()):
|
| 286 |
+
print(f" {family:20s}: {count:3d}")
|
| 287 |
+
|
| 288 |
+
print("\nSQL difficulty:")
|
| 289 |
+
for diff, count in sorted(stats['sql_difficulty'].items()):
|
| 290 |
+
print(f" {diff:20s}: {count:3d}")
|
| 291 |
+
|
| 292 |
+
print("\nGrounding difficulty:")
|
| 293 |
+
for diff, count in sorted(stats['grounding_difficulty'].items()):
|
| 294 |
+
print(f" {diff:20s}: {count:3d}")
|
| 295 |
+
|
| 296 |
+
print("\nAnchor sources:")
|
| 297 |
+
for src, count in sorted(stats['anchor_sources'].items()):
|
| 298 |
+
print(f" {src:20s}: {count:3d}")
|
| 299 |
+
|
| 300 |
+
print(f"\nAverage candidates per sample: {stats['avg_candidates_per_sample']:.1f}")
|
| 301 |
+
print(f"Average question length (words): {stats['avg_question_length']:.1f}")
|
| 302 |
+
print(f"Countries covered: {len(stats['countries_covered'])}")
|
| 303 |
+
print(f"Subtypes covered: {len(stats['subtypes_covered'])}")
|
| 304 |
+
|
| 305 |
+
print("\n✓ Validation complete")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
main()
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
llama:
|
| 3 |
+
image: ghcr.io/ggml-org/llama.cpp:server
|
| 4 |
+
volumes:
|
| 5 |
+
- ./finetune/models/qwen-base-run/ckpt-001.gguf:/models/model.gguf:ro
|
| 6 |
+
command: >
|
| 7 |
+
-m /models/model.gguf
|
| 8 |
+
--port 9000
|
| 9 |
+
--host 0.0.0.0
|
| 10 |
+
--ctx-size 2048
|
| 11 |
+
-t 4
|
| 12 |
+
healthcheck:
|
| 13 |
+
test: ["CMD", "curl", "-f", "http://localhost:9000/health"]
|
| 14 |
+
interval: 10s
|
| 15 |
+
timeout: 5s
|
| 16 |
+
retries: 30
|
| 17 |
+
start_period: 30s
|
| 18 |
+
|
| 19 |
+
app:
|
| 20 |
+
build: .
|
| 21 |
+
volumes:
|
| 22 |
+
- ./data:/data:ro
|
| 23 |
+
environment:
|
| 24 |
+
GAZET_DATA_DIR: /data
|
| 25 |
+
LLAMA_SERVER_URL: http://llama:9000
|
| 26 |
+
ports:
|
| 27 |
+
- "8000:8000"
|
| 28 |
+
command: uvicorn gazet.api:app --host 0.0.0.0 --port 8000
|
| 29 |
+
depends_on:
|
| 30 |
+
llama:
|
| 31 |
+
condition: service_healthy
|
| 32 |
+
|
| 33 |
+
demo:
|
| 34 |
+
build: .
|
| 35 |
+
environment:
|
| 36 |
+
GAZET_API_URL: http://app:8000
|
| 37 |
+
ports:
|
| 38 |
+
- "8501:8501"
|
| 39 |
+
command: streamlit run gazet_demo.py --server.port 8501 --server.address 0.0.0.0
|
| 40 |
+
depends_on:
|
| 41 |
+
- app
|
finetune/README.md
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-tuning and Inference
|
| 2 |
+
|
| 3 |
+
LoRA fine-tuning of Qwen3.5-0.8B (via Unsloth) to perform two geospatial
|
| 4 |
+
tasks (text-to-SQL and place extraction), then serving locally via
|
| 5 |
+
llama-server.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## End-to-end workflow
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
1. Generate dataset → dataset/ (see dataset/README.md)
|
| 13 |
+
2. Check token lengths → check_token_lengths.py
|
| 14 |
+
3. Train on Modal → train_modal_qwen35.py
|
| 15 |
+
4. Convert to GGUF → llama.cpp
|
| 16 |
+
5. Serve locally → llama-server
|
| 17 |
+
6. Eval locally → eval_cli.py (interactive or batch) + eval_demo.py
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Step 1 — Check token lengths
|
| 23 |
+
|
| 24 |
+
Before training, verify that your `max_length` setting covers the data.
|
| 25 |
+
SQL samples are long (schema + candidates + SQL), places samples are short.
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
modal run finetune/check_token_lengths.py
|
| 29 |
+
modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/v1
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
This prints per-split statistics (min, max, P95, P99) and recommends a
|
| 33 |
+
`max_length` value. Adjust `--max-seq-length` in `train_modal_qwen35.py` accordingly.
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Step 2 — Train (Qwen3.5 + Unsloth)
|
| 38 |
+
|
| 39 |
+
Training runs on Modal with an A100-80GB GPU. The script loads both SQL and
|
| 40 |
+
places JSONL files from the run directory, applies the Qwen3.5 ChatML
|
| 41 |
+
template, and trains a LoRA adapter using Unsloth's
|
| 42 |
+
`train_on_responses_only` to mask non-assistant tokens.
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
# Default settings (Qwen3.5-0.8B, r=16, 1 epoch)
|
| 46 |
+
modal run finetune/train_modal_qwen35.py --experiment-name qwen35-v1
|
| 47 |
+
|
| 48 |
+
# Override any config field from CLI
|
| 49 |
+
modal run finetune/train_modal_qwen35.py \
|
| 50 |
+
--experiment-name qwen35-v1 \
|
| 51 |
+
--base-model unsloth/Qwen3.5-0.8B \
|
| 52 |
+
--num-train-epochs 3 \
|
| 53 |
+
--lora-r 32 \
|
| 54 |
+
--max-seq-length 2048
|
| 55 |
+
|
| 56 |
+
# Quick smoke test
|
| 57 |
+
modal run finetune/train_modal_qwen35.py --experiment-name qwen35-v1 --max-train-samples 100
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
All CLI overrides: `--base-model`, `--experiment-name`, `--run-dir`,
|
| 61 |
+
`--num-train-epochs`, `--per-device-train-batch-size`, `--max-train-samples`,
|
| 62 |
+
`--max-eval-samples`, `--lora-r`, `--max-seq-length`. When `--lora-r` is
|
| 63 |
+
overridden, `lora_alpha` is automatically set to `2 * r`.
|
| 64 |
+
|
| 65 |
+
### Training config defaults (`Qwen35Config`)
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
base_model: unsloth/Qwen3.5-0.8B
|
| 69 |
+
run_dir: /mnt/gazet/data/v1
|
| 70 |
+
lora_r: 16
|
| 71 |
+
lora_alpha: 32 (2 * r, Unsloth recommendation for Qwen)
|
| 72 |
+
lora_dropout: 0.0
|
| 73 |
+
num_train_epochs: 1
|
| 74 |
+
batch_size: 32 (x 1 gradient accumulation = 32 effective)
|
| 75 |
+
learning_rate: 1e-4
|
| 76 |
+
lr_scheduler: linear
|
| 77 |
+
optim: adamw_8bit
|
| 78 |
+
max_seq_length: 2048
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Output
|
| 82 |
+
|
| 83 |
+
Checkpoints and the merged model are saved to the Modal volume:
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
/mnt/gazet/checkpoints/{experiment_name}/
|
| 87 |
+
adapter_config.json # LoRA adapter
|
| 88 |
+
adapter_model.safetensors
|
| 89 |
+
checkpoint-*/ # intermediate checkpoints
|
| 90 |
+
merged/ # full merged 16-bit model
|
| 91 |
+
model.safetensors
|
| 92 |
+
tokenizer.json
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Pass `--experiment-name` to set a human-readable name (e.g. `qwen35-v1`).
|
| 96 |
+
If omitted, it is auto-generated as `{model}-r{lora_r}-{timestamp}`.
|
| 97 |
+
|
| 98 |
+
Training metrics are logged to [trackio](https://huggingface.co/spaces/srmsoumya/gazet-trackio).
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Step 3 — Convert merged model to GGUF
|
| 103 |
+
|
| 104 |
+
After training, download the merged model from Modal and convert to GGUF
|
| 105 |
+
for local inference with llama-server.
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
# Download from Modal volume
|
| 109 |
+
modal volume get gazet checkpoints/qwen35-v1/merged ./finetune/models/merged
|
| 110 |
+
|
| 111 |
+
# Convert to GGUF (requires llama.cpp repo)
|
| 112 |
+
uv run \
|
| 113 |
+
--no-project \
|
| 114 |
+
--with transformers \
|
| 115 |
+
--with sentencepiece \
|
| 116 |
+
--with protobuf \
|
| 117 |
+
--with torch \
|
| 118 |
+
python convert_hf_to_gguf.py \
|
| 119 |
+
../gazet/finetune/models/qwen-base/merged \
|
| 120 |
+
--outtype q8_0 \
|
| 121 |
+
--outfile ../gazet/finetune/models/qwen-base/ckpt-001.gguf
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
## Step 4 — Serve with llama-server
|
| 127 |
+
|
| 128 |
+
### Local
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
llama-server \
|
| 132 |
+
-m finetune/models/qwen-base/ckpt-001.gguf \
|
| 133 |
+
-ngl 99 \
|
| 134 |
+
--port 9000 \
|
| 135 |
+
--ctx-size 2048
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
`--ctx-size` is the total KV cache shared across all parallel slots. SQL
|
| 139 |
+
prompts can be ~600 tokens; with `--parallel 4` and up to 2048 output
|
| 140 |
+
tokens, use at least `8192`. Match `--parallel` to `--workers` in
|
| 141 |
+
`eval_cli.py`.
|
| 142 |
+
|
| 143 |
+
### Docker (CPU-only)
|
| 144 |
+
|
| 145 |
+
Useful for testing inference in a constrained environment. Adjust `--cpus`
|
| 146 |
+
and `--memory` to simulate deployment targets. Set `-t` to match `--cpus`.
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
docker run \
|
| 150 |
+
--cpus="2" --memory="4g" \
|
| 151 |
+
-v $(pwd)/finetune/models:/models \
|
| 152 |
+
-p 9000:9000 \
|
| 153 |
+
ghcr.io/ggml-org/llama.cpp:server \
|
| 154 |
+
-m /models/qwen-base/ckpt-001.gguf \
|
| 155 |
+
--port 9000 --host 0.0.0.0 \
|
| 156 |
+
--ctx-size 2048 -t 2 -v
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
Notes:
|
| 160 |
+
- `--host 0.0.0.0` is required so the port forward from Docker works
|
| 161 |
+
- `-v` (verbose) enables per-request timing logs (prompt eval t/s, generation t/s)
|
| 162 |
+
- `-ngl` is omitted since the default Docker image is CPU-only; for GPU use
|
| 163 |
+
the CUDA image (`ghcr.io/ggml-org/llama.cpp:server-cuda`) with `--gpus`
|
| 164 |
+
- The model is memory-mapped by default (`mmap = true`), so containers with
|
| 165 |
+
less RAM than the model size may still start but will be slow due to page
|
| 166 |
+
thrashing
|
| 167 |
+
|
| 168 |
+
The server exposes `/v1/chat/completions` (chat API) on
|
| 169 |
+
`http://localhost:9000`. All eval scripts use this endpoint.
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
## Step 5 — Evaluate
|
| 174 |
+
|
| 175 |
+
Two evaluation tools, both using a locally running llama-server.
|
| 176 |
+
|
| 177 |
+
### Interactive or batch eval (`eval_cli.py`)
|
| 178 |
+
|
| 179 |
+
Requires llama-server running on port 9000 (see Step 4).
|
| 180 |
+
|
| 181 |
+
**Interactive** — spot-check individual samples:
|
| 182 |
+
|
| 183 |
+
```bash
|
| 184 |
+
uv run finetune/eval_cli.py # prompts for sample index
|
| 185 |
+
uv run finetune/eval_cli.py 0 5 12 # run specific samples
|
| 186 |
+
uv run finetune/eval_cli.py --task places 0 5
|
| 187 |
+
uv run finetune/eval_cli.py -v 0 # print full prompt
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
**Batch** — run the full split and save a JSON results file:
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
# Full val set, SQL task
|
| 194 |
+
uv run finetune/eval_cli.py --all --label finetuned-qwen35
|
| 195 |
+
|
| 196 |
+
# Places task
|
| 197 |
+
uv run finetune/eval_cli.py --all --task places --label finetuned-places
|
| 198 |
+
|
| 199 |
+
# Limit samples, custom output path
|
| 200 |
+
uv run finetune/eval_cli.py --all --max-samples 100 --output results/eval-v5.json
|
| 201 |
+
|
| 202 |
+
# Evaluate test split instead of val
|
| 203 |
+
uv run finetune/eval_cli.py --all --split test --label finetuned-qwen35
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
All batch CLI args:
|
| 207 |
+
|
| 208 |
+
| Arg | Default | Description |
|
| 209 |
+
|-----|---------|-------------|
|
| 210 |
+
| `--all` | off | Enable batch mode |
|
| 211 |
+
| `--label` | `local-gguf` | Label used in the output filename |
|
| 212 |
+
| `--task` | `sql` | `sql` or `places` |
|
| 213 |
+
| `--split` | `val` | Data split to evaluate (`val`, `test`) |
|
| 214 |
+
| `--run-dir` | `dataset/output/runs/v1` | Directory with `{task}/{split}.jsonl` |
|
| 215 |
+
| `--max-samples` | all | Cap the number of samples |
|
| 216 |
+
| `--output` | `eval-{label}-{task}.json` | Output JSON path |
|
| 217 |
+
| `--workers` | `4` | Concurrent requests; match llama-server `--parallel` |
|
| 218 |
+
|
| 219 |
+
Results are saved to `results/eval-{label}-{task}.json` with this structure:
|
| 220 |
+
|
| 221 |
+
```json
|
| 222 |
+
{
|
| 223 |
+
"summary": {"label": "...", "task": "sql", "exact_match_rate": 0.85, ...},
|
| 224 |
+
"results": [
|
| 225 |
+
{"index": 0, "question": "...", "expected": "...", "predicted": "...", "exact_match": true},
|
| 226 |
+
...
|
| 227 |
+
]
|
| 228 |
+
}
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
Config constants at the top of `eval_cli.py`: `SERVER_URL` (default
|
| 232 |
+
`http://localhost:9000`), `MAX_TOKENS` (2048), `TEMPERATURE` (0.6).
|
| 233 |
+
|
| 234 |
+
### Visual eval (`eval_demo.py`)
|
| 235 |
+
|
| 236 |
+
Streamlit app that loads JSON results from `eval_cli.py --all` and displays
|
| 237 |
+
them interactively. For SQL results, it shows formatted SQL side-by-side,
|
| 238 |
+
a diff view for mismatches, and executes both queries against DuckDB to
|
| 239 |
+
render the geometry on a map. For places results, it shows expected vs
|
| 240 |
+
predicted JSON.
|
| 241 |
+
|
| 242 |
+
```bash
|
| 243 |
+
streamlit run finetune/eval_demo.py
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
Reads result files from `results/eval-*.json` by default. Override with:
|
| 247 |
+
|
| 248 |
+
```bash
|
| 249 |
+
GAZET_EVAL_DIR=/path/to/results streamlit run finetune/eval_demo.py
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
Set `GAZET_DATA_DIR` if your parquet data is not in the default `data/` directory.
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
## File reference
|
| 257 |
+
|
| 258 |
+
| File | What it does |
|
| 259 |
+
|---|---|
|
| 260 |
+
| `train_modal_qwen35.py` | Modal training script — Qwen3.5 LoRA fine-tuning with Unsloth |
|
| 261 |
+
| `check_token_lengths.py` | Modal script to analyze token length distribution before training |
|
| 262 |
+
| `eval_cli.py` | Local eval — interactive spot-check or full batch mode via llama-server |
|
| 263 |
+
| `eval_demo.py` | Streamlit app — visual diff + map rendering of `eval_cli.py --all` results |
|
| 264 |
+
| `models/` | GGUF model files for local llama-server inference |
|
| 265 |
+
|
| 266 |
+
---
|
| 267 |
+
|
| 268 |
+
## Data format
|
| 269 |
+
|
| 270 |
+
The Qwen3.5 training pipeline (`train_modal_qwen35.py`) expects data in
|
| 271 |
+
**messages format**:
|
| 272 |
+
|
| 273 |
+
```json
|
| 274 |
+
{
|
| 275 |
+
"messages": [
|
| 276 |
+
{"role": "system", "content": "You are a text to SQL query translator..."},
|
| 277 |
+
{"role": "user", "content": "GIVEN the <SCHEMA_DETAILS>..."},
|
| 278 |
+
{"role": "assistant", "content": "SELECT ST_AsGeoJSON(geometry) ..."}
|
| 279 |
+
]
|
| 280 |
+
}
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
The Qwen3.5 chat template (ChatML) is applied by the tokenizer. Unsloth's
|
| 284 |
+
`train_on_responses_only` then masks everything before the assistant
|
| 285 |
+
response marker (`<|im_start|>assistant\n<think>\n\n</think>\n\n`), so
|
| 286 |
+
loss is computed only on the completion tokens.
|
| 287 |
+
|
| 288 |
+
SQL in the training data uses symbolic path placeholders
|
| 289 |
+
(`read_parquet('divisions_area')`) instead of real file paths. At inference
|
| 290 |
+
time, `src/gazet/sql.py` replaces these with actual runtime paths before
|
| 291 |
+
executing against DuckDB.
|
finetune/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
finetune/check_token_lengths.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Check token lengths of training samples to validate max_length setting.
|
| 2 |
+
|
| 3 |
+
Usage
|
| 4 |
+
-----
|
| 5 |
+
modal run finetune/check_token_lengths.py
|
| 6 |
+
modal run finetune/check_token_lengths.py --run-dir /mnt/gazet/data/v1
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import modal
|
| 12 |
+
|
| 13 |
+
app = modal.App("gazet-check-token-lengths")
|
| 14 |
+
|
| 15 |
+
check_image = (
|
| 16 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 17 |
+
.pip_install(
|
| 18 |
+
"datasets>=3.0",
|
| 19 |
+
"transformers>=4.46",
|
| 20 |
+
"jinja2>=3.1",
|
| 21 |
+
)
|
| 22 |
+
.add_local_python_source("finetune", copy=True)
|
| 23 |
+
.env({"HF_HOME": "/mnt/gazet/model_cache"})
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
|
| 27 |
+
|
| 28 |
+
VOLUMES = {
|
| 29 |
+
"/mnt/gazet": gazet_vol,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.function(
|
| 34 |
+
image=check_image,
|
| 35 |
+
volumes=VOLUMES,
|
| 36 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 37 |
+
)
|
| 38 |
+
def analyze_token_lengths(run_dir: str, base_model: str):
|
| 39 |
+
import json
|
| 40 |
+
import pathlib
|
| 41 |
+
|
| 42 |
+
from datasets import Dataset, DatasetDict
|
| 43 |
+
from transformers import AutoTokenizer
|
| 44 |
+
|
| 45 |
+
def load_jsonl(path):
|
| 46 |
+
rows = []
|
| 47 |
+
with open(path) as f:
|
| 48 |
+
for line in f:
|
| 49 |
+
line = line.strip()
|
| 50 |
+
if line:
|
| 51 |
+
rows.append(json.loads(line))
|
| 52 |
+
return rows
|
| 53 |
+
|
| 54 |
+
print(f"Loading tokenizer: {base_model}")
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 56 |
+
|
| 57 |
+
root = pathlib.Path(run_dir)
|
| 58 |
+
ds_dict = {}
|
| 59 |
+
for split in ("train", "val", "test"):
|
| 60 |
+
combined = []
|
| 61 |
+
for task in ("sql", "places"):
|
| 62 |
+
path = root / task / f"{split}.jsonl"
|
| 63 |
+
if path.exists():
|
| 64 |
+
combined.extend(load_jsonl(path))
|
| 65 |
+
if combined:
|
| 66 |
+
ds_dict[split] = Dataset.from_list(combined)
|
| 67 |
+
ds = DatasetDict(ds_dict)
|
| 68 |
+
|
| 69 |
+
def token_lengths(dataset):
|
| 70 |
+
lengths = []
|
| 71 |
+
for row in dataset:
|
| 72 |
+
msgs = row["messages"]
|
| 73 |
+
text = tokenizer.apply_chat_template(msgs, tokenize=False)
|
| 74 |
+
lengths.append(len(tokenizer.encode(text)))
|
| 75 |
+
return lengths
|
| 76 |
+
|
| 77 |
+
def report(split_name: str, lengths: list[int]):
|
| 78 |
+
lengths.sort()
|
| 79 |
+
n = len(lengths)
|
| 80 |
+
if not n:
|
| 81 |
+
print(f"\n{split_name}: empty")
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
print(f"\n{'='*60}")
|
| 85 |
+
print(f"{split_name} ({n:,} samples)")
|
| 86 |
+
print(f"{'='*60}")
|
| 87 |
+
print(f" Min: {min(lengths):,}")
|
| 88 |
+
print(f" Max: {max(lengths):,}")
|
| 89 |
+
print(f" Mean: {sum(lengths)/n:.0f}")
|
| 90 |
+
print(f" Median: {lengths[n//2]:,}")
|
| 91 |
+
print(f" P90: {lengths[int(n*0.90)]:,}")
|
| 92 |
+
print(f" P95: {lengths[int(n*0.95)]:,}")
|
| 93 |
+
print(f" P99: {lengths[int(n*0.99)]:,}")
|
| 94 |
+
|
| 95 |
+
buckets = [512, 1024, 2048, 4096, 8192]
|
| 96 |
+
print(f"\n Distribution:")
|
| 97 |
+
prev = 0
|
| 98 |
+
for limit in buckets:
|
| 99 |
+
count = sum(1 for l in lengths if prev < l <= limit)
|
| 100 |
+
pct = 100 * count / n
|
| 101 |
+
bar = "#" * int(pct / 2)
|
| 102 |
+
print(f" {prev+1:>5}-{limit:<5}: {count:5,} ({pct:5.1f}%) {bar}")
|
| 103 |
+
prev = limit
|
| 104 |
+
over = sum(1 for l in lengths if l > buckets[-1])
|
| 105 |
+
if over:
|
| 106 |
+
print(f" {buckets[-1]+1:>5}+ : {over:5,} ({100*over/n:5.1f}%)")
|
| 107 |
+
|
| 108 |
+
return lengths
|
| 109 |
+
|
| 110 |
+
all_lengths = []
|
| 111 |
+
for split in ("train", "val", "test"):
|
| 112 |
+
if split not in ds:
|
| 113 |
+
continue
|
| 114 |
+
lengths = token_lengths(ds[split])
|
| 115 |
+
report(split, lengths)
|
| 116 |
+
all_lengths.extend(lengths)
|
| 117 |
+
|
| 118 |
+
if all_lengths:
|
| 119 |
+
all_lengths.sort()
|
| 120 |
+
n = len(all_lengths)
|
| 121 |
+
max_len = max(all_lengths)
|
| 122 |
+
p99 = all_lengths[int(n * 0.99)]
|
| 123 |
+
|
| 124 |
+
print(f"\n{'='*60}")
|
| 125 |
+
print(f"RECOMMENDATION")
|
| 126 |
+
print(f"{'='*60}")
|
| 127 |
+
print(f" Total samples: {n:,}")
|
| 128 |
+
print(f" Max length: {max_len:,}")
|
| 129 |
+
print(f" P99: {p99:,}")
|
| 130 |
+
|
| 131 |
+
for threshold in [1024, 2048, 4096]:
|
| 132 |
+
over = sum(1 for l in all_lengths if l > threshold)
|
| 133 |
+
pct = 100 * over / n
|
| 134 |
+
print(f" > {threshold:5,}: {over:5,} ({pct:5.1f}%)")
|
| 135 |
+
|
| 136 |
+
if max_len <= 1024:
|
| 137 |
+
print(f"\n All samples fit in 1024 tokens. Use --max-length 1024.")
|
| 138 |
+
elif max_len <= 2048:
|
| 139 |
+
print(f"\n All samples fit in 2048 tokens. Use --max-length 2048.")
|
| 140 |
+
else:
|
| 141 |
+
over_2048 = sum(1 for l in all_lengths if l > 2048)
|
| 142 |
+
print(f"\n {over_2048} samples exceed 2048. Consider --max-length {max_len}")
|
| 143 |
+
print(f" or reduce candidate count to keep samples shorter.")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@app.local_entrypoint()
|
| 147 |
+
def main(
|
| 148 |
+
run_dir: str = "/mnt/gazet/data/v1",
|
| 149 |
+
base_model: str = "unsloth/Qwen3.5-0.8B",
|
| 150 |
+
):
|
| 151 |
+
print(f"Checking token lengths:")
|
| 152 |
+
print(f" Model: {base_model}")
|
| 153 |
+
print(f" Run dir: {run_dir}")
|
| 154 |
+
analyze_token_lengths.remote(run_dir, base_model)
|
| 155 |
+
print("Analysis complete!")
|
finetune/eval_cli.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Interactive eval: run test samples through the local GGUF model.
|
| 2 |
+
|
| 3 |
+
Requires llama-server running on port 8080:
|
| 4 |
+
llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 4096 --log-disable
|
| 5 |
+
|
| 6 |
+
Uses the /v1/chat/completions endpoint with a messages list. The Qwen3 GGUF
|
| 7 |
+
embeds its chat template in metadata, so llama-server applies it automatically.
|
| 8 |
+
|
| 9 |
+
Usage
|
| 10 |
+
-----
|
| 11 |
+
uv run finetune/eval_cli.py # prompts for index
|
| 12 |
+
uv run finetune/eval_cli.py 5 # run sample at index 5
|
| 13 |
+
uv run finetune/eval_cli.py 5 12 20 # run multiple samples
|
| 14 |
+
|
| 15 |
+
Use --task places for place extraction:
|
| 16 |
+
uv run finetune/eval_cli.py --task places 0 5
|
| 17 |
+
|
| 18 |
+
Override run directory:
|
| 19 |
+
uv run finetune/eval_cli.py --run-dir dataset/output/runs/v1 0
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import sys
|
| 27 |
+
import urllib.error
|
| 28 |
+
import urllib.request
|
| 29 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 30 |
+
from datetime import datetime
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
SERVER_URL = "http://localhost:9000"
|
| 34 |
+
MAX_TOKENS = 2048
|
| 35 |
+
TEMPERATURE = 0.6
|
| 36 |
+
|
| 37 |
+
DEFAULT_RUN_DIR = Path("dataset/output/runs/v1")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def postprocess_sql(text: str) -> str:
|
| 41 |
+
cleaned = text.strip()
|
| 42 |
+
if "```sql" in cleaned:
|
| 43 |
+
cleaned = cleaned.split("```sql", 1)[1]
|
| 44 |
+
if cleaned.startswith("```"):
|
| 45 |
+
cleaned = cleaned[3:]
|
| 46 |
+
if "```" in cleaned:
|
| 47 |
+
cleaned = cleaned.split("```", 1)[0]
|
| 48 |
+
return cleaned.strip()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def check_server() -> bool:
|
| 52 |
+
try:
|
| 53 |
+
urllib.request.urlopen(f"{SERVER_URL}/health", timeout=2)
|
| 54 |
+
return True
|
| 55 |
+
except Exception:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def chat_complete(messages: list[dict]) -> str:
|
| 60 |
+
"""Call llama-server /v1/chat/completions with a messages list."""
|
| 61 |
+
payload = json.dumps({
|
| 62 |
+
"messages": messages,
|
| 63 |
+
"n_predict": MAX_TOKENS,
|
| 64 |
+
"temperature": TEMPERATURE,
|
| 65 |
+
"chat_template_kwargs": {"enable_thinking": False},
|
| 66 |
+
}).encode()
|
| 67 |
+
|
| 68 |
+
req = urllib.request.Request(
|
| 69 |
+
f"{SERVER_URL}/v1/chat/completions",
|
| 70 |
+
data=payload,
|
| 71 |
+
headers={"Content-Type": "application/json"},
|
| 72 |
+
)
|
| 73 |
+
with urllib.request.urlopen(req, timeout=60) as resp:
|
| 74 |
+
return json.loads(resp.read())["choices"][0]["message"]["content"]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_samples(run_dir: Path, task: str, split: str = "val") -> list[dict]:
|
| 78 |
+
path = run_dir / task / f"{split}.jsonl"
|
| 79 |
+
if not path.exists():
|
| 80 |
+
print(f"Error: {path} not found")
|
| 81 |
+
sys.exit(1)
|
| 82 |
+
print(f"Loading {task} samples from: {path}")
|
| 83 |
+
with path.open() as f:
|
| 84 |
+
return [json.loads(line) for line in f if line.strip()]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def build_raw_prompt(sample: dict) -> str:
|
| 88 |
+
"""Reconstruct the plain prompt string from messages format (all turns except assistant)."""
|
| 89 |
+
return "\n\n".join(m["content"] for m in sample["messages"][:-1])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def eval_sample(sample: dict, task: str) -> dict:
|
| 93 |
+
"""Run a single sample through the server and return a result dict."""
|
| 94 |
+
expected = sample["messages"][-1]["content"]
|
| 95 |
+
messages = sample["messages"][:-1]
|
| 96 |
+
|
| 97 |
+
user_content = sample["messages"][-2]["content"]
|
| 98 |
+
if "<USER_QUERY>" in user_content:
|
| 99 |
+
question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
|
| 100 |
+
else:
|
| 101 |
+
question = user_content[:120]
|
| 102 |
+
|
| 103 |
+
raw = chat_complete(messages)
|
| 104 |
+
predicted = postprocess_sql(raw) if task == "sql" else raw.strip()
|
| 105 |
+
return {
|
| 106 |
+
"question": question,
|
| 107 |
+
"expected": expected,
|
| 108 |
+
"predicted": predicted,
|
| 109 |
+
"exact_match": predicted.strip() == expected.strip(),
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
|
| 114 |
+
user_content = sample["messages"][-2]["content"]
|
| 115 |
+
if "<USER_QUERY>" in user_content:
|
| 116 |
+
question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
|
| 117 |
+
else:
|
| 118 |
+
question = user_content[:120]
|
| 119 |
+
|
| 120 |
+
header = f" Sample {index}/{total-1} | {task} "
|
| 121 |
+
print(f"\n{'━' * 60}")
|
| 122 |
+
print(f"{'━' * ((60 - len(header)) // 2)}{header}{'━' * ((60 - len(header)) // 2)}")
|
| 123 |
+
print(f"{'━' * 60}")
|
| 124 |
+
print(f"\nQuestion: {question}\n")
|
| 125 |
+
|
| 126 |
+
if verbose:
|
| 127 |
+
prompt = build_raw_prompt(sample)
|
| 128 |
+
print(f"{'─' * 60}")
|
| 129 |
+
print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())} words):")
|
| 130 |
+
print(f"{'─' * 60}")
|
| 131 |
+
print(prompt)
|
| 132 |
+
|
| 133 |
+
result = eval_sample(sample, task)
|
| 134 |
+
|
| 135 |
+
print(f"{'─' * 60}")
|
| 136 |
+
print("Expected:")
|
| 137 |
+
print(f"{'─' * 60}")
|
| 138 |
+
print(result["expected"])
|
| 139 |
+
|
| 140 |
+
print(f"\n{'─' * 60}")
|
| 141 |
+
print("Generated:")
|
| 142 |
+
print(f"{'─' * 60}")
|
| 143 |
+
print(result["predicted"])
|
| 144 |
+
|
| 145 |
+
print(f"\n{'─' * 60}")
|
| 146 |
+
print(f"Match: {'YES' if result['exact_match'] else 'NO'}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def run_batch(
|
| 150 |
+
samples: list[dict],
|
| 151 |
+
task: str,
|
| 152 |
+
label: str,
|
| 153 |
+
output_path: Path,
|
| 154 |
+
workers: int = 8,
|
| 155 |
+
) -> None:
|
| 156 |
+
"""Run all samples concurrently and save results to a JSON file."""
|
| 157 |
+
total = len(samples)
|
| 158 |
+
results = [None] * total
|
| 159 |
+
completed = 0
|
| 160 |
+
|
| 161 |
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
| 162 |
+
futures = {executor.submit(eval_sample, s, task): i for i, s in enumerate(samples)}
|
| 163 |
+
for future in as_completed(futures):
|
| 164 |
+
i = futures[future]
|
| 165 |
+
result = future.result()
|
| 166 |
+
results[i] = {"index": i, **result}
|
| 167 |
+
completed += 1
|
| 168 |
+
if completed % 50 == 0 or completed == total:
|
| 169 |
+
print(f"{completed}/{total} done", flush=True)
|
| 170 |
+
|
| 171 |
+
matches = sum(1 for r in results if r["exact_match"])
|
| 172 |
+
exact_match_rate = matches / total if total else 0
|
| 173 |
+
|
| 174 |
+
output = {
|
| 175 |
+
"summary": {
|
| 176 |
+
"label": label,
|
| 177 |
+
"task": task,
|
| 178 |
+
"num_samples": total,
|
| 179 |
+
"exact_matches": matches,
|
| 180 |
+
"exact_match_rate": exact_match_rate,
|
| 181 |
+
"timestamp": datetime.now().isoformat(),
|
| 182 |
+
},
|
| 183 |
+
"results": results,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 187 |
+
with output_path.open("w") as f:
|
| 188 |
+
json.dump(output, f, indent=2)
|
| 189 |
+
|
| 190 |
+
print(f"\n{'=' * 60}")
|
| 191 |
+
print(f"[{label}] {matches}/{total} exact matches ({100 * exact_match_rate:.1f}%)")
|
| 192 |
+
print(f"Results saved to {output_path}")
|
| 193 |
+
print(f"{'=' * 60}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def main() -> None:
|
| 197 |
+
parser = argparse.ArgumentParser(description="Interactive eval against llama-server")
|
| 198 |
+
parser.add_argument("indices", nargs="*", type=int, help="Sample indices to evaluate")
|
| 199 |
+
parser.add_argument("--task", default="sql", choices=["sql", "places"])
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--run-dir",
|
| 202 |
+
type=Path,
|
| 203 |
+
default=DEFAULT_RUN_DIR,
|
| 204 |
+
help="Run directory containing {task}/{split}.jsonl files",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument("--split", default="val", choices=["val", "test"], help="Dataset split")
|
| 207 |
+
parser.add_argument("--verbose", "-v", action="store_true", help="Print full prompt sent to the model")
|
| 208 |
+
parser.add_argument("--all", dest="run_all", action="store_true", help="Run all samples in batch mode")
|
| 209 |
+
parser.add_argument("--max-samples", type=int, default=None, help="Limit number of samples (batch mode)")
|
| 210 |
+
parser.add_argument("--label", default="local-gguf", help="Label for batch output file")
|
| 211 |
+
parser.add_argument("--output", type=Path, default=None, help="Output JSON path (batch mode)")
|
| 212 |
+
parser.add_argument("--workers", type=int, default=4, help="Concurrent requests; match llama-server --parallel (default 4)")
|
| 213 |
+
args = parser.parse_args()
|
| 214 |
+
|
| 215 |
+
if not check_server():
|
| 216 |
+
print("llama-server not running. Start it with:")
|
| 217 |
+
print("llama-server -m finetune/models/<model>.gguf -ngl 99 --port 9000 --ctx-size 2048 --log-disable")
|
| 218 |
+
sys.exit(1)
|
| 219 |
+
|
| 220 |
+
samples = load_samples(args.run_dir, args.task, args.split)
|
| 221 |
+
total = len(samples)
|
| 222 |
+
|
| 223 |
+
if args.run_all:
|
| 224 |
+
if args.max_samples:
|
| 225 |
+
samples = samples[: args.max_samples]
|
| 226 |
+
output_path = args.output or Path(f"eval-{args.label}-{args.task}.json")
|
| 227 |
+
print(f"Running batch eval: {len(samples)} samples, {args.workers} workers")
|
| 228 |
+
run_batch(samples, args.task, args.label, output_path, workers=args.workers)
|
| 229 |
+
return
|
| 230 |
+
|
| 231 |
+
if not args.indices:
|
| 232 |
+
print(f"Test set has {total} {args.task} samples (0-{total-1})")
|
| 233 |
+
raw = input("Enter index (or press Enter for 0): ").strip()
|
| 234 |
+
indices = [int(raw) if raw else 0]
|
| 235 |
+
else:
|
| 236 |
+
indices = args.indices
|
| 237 |
+
|
| 238 |
+
for idx in indices:
|
| 239 |
+
if not (0 <= idx < total):
|
| 240 |
+
print(f"Index {idx} out of range (0-{total-1}), skipping")
|
| 241 |
+
continue
|
| 242 |
+
run_sample(samples[idx], args.task, total, idx, verbose=args.verbose)
|
| 243 |
+
|
| 244 |
+
print(f"\n{'━' * 60}\n")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
main()
|
finetune/eval_demo.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit eval viewer: compare expected vs predicted SQL and view results on a map.
|
| 2 |
+
|
| 3 |
+
Usage: streamlit run finetune/eval_demo.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import difflib
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import pathlib
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import pydeck as pdk
|
| 16 |
+
import sqlparse
|
| 17 |
+
import streamlit as st
|
| 18 |
+
|
| 19 |
+
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
| 20 |
+
DATA_DIR = pathlib.Path(
|
| 21 |
+
os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
|
| 22 |
+
)
|
| 23 |
+
EVAL_DIR = pathlib.Path(
|
| 24 |
+
os.environ.get("GAZET_EVAL_DIR", str(PROJECT_ROOT / "results"))
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_eval_results(path):
|
| 29 |
+
with open(path) as f:
|
| 30 |
+
return json.load(f)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def rewrite_data_paths(sql):
|
| 34 |
+
"""Replace symbolic and legacy paths with actual local data paths."""
|
| 35 |
+
# Legacy fixed Docker paths must be replaced first to avoid double-expansion
|
| 36 |
+
sql = sql.replace("/data/", f"{DATA_DIR}/")
|
| 37 |
+
div_path = str(DATA_DIR / "overture" / "divisions_area" / "*.parquet")
|
| 38 |
+
ne_path = str(DATA_DIR / "natural_earth_geoparquet" / "ne_geography.parquet")
|
| 39 |
+
sql = sql.replace("read_parquet('divisions_area')", f"read_parquet('{div_path}')")
|
| 40 |
+
sql = sql.replace("read_parquet('natural_earth')", f"read_parquet('{ne_path}')")
|
| 41 |
+
return sql
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def format_sql(sql):
|
| 45 |
+
"""Pretty-print SQL with sqlparse."""
|
| 46 |
+
return sqlparse.format(sql, reindent=True, keyword_case="upper")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def sql_diff_html(expected, predicted):
|
| 50 |
+
"""Return an HTML diff of two SQL strings."""
|
| 51 |
+
expected_lines = format_sql(expected).splitlines()
|
| 52 |
+
predicted_lines = format_sql(predicted).splitlines()
|
| 53 |
+
diff = difflib.HtmlDiff(tabsize=2, wrapcolumn=80)
|
| 54 |
+
return diff.make_table(
|
| 55 |
+
expected_lines, predicted_lines,
|
| 56 |
+
fromdesc="Expected", todesc="Predicted",
|
| 57 |
+
context=False,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_duckdb_connection():
|
| 62 |
+
con = duckdb.connect()
|
| 63 |
+
con.execute("INSTALL spatial")
|
| 64 |
+
con.execute("LOAD spatial")
|
| 65 |
+
return con
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def execute_sql(con, sql):
|
| 69 |
+
"""Execute SQL, converting geometry columns to simplified GeoJSON strings."""
|
| 70 |
+
rel = con.sql(sql)
|
| 71 |
+
cols = rel.columns
|
| 72 |
+
types = [str(t) for t in rel.dtypes]
|
| 73 |
+
|
| 74 |
+
select_parts = []
|
| 75 |
+
for col, dtype in zip(cols, types):
|
| 76 |
+
if "GEOMETRY" in dtype.upper():
|
| 77 |
+
select_parts.append(
|
| 78 |
+
f'ST_AsGeoJSON(ST_SimplifyPreserveTopology("{col}", 0.001)) AS "{col}"'
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
select_parts.append(f'"{col}"')
|
| 82 |
+
|
| 83 |
+
wrapped = f"SELECT {', '.join(select_parts)} FROM ({sql})"
|
| 84 |
+
return con.execute(wrapped).fetchdf()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _is_notna(val):
|
| 88 |
+
"""Check if a value is not NA, handling arrays/lists/numpy arrays safely."""
|
| 89 |
+
if isinstance(val, (list, tuple, np.ndarray)):
|
| 90 |
+
return len(val) > 0
|
| 91 |
+
return pd.notna(val)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _to_python(val):
|
| 95 |
+
"""Convert numpy/pandas types to native Python for JSON serialization."""
|
| 96 |
+
if isinstance(val, (np.integer,)):
|
| 97 |
+
return int(val)
|
| 98 |
+
if isinstance(val, (np.floating,)):
|
| 99 |
+
return float(val)
|
| 100 |
+
if isinstance(val, np.ndarray):
|
| 101 |
+
return val.tolist()
|
| 102 |
+
if isinstance(val, (np.bool_,)):
|
| 103 |
+
return bool(val)
|
| 104 |
+
return val
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def to_feature_collection(result_df):
|
| 108 |
+
"""Build GeoJSON FeatureCollection from a DataFrame with GeoJSON string columns."""
|
| 109 |
+
geom_cols = []
|
| 110 |
+
for c in result_df.columns:
|
| 111 |
+
vals = [v for v in result_df[c].head(5) if isinstance(v, str)]
|
| 112 |
+
if vals and all(v.lstrip().startswith('{"type":') for v in vals):
|
| 113 |
+
geom_cols.append(c)
|
| 114 |
+
|
| 115 |
+
prop_cols = [c for c in result_df.columns if c not in geom_cols]
|
| 116 |
+
features = []
|
| 117 |
+
for _, row in result_df.iterrows():
|
| 118 |
+
geometry = None
|
| 119 |
+
if geom_cols:
|
| 120 |
+
raw = row[geom_cols[0]]
|
| 121 |
+
if raw and isinstance(raw, str):
|
| 122 |
+
geometry = json.loads(raw)
|
| 123 |
+
properties = {}
|
| 124 |
+
for c in prop_cols:
|
| 125 |
+
val = row[c]
|
| 126 |
+
if _is_notna(val):
|
| 127 |
+
properties[c] = _to_python(val)
|
| 128 |
+
features.append(
|
| 129 |
+
{"type": "Feature", "geometry": geometry, "properties": properties}
|
| 130 |
+
)
|
| 131 |
+
return {"type": "FeatureCollection", "features": features}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def bbox_from_geojson(geojson):
|
| 135 |
+
lngs, lats = [], []
|
| 136 |
+
for f in geojson.get("features", []):
|
| 137 |
+
geom = f.get("geometry")
|
| 138 |
+
if geom:
|
| 139 |
+
for coord in _extract_coords(geom):
|
| 140 |
+
lngs.append(coord[0])
|
| 141 |
+
lats.append(coord[1])
|
| 142 |
+
if not lngs:
|
| 143 |
+
return None
|
| 144 |
+
return min(lngs), min(lats), max(lngs), max(lats)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _extract_coords(geom):
|
| 148 |
+
t = geom.get("type", "")
|
| 149 |
+
coords = geom.get("coordinates", [])
|
| 150 |
+
if t == "Point":
|
| 151 |
+
yield coords
|
| 152 |
+
elif t in ("LineString", "MultiPoint"):
|
| 153 |
+
yield from coords
|
| 154 |
+
elif t == "Polygon":
|
| 155 |
+
for ring in coords:
|
| 156 |
+
yield from ring
|
| 157 |
+
elif t in ("MultiLineString", "MultiPolygon"):
|
| 158 |
+
for part in coords:
|
| 159 |
+
if t == "MultiLineString":
|
| 160 |
+
yield from part
|
| 161 |
+
else:
|
| 162 |
+
for ring in part:
|
| 163 |
+
yield from ring
|
| 164 |
+
elif t == "GeometryCollection":
|
| 165 |
+
for g in geom.get("geometries", []):
|
| 166 |
+
yield from _extract_coords(g)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _centroids_from_geojson(geojson):
|
| 170 |
+
"""Extract centroid [lng, lat] for each feature to use as scatter markers."""
|
| 171 |
+
centroids = []
|
| 172 |
+
for f in geojson.get("features", []):
|
| 173 |
+
geom = f.get("geometry")
|
| 174 |
+
if not geom:
|
| 175 |
+
continue
|
| 176 |
+
lngs, lats = [], []
|
| 177 |
+
for coord in _extract_coords(geom):
|
| 178 |
+
lngs.append(coord[0])
|
| 179 |
+
lats.append(coord[1])
|
| 180 |
+
if lngs:
|
| 181 |
+
centroids.append({"lng": sum(lngs) / len(lngs), "lat": sum(lats) / len(lats)})
|
| 182 |
+
return centroids
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def render_map(geojson, color, key):
|
| 186 |
+
n = len(geojson.get("features", []))
|
| 187 |
+
if not n:
|
| 188 |
+
st.info("Query returned no features.")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
layers = [
|
| 192 |
+
pdk.Layer(
|
| 193 |
+
"GeoJsonLayer",
|
| 194 |
+
data=geojson,
|
| 195 |
+
get_fill_color=color,
|
| 196 |
+
get_line_color=[100, 100, 100, 200],
|
| 197 |
+
get_line_width=2,
|
| 198 |
+
pickable=True,
|
| 199 |
+
),
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
bbox = bbox_from_geojson(geojson)
|
| 203 |
+
if bbox:
|
| 204 |
+
min_lng, min_lat, max_lng, max_lat = bbox
|
| 205 |
+
span = max(max_lng - min_lng, max_lat - min_lat, 1e-6)
|
| 206 |
+
zoom = max(0, min(18, math.log2(360 / span) - 0.8))
|
| 207 |
+
|
| 208 |
+
# Add scatter markers when polygons would be too small to see
|
| 209 |
+
if zoom < 4:
|
| 210 |
+
centroids = _centroids_from_geojson(geojson)
|
| 211 |
+
if centroids:
|
| 212 |
+
layers.append(
|
| 213 |
+
pdk.Layer(
|
| 214 |
+
"ScatterplotLayer",
|
| 215 |
+
data=centroids,
|
| 216 |
+
get_position=["lng", "lat"],
|
| 217 |
+
get_fill_color=color[:3] + [220],
|
| 218 |
+
get_radius=50000,
|
| 219 |
+
radius_min_pixels=6,
|
| 220 |
+
pickable=True,
|
| 221 |
+
)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
view = pdk.ViewState(
|
| 225 |
+
latitude=(min_lat + max_lat) / 2,
|
| 226 |
+
longitude=(min_lng + max_lng) / 2,
|
| 227 |
+
zoom=zoom,
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
view = pdk.ViewState(latitude=0, longitude=0, zoom=1)
|
| 231 |
+
|
| 232 |
+
st.pydeck_chart(
|
| 233 |
+
pdk.Deck(layers=layers, initial_view_state=view, map_style=None),
|
| 234 |
+
width="stretch",
|
| 235 |
+
height=400,
|
| 236 |
+
key=key,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# --- App ---
|
| 241 |
+
|
| 242 |
+
st.set_page_config(page_title="Eval Viewer", layout="wide")
|
| 243 |
+
st.title("Eval Viewer")
|
| 244 |
+
|
| 245 |
+
eval_files = sorted(EVAL_DIR.glob("eval-*.json"))
|
| 246 |
+
if not eval_files:
|
| 247 |
+
st.error(f"No eval result files found in {EVAL_DIR}")
|
| 248 |
+
st.stop()
|
| 249 |
+
|
| 250 |
+
selected_file = st.sidebar.selectbox(
|
| 251 |
+
"Eval file",
|
| 252 |
+
eval_files,
|
| 253 |
+
format_func=lambda p: p.stem,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
data = load_eval_results(selected_file)
|
| 257 |
+
summary = data["summary"]
|
| 258 |
+
results = data["results"]
|
| 259 |
+
|
| 260 |
+
st.sidebar.markdown(f"""
|
| 261 |
+
**Model**: `{summary.get('label', '')}`
|
| 262 |
+
**Exact match**: {summary['exact_matches']}/{summary['num_samples']} ({summary['exact_match_rate']:.1%})
|
| 263 |
+
""")
|
| 264 |
+
|
| 265 |
+
filter_option = st.sidebar.radio("Filter", ["All", "Matches only", "Mismatches only"])
|
| 266 |
+
if filter_option == "Matches only":
|
| 267 |
+
results = [r for r in results if r["exact_match"]]
|
| 268 |
+
elif filter_option == "Mismatches only":
|
| 269 |
+
results = [r for r in results if not r["exact_match"]]
|
| 270 |
+
|
| 271 |
+
if not results:
|
| 272 |
+
st.warning("No results match the current filter.")
|
| 273 |
+
st.stop()
|
| 274 |
+
|
| 275 |
+
questions = [
|
| 276 |
+
f"[{r['index']}] {r.get('question', 'Sample ' + str(r['index']))}"
|
| 277 |
+
for r in results
|
| 278 |
+
]
|
| 279 |
+
selected_idx = st.selectbox("Select a query", range(len(questions)), format_func=lambda i: questions[i])
|
| 280 |
+
row = results[selected_idx]
|
| 281 |
+
|
| 282 |
+
match_label = "MATCH" if row["exact_match"] else "MISMATCH"
|
| 283 |
+
match_color = "green" if row["exact_match"] else "red"
|
| 284 |
+
st.markdown(f"### :{match_color}[{match_label}]")
|
| 285 |
+
|
| 286 |
+
is_sql = summary.get("task", "sql") == "sql"
|
| 287 |
+
expected = row["expected"]
|
| 288 |
+
predicted = row["predicted"]
|
| 289 |
+
|
| 290 |
+
# Formatted output side-by-side
|
| 291 |
+
col_expected, col_predicted = st.columns(2)
|
| 292 |
+
with col_expected:
|
| 293 |
+
st.markdown("**Expected**")
|
| 294 |
+
if is_sql:
|
| 295 |
+
st.code(format_sql(expected), language="sql")
|
| 296 |
+
else:
|
| 297 |
+
st.code(expected, language="json")
|
| 298 |
+
with col_predicted:
|
| 299 |
+
st.markdown("**Predicted**")
|
| 300 |
+
if is_sql:
|
| 301 |
+
st.code(format_sql(predicted), language="sql")
|
| 302 |
+
else:
|
| 303 |
+
st.code(predicted, language="json")
|
| 304 |
+
|
| 305 |
+
# Diff view
|
| 306 |
+
if not row["exact_match"]:
|
| 307 |
+
with st.expander("Diff", expanded=True):
|
| 308 |
+
diff_html = sql_diff_html(expected, predicted)
|
| 309 |
+
diff_css = """
|
| 310 |
+
<style>
|
| 311 |
+
.diff_add { background-color: rgba(40, 167, 69, 0.15); }
|
| 312 |
+
.diff_sub { background-color: rgba(220, 53, 69, 0.15); }
|
| 313 |
+
.diff_chg { background-color: rgba(255, 193, 7, 0.15); }
|
| 314 |
+
.diff_header { background-color: rgba(128, 128, 128, 0.1); font-weight: bold; }
|
| 315 |
+
table.diff { border-collapse: collapse; width: 100%; font-family: monospace; color: inherit; }
|
| 316 |
+
table.diff td, table.diff th { padding: 4px 8px; border: 1px solid rgba(128, 128, 128, 0.2); }
|
| 317 |
+
</style>
|
| 318 |
+
"""
|
| 319 |
+
st.html(f"{diff_css}<div style='overflow-x:auto; font-size:13px;'>{diff_html}</div>")
|
| 320 |
+
|
| 321 |
+
# Auto-execute SQL and show maps (only for sql task)
|
| 322 |
+
if is_sql:
|
| 323 |
+
con = get_duckdb_connection()
|
| 324 |
+
|
| 325 |
+
map_col1, map_col2 = st.columns(2)
|
| 326 |
+
|
| 327 |
+
with map_col1:
|
| 328 |
+
st.markdown("**Expected result**")
|
| 329 |
+
sql = rewrite_data_paths(expected)
|
| 330 |
+
try:
|
| 331 |
+
df = execute_sql(con, sql)
|
| 332 |
+
geojson = to_feature_collection(df)
|
| 333 |
+
render_map(geojson, [40, 180, 160, 140], key="map_expected")
|
| 334 |
+
with st.expander("Result table"):
|
| 335 |
+
st.dataframe(df, width="stretch")
|
| 336 |
+
except Exception as e:
|
| 337 |
+
st.error(f"Execution error: {e}")
|
| 338 |
+
|
| 339 |
+
with map_col2:
|
| 340 |
+
st.markdown("**Predicted result**")
|
| 341 |
+
sql = rewrite_data_paths(predicted)
|
| 342 |
+
try:
|
| 343 |
+
df = execute_sql(con, sql)
|
| 344 |
+
geojson = to_feature_collection(df)
|
| 345 |
+
render_map(geojson, [180, 80, 60, 140], key="map_predicted")
|
| 346 |
+
with st.expander("Result table"):
|
| 347 |
+
st.dataframe(df, width="stretch")
|
| 348 |
+
except Exception as e:
|
| 349 |
+
st.error(f"Execution error: {e}")
|
| 350 |
+
|
| 351 |
+
con.close()
|
finetune/train_modal_qwen35.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal training script for gazet Qwen3.5 LoRA fine-tuning with Unsloth.
|
| 2 |
+
|
| 3 |
+
Key differences from train_modal.py (Gemma):
|
| 4 |
+
- Uses Unsloth's FastLanguageModel for memory-efficient training
|
| 5 |
+
- Applies Qwen3.5 chat template to format data (not plain prompt+completion strings)
|
| 6 |
+
- Uses train_on_responses_only with ChatML markers to mask non-assistant tokens
|
| 7 |
+
- Saves merged 16-bit model via unsloth's save_pretrained_merged
|
| 8 |
+
|
| 9 |
+
Usage
|
| 10 |
+
-----
|
| 11 |
+
modal run finetune/train_modal_qwen35.py
|
| 12 |
+
modal run finetune/train_modal_qwen35.py --base-model unsloth/Qwen3.5-0.8B
|
| 13 |
+
modal run finetune/train_modal_qwen35.py --run-dir /mnt/gazet/data/v3-symbolic-paths
|
| 14 |
+
modal run finetune/train_modal_qwen35.py --num-train-epochs 5 --lora-r 32
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import pathlib
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import modal
|
| 25 |
+
|
| 26 |
+
app = modal.App("gazet-nlg-qwen35-finetune-v2")
|
| 27 |
+
|
| 28 |
+
GPU_TYPE = "A100-80GB"
|
| 29 |
+
TIMEOUT_HOURS = 24
|
| 30 |
+
MAX_RETRIES = 1
|
| 31 |
+
|
| 32 |
+
train_image = (
|
| 33 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 34 |
+
.pip_install(
|
| 35 |
+
# Use unsloth's bundled CUDA+torch extra so bitsandbytes, xformers,
|
| 36 |
+
# and trl are all resolved together against the same CUDA/torch build.
|
| 37 |
+
# Mirrors the approach in https://modal.com/docs/examples/unsloth_finetune
|
| 38 |
+
"unsloth[cu129-torch280]",
|
| 39 |
+
"unsloth_zoo",
|
| 40 |
+
"transformers~=5.2.0",
|
| 41 |
+
"hf-transfer==0.1.9",
|
| 42 |
+
"trackio[gpu]==0.21.1",
|
| 43 |
+
"datasets",
|
| 44 |
+
"pandas",
|
| 45 |
+
)
|
| 46 |
+
.add_local_python_source("finetune", copy=True)
|
| 47 |
+
.env({
|
| 48 |
+
"HF_HOME": "/mnt/gazet/model_cache",
|
| 49 |
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
| 50 |
+
})
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
with train_image.imports():
|
| 54 |
+
from unsloth import FastLanguageModel
|
| 55 |
+
from unsloth.chat_templates import train_on_responses_only
|
| 56 |
+
from trl import SFTConfig, SFTTrainer
|
| 57 |
+
from transformers import set_seed
|
| 58 |
+
|
| 59 |
+
gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
|
| 60 |
+
|
| 61 |
+
VOLUMES = {
|
| 62 |
+
"/mnt/gazet": gazet_vol,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# ChatML response markers for Qwen3.5 — the empty <think> block is how Qwen3.5
|
| 66 |
+
# formats non-thinking responses. We train only on tokens after this prefix.
|
| 67 |
+
INSTRUCTION_PART = "<|im_start|>user\n"
|
| 68 |
+
RESPONSE_PART = "<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class Qwen35Config:
|
| 73 |
+
# Model
|
| 74 |
+
base_model: str = "unsloth/Qwen3.5-0.8B"
|
| 75 |
+
|
| 76 |
+
# Dataset — path to run dir with {task}/{split}.jsonl files
|
| 77 |
+
run_dir: str = "/mnt/gazet/data/v1"
|
| 78 |
+
max_train_samples: Optional[int] = None
|
| 79 |
+
max_eval_samples: Optional[int] = None
|
| 80 |
+
|
| 81 |
+
# Sequence
|
| 82 |
+
max_seq_length: int = 2048
|
| 83 |
+
|
| 84 |
+
# LoRA — alpha=2*r follows unsloth recommendation for Qwen models
|
| 85 |
+
lora_r: int = 16
|
| 86 |
+
lora_alpha: int = 32
|
| 87 |
+
lora_dropout: float = 0.0
|
| 88 |
+
|
| 89 |
+
# Training
|
| 90 |
+
num_train_epochs: int = 1
|
| 91 |
+
per_device_train_batch_size: int = 32
|
| 92 |
+
per_device_eval_batch_size: int = 16
|
| 93 |
+
gradient_accumulation_steps: int = 1 # effective batch = 48
|
| 94 |
+
learning_rate: float = 1e-4
|
| 95 |
+
max_grad_norm: float = 1.0
|
| 96 |
+
warmup_steps: int = 50
|
| 97 |
+
lr_scheduler_type: str = "linear"
|
| 98 |
+
weight_decay: float = 0.01
|
| 99 |
+
optim: str = "adamw_8bit"
|
| 100 |
+
|
| 101 |
+
# Logging / saving
|
| 102 |
+
logging_steps: int = 10
|
| 103 |
+
save_strategy: str = "steps"
|
| 104 |
+
save_steps: int = 400
|
| 105 |
+
eval_strategy: str = "steps"
|
| 106 |
+
eval_steps: int = 200
|
| 107 |
+
report_to: str = "trackio"
|
| 108 |
+
trackio_space_id: Optional[str] = "srmsoumya/gazet-trackio"
|
| 109 |
+
project: str = "gazet-nlg-qwen35"
|
| 110 |
+
|
| 111 |
+
# Experiment
|
| 112 |
+
seed: int = 42
|
| 113 |
+
experiment_name: Optional[str] = None
|
| 114 |
+
|
| 115 |
+
def __post_init__(self):
|
| 116 |
+
if self.experiment_name is None:
|
| 117 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 118 |
+
model_short = self.base_model.split("/")[-1]
|
| 119 |
+
self.experiment_name = f"{model_short}-r{self.lora_r}-{timestamp}"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples=None):
|
| 123 |
+
"""Load JSONL data and apply Qwen3.5 chat template.
|
| 124 |
+
|
| 125 |
+
Each sample must have:
|
| 126 |
+
messages: list of {role, content} dicts (system + user + assistant)
|
| 127 |
+
|
| 128 |
+
The chat template produces the full ChatML string including the assistant turn.
|
| 129 |
+
train_on_responses_only then masks everything except the assistant response.
|
| 130 |
+
"""
|
| 131 |
+
import json
|
| 132 |
+
from datasets import Dataset, DatasetDict
|
| 133 |
+
|
| 134 |
+
def load_jsonl(path: pathlib.Path) -> list[dict]:
|
| 135 |
+
rows = []
|
| 136 |
+
with open(path) as f:
|
| 137 |
+
for line in f:
|
| 138 |
+
line = line.strip()
|
| 139 |
+
if line:
|
| 140 |
+
rows.append(json.loads(line))
|
| 141 |
+
return rows
|
| 142 |
+
|
| 143 |
+
def to_message(sample: dict) -> dict:
|
| 144 |
+
text = tokenizer.apply_chat_template(
|
| 145 |
+
sample["messages"],
|
| 146 |
+
tokenize=False,
|
| 147 |
+
add_generation_prompt=False,
|
| 148 |
+
)
|
| 149 |
+
return {"messages": text}
|
| 150 |
+
|
| 151 |
+
run_dir = pathlib.Path(run_dir)
|
| 152 |
+
tasks = ("sql", "places")
|
| 153 |
+
splits = ("train", "val")
|
| 154 |
+
ds_dict: dict = {}
|
| 155 |
+
|
| 156 |
+
for split in splits:
|
| 157 |
+
combined: list[dict] = []
|
| 158 |
+
for task in tasks:
|
| 159 |
+
path = run_dir / task / f"{split}.jsonl"
|
| 160 |
+
if not path.exists():
|
| 161 |
+
print(f"Missing {path} — skipping")
|
| 162 |
+
continue
|
| 163 |
+
rows = load_jsonl(path)
|
| 164 |
+
flattened = [to_message(r) for r in rows]
|
| 165 |
+
combined.extend(flattened)
|
| 166 |
+
print(f"Loaded {len(rows):,} {task}/{split} rows")
|
| 167 |
+
|
| 168 |
+
if combined:
|
| 169 |
+
ds_dict[split] = Dataset.from_list(combined)
|
| 170 |
+
print(f"{split} split: {len(combined):,} total rows")
|
| 171 |
+
|
| 172 |
+
ds = DatasetDict(ds_dict).shuffle(seed=42)
|
| 173 |
+
|
| 174 |
+
if max_train_samples is not None and "train" in ds:
|
| 175 |
+
ds["train"] = ds["train"].select(range(min(max_train_samples, len(ds["train"]))))
|
| 176 |
+
if max_eval_samples is not None and "val" in ds:
|
| 177 |
+
ds["val"] = ds["val"].select(range(min(max_eval_samples, len(ds["val"]))))
|
| 178 |
+
|
| 179 |
+
return ds
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _find_latest_checkpoint(checkpoint_dir: pathlib.Path) -> str | None:
|
| 183 |
+
if not checkpoint_dir.exists():
|
| 184 |
+
return None
|
| 185 |
+
checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
|
| 186 |
+
if not checkpoints:
|
| 187 |
+
return None
|
| 188 |
+
latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
|
| 189 |
+
print(f"Found existing checkpoint: {latest}")
|
| 190 |
+
return str(latest)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@app.function(
|
| 194 |
+
image=train_image,
|
| 195 |
+
gpu=GPU_TYPE,
|
| 196 |
+
volumes=VOLUMES,
|
| 197 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 198 |
+
timeout=TIMEOUT_HOURS * 60 * 60,
|
| 199 |
+
retries=modal.Retries(initial_delay=0.0, max_retries=MAX_RETRIES),
|
| 200 |
+
)
|
| 201 |
+
def finetune(config_dict: dict):
|
| 202 |
+
"""Run Qwen3.5 LoRA SFT training with Unsloth inside a Modal container."""
|
| 203 |
+
config = Qwen35Config(**config_dict)
|
| 204 |
+
set_seed(config.seed)
|
| 205 |
+
|
| 206 |
+
experiment_dir = pathlib.Path("/mnt/gazet/checkpoints") / config.experiment_name
|
| 207 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
| 208 |
+
|
| 209 |
+
print(f"Experiment: {config.experiment_name}")
|
| 210 |
+
print(f"Model: {config.base_model}")
|
| 211 |
+
print(f"Run dir: {config.run_dir}")
|
| 212 |
+
|
| 213 |
+
# Load base model with unsloth — gradient checkpointing is handled internally
|
| 214 |
+
model, processor = FastLanguageModel.from_pretrained(
|
| 215 |
+
config.base_model,
|
| 216 |
+
max_seq_length=config.max_seq_length,
|
| 217 |
+
load_in_4bit=False,
|
| 218 |
+
use_gradient_checkpointing="unsloth",
|
| 219 |
+
fast_inference=False,
|
| 220 |
+
)
|
| 221 |
+
tokenizer = processor.tokenizer
|
| 222 |
+
|
| 223 |
+
# Apply LoRA adapters — let unsloth select target modules via finetune_* flags
|
| 224 |
+
model = FastLanguageModel.get_peft_model(
|
| 225 |
+
model,
|
| 226 |
+
r=config.lora_r,
|
| 227 |
+
lora_alpha=config.lora_alpha,
|
| 228 |
+
lora_dropout=config.lora_dropout,
|
| 229 |
+
finetune_vision_layers=False,
|
| 230 |
+
finetune_language_layers=True,
|
| 231 |
+
finetune_attention_modules=True,
|
| 232 |
+
finetune_mlp_modules=True,
|
| 233 |
+
bias="none",
|
| 234 |
+
random_state=config.seed,
|
| 235 |
+
use_gradient_checkpointing=False, # already set in from_pretrained
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 239 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 240 |
+
print(f"Total parameters: {total_params:,}")
|
| 241 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 242 |
+
|
| 243 |
+
ds = _load_data(
|
| 244 |
+
config.run_dir,
|
| 245 |
+
tokenizer,
|
| 246 |
+
max_train_samples=config.max_train_samples,
|
| 247 |
+
max_eval_samples=config.max_eval_samples,
|
| 248 |
+
)
|
| 249 |
+
if "train" not in ds:
|
| 250 |
+
raise RuntimeError(
|
| 251 |
+
f"No training data found in {config.run_dir}. "
|
| 252 |
+
"Run the dataset pipeline and upload exported data to the volume first."
|
| 253 |
+
)
|
| 254 |
+
print(f"Train samples: {len(ds['train']):,}")
|
| 255 |
+
if "val" in ds:
|
| 256 |
+
print(f"Val samples: {len(ds['val']):,}")
|
| 257 |
+
effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
|
| 258 |
+
print(f"Effective batch: {effective_batch}")
|
| 259 |
+
|
| 260 |
+
sft_args = SFTConfig(
|
| 261 |
+
output_dir=str(experiment_dir),
|
| 262 |
+
dataset_text_field="messages",
|
| 263 |
+
max_seq_length=config.max_seq_length,
|
| 264 |
+
num_train_epochs=config.num_train_epochs,
|
| 265 |
+
per_device_train_batch_size=config.per_device_train_batch_size,
|
| 266 |
+
per_device_eval_batch_size=config.per_device_eval_batch_size,
|
| 267 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 268 |
+
learning_rate=config.learning_rate,
|
| 269 |
+
max_grad_norm=config.max_grad_norm,
|
| 270 |
+
warmup_steps=config.warmup_steps,
|
| 271 |
+
lr_scheduler_type=config.lr_scheduler_type,
|
| 272 |
+
weight_decay=config.weight_decay,
|
| 273 |
+
optim=config.optim,
|
| 274 |
+
bf16=True,
|
| 275 |
+
logging_steps=config.logging_steps,
|
| 276 |
+
save_strategy=config.save_strategy,
|
| 277 |
+
save_steps=config.save_steps,
|
| 278 |
+
eval_strategy=config.eval_strategy,
|
| 279 |
+
eval_steps=config.eval_steps,
|
| 280 |
+
report_to=config.report_to,
|
| 281 |
+
trackio_space_id=config.trackio_space_id,
|
| 282 |
+
project=config.project,
|
| 283 |
+
dataset_num_proc=8,
|
| 284 |
+
seed=config.seed,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
trainer = SFTTrainer(
|
| 288 |
+
model=model,
|
| 289 |
+
tokenizer=tokenizer,
|
| 290 |
+
train_dataset=ds["train"],
|
| 291 |
+
eval_dataset=ds.get("val"),
|
| 292 |
+
args=sft_args,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Mask all tokens except the assistant response — train on completions only
|
| 296 |
+
trainer = train_on_responses_only(
|
| 297 |
+
trainer,
|
| 298 |
+
instruction_part=INSTRUCTION_PART,
|
| 299 |
+
response_part=RESPONSE_PART,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
resume_from = _find_latest_checkpoint(experiment_dir)
|
| 303 |
+
if resume_from:
|
| 304 |
+
print(f"Resuming from {resume_from}")
|
| 305 |
+
|
| 306 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 307 |
+
|
| 308 |
+
# Save LoRA adapter + tokenizer (lightweight, for future merging)
|
| 309 |
+
print(f"Saving LoRA adapter to {experiment_dir}")
|
| 310 |
+
model.save_pretrained(str(experiment_dir))
|
| 311 |
+
tokenizer.save_pretrained(str(experiment_dir))
|
| 312 |
+
|
| 313 |
+
# Save merged 16-bit model (full weights, ready for inference / GGUF conversion)
|
| 314 |
+
merged_dir = experiment_dir / "merged"
|
| 315 |
+
merged_dir.mkdir(parents=True, exist_ok=True)
|
| 316 |
+
print(f"Saving merged 16-bit model to {merged_dir}")
|
| 317 |
+
model.save_pretrained_merged(str(merged_dir), tokenizer, save_method="merged_16bit")
|
| 318 |
+
|
| 319 |
+
gazet_vol.commit()
|
| 320 |
+
print(f"Training complete: {config.experiment_name}")
|
| 321 |
+
return config.experiment_name
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@app.local_entrypoint()
|
| 325 |
+
def main(
|
| 326 |
+
base_model: Optional[str] = None,
|
| 327 |
+
experiment_name: Optional[str] = None,
|
| 328 |
+
run_dir: Optional[str] = None,
|
| 329 |
+
num_train_epochs: Optional[int] = None,
|
| 330 |
+
per_device_train_batch_size: Optional[int] = None,
|
| 331 |
+
max_train_samples: Optional[int] = None,
|
| 332 |
+
max_eval_samples: Optional[int] = None,
|
| 333 |
+
lora_r: Optional[int] = None,
|
| 334 |
+
max_seq_length: Optional[int] = None,
|
| 335 |
+
):
|
| 336 |
+
overrides = {
|
| 337 |
+
k: v for k, v in dict(
|
| 338 |
+
base_model=base_model,
|
| 339 |
+
experiment_name=experiment_name,
|
| 340 |
+
run_dir=run_dir,
|
| 341 |
+
num_train_epochs=num_train_epochs,
|
| 342 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 343 |
+
max_train_samples=max_train_samples,
|
| 344 |
+
max_eval_samples=max_eval_samples,
|
| 345 |
+
lora_r=lora_r,
|
| 346 |
+
max_seq_length=max_seq_length,
|
| 347 |
+
).items() if v is not None
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
config = Qwen35Config(**overrides)
|
| 351 |
+
# lora_alpha follows r if r was overridden and alpha wasn't
|
| 352 |
+
if lora_r is not None:
|
| 353 |
+
config.lora_alpha = 2 * config.lora_r
|
| 354 |
+
|
| 355 |
+
print(f"Starting experiment: {config.experiment_name}")
|
| 356 |
+
print(f"Model: {config.base_model}")
|
| 357 |
+
print(f"Run dir: {config.run_dir}")
|
| 358 |
+
print(f"LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
|
| 359 |
+
effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
|
| 360 |
+
print(f"Effective batch: {effective_batch}")
|
| 361 |
+
|
| 362 |
+
result = finetune.remote(config.__dict__)
|
| 363 |
+
print(f"Training complete: {result}")
|
gazet_demo.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import math
|
|
|
|
| 5 |
|
| 6 |
import pandas as pd
|
| 7 |
import requests
|
|
@@ -100,7 +101,7 @@ def _render_map(geojson, placeholder):
|
|
| 100 |
st.json(geojson)
|
| 101 |
|
| 102 |
|
| 103 |
-
API = "http://127.0.0.1:8000"
|
| 104 |
EXAMPLES = [
|
| 105 |
"Angola and Mozambique",
|
| 106 |
"Mediterranean Sea",
|
|
@@ -114,6 +115,17 @@ st.set_page_config(page_title="Gazet", page_icon="🌍", layout="wide")
|
|
| 114 |
st.title("Gazet")
|
| 115 |
st.caption("Natural-language geo search · click an example or type your own")
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if "run_q" not in st.session_state:
|
| 118 |
st.session_state.run_q = None
|
| 119 |
|
|
@@ -149,7 +161,7 @@ with col2:
|
|
| 149 |
|
| 150 |
try:
|
| 151 |
with requests.get(
|
| 152 |
-
f"{API}/search/stream", params={"q": to_run}, stream=True, timeout=120
|
| 153 |
) as r:
|
| 154 |
r.raise_for_status()
|
| 155 |
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import math
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
import pandas as pd
|
| 8 |
import requests
|
|
|
|
| 101 |
st.json(geojson)
|
| 102 |
|
| 103 |
|
| 104 |
+
API = os.environ.get("GAZET_API_URL", "http://127.0.0.1:8000")
|
| 105 |
EXAMPLES = [
|
| 106 |
"Angola and Mozambique",
|
| 107 |
"Mediterranean Sea",
|
|
|
|
| 115 |
st.title("Gazet")
|
| 116 |
st.caption("Natural-language geo search · click an example or type your own")
|
| 117 |
|
| 118 |
+
backend = st.sidebar.radio(
|
| 119 |
+
"SQL Backend",
|
| 120 |
+
["gguf", "dspy"],
|
| 121 |
+
index=0,
|
| 122 |
+
format_func=lambda x: "⚡ GGUF (llama-server)" if x == "gguf" else "🧠 DSPy (cloud LM)",
|
| 123 |
+
)
|
| 124 |
+
st.sidebar.caption(
|
| 125 |
+
"**gguf** → finetuned Qwen3.5 via llama-server\n\n"
|
| 126 |
+
"**dspy** → Ollama / cloud LM with retry loop"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
if "run_q" not in st.session_state:
|
| 130 |
st.session_state.run_q = None
|
| 131 |
|
|
|
|
| 161 |
|
| 162 |
try:
|
| 163 |
with requests.get(
|
| 164 |
+
f"{API}/search/stream", params={"q": to_run, "backend": backend}, stream=True, timeout=120
|
| 165 |
) as r:
|
| 166 |
r.raise_for_status()
|
| 167 |
|
pyproject.toml
CHANGED
|
@@ -16,8 +16,18 @@ dependencies = [
|
|
| 16 |
"pydantic>=2.0",
|
| 17 |
"pyarrow>=17.0.0",
|
| 18 |
"geopandas>=1.1.2",
|
|
|
|
|
|
|
| 19 |
]
|
| 20 |
optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
[tool.hatch.build.targets.wheel]
|
| 23 |
-
packages = ["src/gazet"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"pydantic>=2.0",
|
| 17 |
"pyarrow>=17.0.0",
|
| 18 |
"geopandas>=1.1.2",
|
| 19 |
+
"httpx>=0.28.1",
|
| 20 |
+
"sqlparse>=0.5.5",
|
| 21 |
]
|
| 22 |
optional-dependencies = { demo = ["streamlit", "requests", "pydeck"], dev = ["ruff"] }
|
| 23 |
|
| 24 |
+
[project.scripts]
|
| 25 |
+
gazet-dataset = "dataset.scripts.cli:main"
|
| 26 |
+
|
| 27 |
[tool.hatch.build.targets.wheel]
|
| 28 |
+
packages = ["src/gazet", "dataset"]
|
| 29 |
+
|
| 30 |
+
[dependency-groups]
|
| 31 |
+
dataset = [
|
| 32 |
+
"modal>=1.4.0",
|
| 33 |
+
]
|
src/gazet/api.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
from typing import Any, Generator
|
| 3 |
|
| 4 |
import duckdb
|
|
@@ -7,19 +8,33 @@ from fastapi import FastAPI, HTTPException
|
|
| 7 |
from fastapi.responses import StreamingResponse
|
| 8 |
|
| 9 |
from .export import to_feature_collection
|
| 10 |
-
from .lm import extract
|
| 11 |
-
from .search import
|
| 12 |
-
from .sql import
|
| 13 |
|
| 14 |
app = FastAPI()
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def _df_to_records(df: pd.DataFrame) -> list[dict[str, Any]]:
|
| 18 |
"""Convert DataFrame to list of dicts for JSON; handle non-JSON-serializable types."""
|
| 19 |
return df.replace({float("nan"): None}).to_dict(orient="records")
|
| 20 |
|
| 21 |
|
| 22 |
-
def _run_stream(query: str) -> Generator[str, None, None]:
|
| 23 |
"""Yield NDJSON lines as each stage of the search completes.
|
| 24 |
|
| 25 |
Event ``type`` values (in order of emission):
|
|
@@ -30,9 +45,12 @@ def _run_stream(query: str) -> Generator[str, None, None]:
|
|
| 30 |
- ``geojson`` – final FeatureCollection
|
| 31 |
- ``error`` – fatal error (no result)
|
| 32 |
"""
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
yield json.dumps({"type": "places", "data": places_result.model_dump()}) + "\n"
|
| 38 |
|
|
@@ -41,12 +59,10 @@ def _run_stream(query: str) -> Generator[str, None, None]:
|
|
| 41 |
con.execute("LOAD spatial")
|
| 42 |
|
| 43 |
try:
|
|
|
|
| 44 |
all_candidates: list[pd.DataFrame] = []
|
| 45 |
for place in places_result.places:
|
| 46 |
-
|
| 47 |
-
df = search_fn(con, place)
|
| 48 |
-
if not df.empty:
|
| 49 |
-
all_candidates.append(df)
|
| 50 |
|
| 51 |
if not all_candidates:
|
| 52 |
yield json.dumps({"type": "error", "data": "No candidates found"}) + "\n"
|
|
@@ -64,8 +80,9 @@ def _run_stream(query: str) -> Generator[str, None, None]:
|
|
| 64 |
+ "\n"
|
| 65 |
)
|
| 66 |
|
|
|
|
| 67 |
result_df: pd.DataFrame | None = None
|
| 68 |
-
for event in
|
| 69 |
if event["type"] == "sql_attempt":
|
| 70 |
yield (
|
| 71 |
json.dumps(
|
|
@@ -105,13 +122,13 @@ def _run_stream(query: str) -> Generator[str, None, None]:
|
|
| 105 |
|
| 106 |
|
| 107 |
@app.get("/search/stream")
|
| 108 |
-
def search_stream(q: str) -> StreamingResponse:
|
| 109 |
"""Stream search progress as NDJSON (one JSON object per line)."""
|
| 110 |
-
return StreamingResponse(_run_stream(q), media_type="application/x-ndjson")
|
| 111 |
|
| 112 |
|
| 113 |
@app.get("/search", response_model=None)
|
| 114 |
-
def search(q: str) -> dict[str, Any]:
|
| 115 |
"""Run geo search for natural-language query (non-streaming).
|
| 116 |
|
| 117 |
Returns GeoJSON FeatureCollection, the executed SQL, and the identified
|
|
@@ -122,7 +139,7 @@ def search(q: str) -> dict[str, Any]:
|
|
| 122 |
sql = ""
|
| 123 |
geojson: dict | None = None
|
| 124 |
|
| 125 |
-
for line in _run_stream(q):
|
| 126 |
if not line.strip():
|
| 127 |
continue
|
| 128 |
event = json.loads(line)
|
|
|
|
| 1 |
import json
|
| 2 |
+
import math
|
| 3 |
from typing import Any, Generator
|
| 4 |
|
| 5 |
import duckdb
|
|
|
|
| 8 |
from fastapi.responses import StreamingResponse
|
| 9 |
|
| 10 |
from .export import to_feature_collection
|
| 11 |
+
from .lm import extract, generate_places
|
| 12 |
+
from .search import search_candidates
|
| 13 |
+
from .sql import run_geo_sql_dspy, run_geo_sql_gguf
|
| 14 |
|
| 15 |
app = FastAPI()
|
| 16 |
|
| 17 |
|
| 18 |
+
def _per_source_limit(num_places: int) -> int:
|
| 19 |
+
"""Candidates to fetch per source per place, scaled by number of places.
|
| 20 |
+
|
| 21 |
+
Keeps the total candidate count in the prompt manageable:
|
| 22 |
+
1 place → 5 per source → 10 total
|
| 23 |
+
2 places → 4 per source → 16 total
|
| 24 |
+
3 places → 3 per source → 18 total
|
| 25 |
+
4 places → 2 per source → 16 total
|
| 26 |
+
5 places → 2 per source → 20 total
|
| 27 |
+
"""
|
| 28 |
+
table = {1: 5, 2: 4, 3: 3, 4: 2, 5: 2}
|
| 29 |
+
return table.get(num_places, max(1, math.ceil(5 / num_places)))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
def _df_to_records(df: pd.DataFrame) -> list[dict[str, Any]]:
|
| 33 |
"""Convert DataFrame to list of dicts for JSON; handle non-JSON-serializable types."""
|
| 34 |
return df.replace({float("nan"): None}).to_dict(orient="records")
|
| 35 |
|
| 36 |
|
| 37 |
+
def _run_stream(query: str, backend: str = "gguf") -> Generator[str, None, None]:
|
| 38 |
"""Yield NDJSON lines as each stage of the search completes.
|
| 39 |
|
| 40 |
Event ``type`` values (in order of emission):
|
|
|
|
| 45 |
- ``geojson`` – final FeatureCollection
|
| 46 |
- ``error`` – fatal error (no result)
|
| 47 |
"""
|
| 48 |
+
if backend == "gguf":
|
| 49 |
+
places_result = generate_places(query)
|
| 50 |
+
else:
|
| 51 |
+
pred = extract(query=query)
|
| 52 |
+
places_result = pred.result
|
| 53 |
+
print("places:", places_result)
|
| 54 |
|
| 55 |
yield json.dumps({"type": "places", "data": places_result.model_dump()}) + "\n"
|
| 56 |
|
|
|
|
| 59 |
con.execute("LOAD spatial")
|
| 60 |
|
| 61 |
try:
|
| 62 |
+
limit = _per_source_limit(len(places_result.places))
|
| 63 |
all_candidates: list[pd.DataFrame] = []
|
| 64 |
for place in places_result.places:
|
| 65 |
+
all_candidates.extend(search_candidates(con, place, limit=limit))
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
if not all_candidates:
|
| 68 |
yield json.dumps({"type": "error", "data": "No candidates found"}) + "\n"
|
|
|
|
| 80 |
+ "\n"
|
| 81 |
)
|
| 82 |
|
| 83 |
+
sql_fn = run_geo_sql_gguf if backend == "gguf" else run_geo_sql_dspy
|
| 84 |
result_df: pd.DataFrame | None = None
|
| 85 |
+
for event in sql_fn(con, query, candidates_df):
|
| 86 |
if event["type"] == "sql_attempt":
|
| 87 |
yield (
|
| 88 |
json.dumps(
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
@app.get("/search/stream")
|
| 125 |
+
def search_stream(q: str, backend: str = "gguf") -> StreamingResponse:
|
| 126 |
"""Stream search progress as NDJSON (one JSON object per line)."""
|
| 127 |
+
return StreamingResponse(_run_stream(q, backend), media_type="application/x-ndjson")
|
| 128 |
|
| 129 |
|
| 130 |
@app.get("/search", response_model=None)
|
| 131 |
+
def search(q: str, backend: str = "gguf") -> dict[str, Any]:
|
| 132 |
"""Run geo search for natural-language query (non-streaming).
|
| 133 |
|
| 134 |
Returns GeoJSON FeatureCollection, the executed SQL, and the identified
|
|
|
|
| 139 |
sql = ""
|
| 140 |
geojson: dict | None = None
|
| 141 |
|
| 142 |
+
for line in _run_stream(q, backend):
|
| 143 |
if not line.strip():
|
| 144 |
continue
|
| 145 |
event = json.loads(line)
|
src/gazet/config.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
|
|
| 1 |
import pathlib
|
| 2 |
|
| 3 |
-
# Data lives at project root (gazet/data/), not inside the package
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
| 5 |
DIVISIONS_AREA_PATH = str(_DATA_DIR / "overture/divisions_area/*.parquet")
|
| 6 |
NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet")
|
| 7 |
|
|
@@ -9,18 +13,29 @@ NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parq
|
|
| 9 |
# MODEL = "granite4:350m"
|
| 10 |
# MODEL = "gemma3:12b-cloud"
|
| 11 |
# MODEL = "qwen3.5:397b-cloud"
|
| 12 |
-
MODEL = "gpt-oss:20b-cloud"
|
| 13 |
# MODEL = "qwen3:4b"
|
| 14 |
# MODEL = "qwen3-coder-next:cloud"
|
| 15 |
# MODEL = "deepseek-coder:1.3b"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
MAX_SQL_ITERATIONS = 5
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
Available DuckDB datasets (read via read_parquet):
|
| 21 |
|
| 22 |
1. divisions_area — Overture polygon/multipolygon admin boundaries
|
| 23 |
-
|
| 24 |
columns:
|
| 25 |
id VARCHAR -- unique feature id (use this to filter precisely)
|
| 26 |
names STRUCT("primary" VARCHAR, ...)
|
|
@@ -36,7 +51,7 @@ Available DuckDB datasets (read via read_parquet):
|
|
| 36 |
geometry GEOMETRY -- boundary polygon/multipolygon (WKB, spatial ext loaded)
|
| 37 |
|
| 38 |
2. natural_earth — Natural Earth geography polygons (oceans, seas, terrain regions, islands)
|
| 39 |
-
|
| 40 |
columns:
|
| 41 |
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 42 |
names STRUCT("primary" VARCHAR, ...)
|
|
@@ -49,26 +64,26 @@ Available DuckDB datasets (read via read_parquet):
|
|
| 49 |
is_territorial BOOLEAN
|
| 50 |
geometry GEOMETRY -- polygon/multipolygon (WKB, spatial ext loaded)
|
| 51 |
|
| 52 |
-
Spatial extension is already loaded — use ST_AsGeoJSON(geometry)
|
| 53 |
To access names use: names."primary"
|
| 54 |
|
| 55 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 56 |
-
Use
|
| 57 |
|
| 58 |
Example patterns:
|
| 59 |
-- single region boundary from divisions_area
|
| 60 |
-
SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS
|
| 61 |
-
FROM read_parquet('
|
| 62 |
WHERE id = '<candidate_id>'
|
| 63 |
|
| 64 |
-- feature from natural_earth
|
| 65 |
-
SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS
|
| 66 |
-
FROM read_parquet('
|
| 67 |
WHERE id = '<candidate_id>'
|
| 68 |
|
| 69 |
-- shared border between two adjacent regions
|
| 70 |
-
WITH a AS (SELECT geometry FROM read_parquet('
|
| 71 |
-
b AS (SELECT geometry FROM read_parquet('
|
| 72 |
-
SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, b.geometry)) AS
|
| 73 |
FROM a, b
|
| 74 |
"""
|
|
|
|
| 1 |
+
import os
|
| 2 |
import pathlib
|
| 3 |
|
| 4 |
+
# Data lives at project root (gazet/data/), not inside the package.
|
| 5 |
+
# Override with GAZET_DATA_DIR env var for remote execution (e.g. Modal volume at /data).
|
| 6 |
+
_DATA_DIR = pathlib.Path(os.environ.get("GAZET_DATA_DIR", str(
|
| 7 |
+
pathlib.Path(__file__).resolve().parent.parent.parent / "data"
|
| 8 |
+
)))
|
| 9 |
DIVISIONS_AREA_PATH = str(_DATA_DIR / "overture/divisions_area/*.parquet")
|
| 10 |
NATURAL_EARTH_PATH = str(_DATA_DIR / "natural_earth_geoparquet/ne_geography.parquet")
|
| 11 |
|
|
|
|
| 13 |
# MODEL = "granite4:350m"
|
| 14 |
# MODEL = "gemma3:12b-cloud"
|
| 15 |
# MODEL = "qwen3.5:397b-cloud"
|
| 16 |
+
# MODEL = "gpt-oss:20b-cloud"
|
| 17 |
# MODEL = "qwen3:4b"
|
| 18 |
# MODEL = "qwen3-coder-next:cloud"
|
| 19 |
# MODEL = "deepseek-coder:1.3b"
|
| 20 |
+
# MODEL = "qwen3.5:2b"
|
| 21 |
+
# MODEL = "qwen3.5:0.8b"
|
| 22 |
+
# MODEL = "qwen2.5-coder:1.5b"
|
| 23 |
+
|
| 24 |
+
PLACE_EXTRACTION_MODEL = "gpt-oss:20b-cloud"
|
| 25 |
+
SQL_GENERATION_MODEL = "gpt-oss:20b-cloud"
|
| 26 |
|
| 27 |
MAX_SQL_ITERATIONS = 5
|
| 28 |
|
| 29 |
+
# ── GGUF / llama-server config ────────────────────────────────────────────────
|
| 30 |
+
LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://localhost:9000")
|
| 31 |
+
LLAMA_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "2048"))
|
| 32 |
+
LLAMA_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0"))
|
| 33 |
+
|
| 34 |
+
SCHEMA_INFO = """
|
| 35 |
Available DuckDB datasets (read via read_parquet):
|
| 36 |
|
| 37 |
1. divisions_area — Overture polygon/multipolygon admin boundaries
|
| 38 |
+
query: read_parquet('divisions_area')
|
| 39 |
columns:
|
| 40 |
id VARCHAR -- unique feature id (use this to filter precisely)
|
| 41 |
names STRUCT("primary" VARCHAR, ...)
|
|
|
|
| 51 |
geometry GEOMETRY -- boundary polygon/multipolygon (WKB, spatial ext loaded)
|
| 52 |
|
| 53 |
2. natural_earth — Natural Earth geography polygons (oceans, seas, terrain regions, islands)
|
| 54 |
+
query: read_parquet('natural_earth')
|
| 55 |
columns:
|
| 56 |
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 57 |
names STRUCT("primary" VARCHAR, ...)
|
|
|
|
| 64 |
is_territorial BOOLEAN
|
| 65 |
geometry GEOMETRY -- polygon/multipolygon (WKB, spatial ext loaded)
|
| 66 |
|
| 67 |
+
Spatial extension is already loaded — use ST_AsGeoJSON(geometry) for geometry outputs.
|
| 68 |
To access names use: names."primary"
|
| 69 |
|
| 70 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 71 |
+
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 72 |
|
| 73 |
Example patterns:
|
| 74 |
-- single region boundary from divisions_area
|
| 75 |
+
SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS geometry
|
| 76 |
+
FROM read_parquet('divisions_area')
|
| 77 |
WHERE id = '<candidate_id>'
|
| 78 |
|
| 79 |
-- feature from natural_earth
|
| 80 |
+
SELECT id, names."primary" AS name, ST_AsGeoJSON(geometry) AS geometry
|
| 81 |
+
FROM read_parquet('natural_earth')
|
| 82 |
WHERE id = '<candidate_id>'
|
| 83 |
|
| 84 |
-- shared border between two adjacent regions
|
| 85 |
+
WITH a AS (SELECT geometry FROM read_parquet('divisions_area') WHERE id = '<id_a>'),
|
| 86 |
+
b AS (SELECT geometry FROM read_parquet('divisions_area') WHERE id = '<id_b>')
|
| 87 |
+
SELECT ST_AsGeoJSON(ST_Intersection(a.geometry, b.geometry)) AS geometry
|
| 88 |
FROM a, b
|
| 89 |
"""
|
src/gazet/export.py
CHANGED
|
@@ -2,9 +2,25 @@ import json
|
|
| 2 |
import pathlib
|
| 3 |
import re
|
| 4 |
|
|
|
|
| 5 |
import pandas as pd
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def _is_geojson_col(series: pd.Series) -> bool:
|
| 9 |
"""Heuristic: a string column whose non-null values start with '{"type":'."""
|
| 10 |
sample = series.dropna().head(5)
|
|
@@ -16,6 +32,36 @@ def _is_geojson_col(series: pd.Series) -> bool:
|
|
| 16 |
)
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def save_geojson(
|
| 20 |
result_df: pd.DataFrame, query: str, output_dir: pathlib.Path | str = "."
|
| 21 |
) -> pathlib.Path:
|
|
@@ -43,22 +89,33 @@ def to_feature_collection(result_df: pd.DataFrame) -> dict:
|
|
| 43 |
|
| 44 |
|
| 45 |
def _to_feature_collection(result_df: pd.DataFrame) -> dict:
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
prop_cols = [c for c in result_df.columns if c not in geom_cols]
|
| 48 |
features = []
|
| 49 |
for _, row in result_df.iterrows():
|
| 50 |
geometry = None
|
| 51 |
-
if
|
| 52 |
-
raw = row[
|
| 53 |
if raw and isinstance(raw, str):
|
| 54 |
try:
|
| 55 |
geometry = json.loads(raw)
|
| 56 |
except json.JSONDecodeError:
|
| 57 |
pass
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
if
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
features.append(
|
| 63 |
{"type": "Feature", "geometry": geometry, "properties": properties}
|
| 64 |
)
|
|
|
|
| 2 |
import pathlib
|
| 3 |
import re
|
| 4 |
|
| 5 |
+
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
|
| 8 |
|
| 9 |
+
def _to_serializable(val):
|
| 10 |
+
"""Convert a value to a JSON-serializable Python type."""
|
| 11 |
+
if isinstance(val, (bytearray, bytes)):
|
| 12 |
+
return None
|
| 13 |
+
if isinstance(val, np.ndarray):
|
| 14 |
+
return val.tolist()
|
| 15 |
+
if isinstance(val, (np.integer,)):
|
| 16 |
+
return int(val)
|
| 17 |
+
if isinstance(val, (np.floating,)):
|
| 18 |
+
return float(val)
|
| 19 |
+
if isinstance(val, (np.bool_,)):
|
| 20 |
+
return bool(val)
|
| 21 |
+
return val
|
| 22 |
+
|
| 23 |
+
|
| 24 |
def _is_geojson_col(series: pd.Series) -> bool:
|
| 25 |
"""Heuristic: a string column whose non-null values start with '{"type":'."""
|
| 26 |
sample = series.dropna().head(5)
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
|
| 35 |
+
def _is_wkb_col(series: pd.Series) -> bool:
|
| 36 |
+
"""Heuristic: a column whose non-null values are bytearray or bytes (WKB geometry)."""
|
| 37 |
+
sample = series.dropna().head(5)
|
| 38 |
+
return (
|
| 39 |
+
sample.apply(lambda v: isinstance(v, (bytearray, bytes))).all()
|
| 40 |
+
and len(sample) > 0
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _wkb_to_geojson(wkb: bytearray | bytes) -> dict | None:
|
| 45 |
+
"""Convert WKB geometry to GeoJSON dict via DuckDB."""
|
| 46 |
+
import duckdb
|
| 47 |
+
|
| 48 |
+
con = duckdb.connect()
|
| 49 |
+
try:
|
| 50 |
+
con.execute("INSTALL spatial")
|
| 51 |
+
con.execute("LOAD spatial")
|
| 52 |
+
result = con.execute(
|
| 53 |
+
"SELECT ST_AsGeoJSON(ST_GeomFromWKB(?::BLOB)) AS geojson",
|
| 54 |
+
[bytes(wkb)],
|
| 55 |
+
).fetchone()
|
| 56 |
+
if result and result[0]:
|
| 57 |
+
return json.loads(result[0])
|
| 58 |
+
except Exception:
|
| 59 |
+
pass
|
| 60 |
+
finally:
|
| 61 |
+
con.close()
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def save_geojson(
|
| 66 |
result_df: pd.DataFrame, query: str, output_dir: pathlib.Path | str = "."
|
| 67 |
) -> pathlib.Path:
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def _to_feature_collection(result_df: pd.DataFrame) -> dict:
|
| 92 |
+
geojson_cols = [c for c in result_df.columns if _is_geojson_col(result_df[c])]
|
| 93 |
+
wkb_cols = [c for c in result_df.columns if _is_wkb_col(result_df[c])]
|
| 94 |
+
geom_cols = geojson_cols + wkb_cols
|
| 95 |
prop_cols = [c for c in result_df.columns if c not in geom_cols]
|
| 96 |
features = []
|
| 97 |
for _, row in result_df.iterrows():
|
| 98 |
geometry = None
|
| 99 |
+
if geojson_cols:
|
| 100 |
+
raw = row[geojson_cols[0]]
|
| 101 |
if raw and isinstance(raw, str):
|
| 102 |
try:
|
| 103 |
geometry = json.loads(raw)
|
| 104 |
except json.JSONDecodeError:
|
| 105 |
pass
|
| 106 |
+
elif wkb_cols:
|
| 107 |
+
raw = row[wkb_cols[0]]
|
| 108 |
+
if raw and isinstance(raw, (bytearray, bytes)):
|
| 109 |
+
geometry = _wkb_to_geojson(raw)
|
| 110 |
+
properties = {}
|
| 111 |
+
for c in prop_cols:
|
| 112 |
+
v = row[c]
|
| 113 |
+
try:
|
| 114 |
+
if not pd.notna(v):
|
| 115 |
+
continue
|
| 116 |
+
except ValueError:
|
| 117 |
+
pass # pd.notna fails on arrays — treat as present
|
| 118 |
+
properties[c] = _to_serializable(v)
|
| 119 |
features.append(
|
| 120 |
{"type": "Feature", "geometry": geometry, "properties": properties}
|
| 121 |
)
|
src/gazet/lm.py
CHANGED
|
@@ -1,7 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import dspy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
from .schemas import PlacesResult
|
| 5 |
|
| 6 |
|
| 7 |
class ExtractPlaces(dspy.Signature):
|
|
@@ -20,6 +33,13 @@ class ExtractPlaces(dspy.Signature):
|
|
| 20 |
|
| 21 |
Where possible and relevant, also extract the ISO country code for each place.
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
Do not repeat the same place name in the result.
|
| 24 |
|
| 25 |
If the user does not explicitly mention a country, dont add the country code to the result.
|
|
@@ -103,10 +123,213 @@ class WriteGeoSQL(dspy.Signature):
|
|
| 103 |
)
|
| 104 |
|
| 105 |
|
| 106 |
-
|
| 107 |
-
f"ollama_chat/{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
)
|
| 109 |
-
dspy.configure(lm=lm)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
import dspy
|
| 5 |
+
import httpx
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from .config import (
|
| 9 |
+
LLAMA_MAX_TOKENS,
|
| 10 |
+
LLAMA_SERVER_URL,
|
| 11 |
+
LLAMA_TEMPERATURE,
|
| 12 |
+
PLACE_EXTRACTION_MODEL,
|
| 13 |
+
SQL_GENERATION_MODEL,
|
| 14 |
+
)
|
| 15 |
+
from .schemas import Place, PlacesResult
|
| 16 |
|
| 17 |
+
logger = logging.getLogger(__name__)
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class ExtractPlaces(dspy.Signature):
|
|
|
|
| 33 |
|
| 34 |
Where possible and relevant, also extract the ISO country code for each place.
|
| 35 |
|
| 36 |
+
Only extract place names that are explicitly mentioned in the query.
|
| 37 |
+
Do NOT generate or infer place names from your own knowledge.
|
| 38 |
+
For example:
|
| 39 |
+
- "north half of India" -> extract "India", NOT individual state names
|
| 40 |
+
- "coastal cities of France" -> extract "France", NOT city names
|
| 41 |
+
- "neighbouring states of Odisha" -> extract "Odisha", NOT neighbouring state names
|
| 42 |
+
|
| 43 |
Do not repeat the same place name in the result.
|
| 44 |
|
| 45 |
If the user does not explicitly mention a country, dont add the country code to the result.
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
|
| 126 |
+
place_extraction_lm = dspy.LM(
|
| 127 |
+
f"ollama_chat/{PLACE_EXTRACTION_MODEL}",
|
| 128 |
+
api_base="http://localhost:11434",
|
| 129 |
+
api_key="",
|
| 130 |
+
temperature=0.1,
|
| 131 |
+
cache=False,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
sql_generation_lm = dspy.LM(
|
| 135 |
+
f"ollama_chat/{SQL_GENERATION_MODEL}",
|
| 136 |
+
api_base="http://localhost:11434",
|
| 137 |
+
api_key="",
|
| 138 |
+
temperature=0.1,
|
| 139 |
+
cache=False,
|
| 140 |
+
think=False
|
| 141 |
)
|
|
|
|
| 142 |
|
| 143 |
+
|
| 144 |
+
class PlaceExtractor(dspy.Module):
|
| 145 |
+
def __init__(self, lm):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.lm = lm
|
| 148 |
+
self.predictor = dspy.Predict(ExtractPlaces)
|
| 149 |
+
|
| 150 |
+
def forward(self, query: str):
|
| 151 |
+
with dspy.context(lm=self.lm):
|
| 152 |
+
return self.predictor(query=query)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class SQLWriter(dspy.Module):
|
| 156 |
+
def __init__(self, lm):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.lm = lm
|
| 159 |
+
self.predictor = dspy.Predict(WriteGeoSQL)
|
| 160 |
+
|
| 161 |
+
def forward(self, user_query: str, schema: str, candidates: str,
|
| 162 |
+
previous_sql: str = "", execution_error: str = ""):
|
| 163 |
+
with dspy.context(lm=self.lm):
|
| 164 |
+
return self.predictor(
|
| 165 |
+
user_query=user_query,
|
| 166 |
+
schema=schema,
|
| 167 |
+
candidates=candidates,
|
| 168 |
+
previous_sql=previous_sql,
|
| 169 |
+
execution_error=execution_error
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
extract = PlaceExtractor(lm=place_extraction_lm)
|
| 174 |
+
write_sql = SQLWriter(lm=sql_generation_lm)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ── GGUF SQL generation via llama-server ──────────────────────────────────────
|
| 178 |
+
|
| 179 |
+
_SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
|
| 180 |
+
|
| 181 |
+
You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
|
| 182 |
+
|
| 183 |
+
<SCHEMA>
|
| 184 |
+
1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 185 |
+
query: read_parquet('divisions_area')
|
| 186 |
+
columns:
|
| 187 |
+
id VARCHAR -- unique feature id
|
| 188 |
+
names STRUCT("primary" VARCHAR, ...)
|
| 189 |
+
country VARCHAR -- ISO 3166-1 alpha-2
|
| 190 |
+
subtype VARCHAR -- country | region | dependency | county | localadmin |
|
| 191 |
+
locality | macrohood | neighborhood | microhood
|
| 192 |
+
class VARCHAR
|
| 193 |
+
region VARCHAR
|
| 194 |
+
admin_level INTEGER
|
| 195 |
+
division_id VARCHAR
|
| 196 |
+
is_land BOOLEAN
|
| 197 |
+
is_territorial BOOLEAN
|
| 198 |
+
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 199 |
+
|
| 200 |
+
2. natural_earth -- Natural Earth geography polygons (oceans, seas, rivers, terrain)
|
| 201 |
+
query: read_parquet('natural_earth')
|
| 202 |
+
columns:
|
| 203 |
+
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 204 |
+
names STRUCT("primary" VARCHAR, ...)
|
| 205 |
+
country VARCHAR
|
| 206 |
+
subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
|
| 207 |
+
class VARCHAR
|
| 208 |
+
region VARCHAR
|
| 209 |
+
admin_level INTEGER
|
| 210 |
+
is_land BOOLEAN
|
| 211 |
+
is_territorial BOOLEAN
|
| 212 |
+
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 213 |
+
</SCHEMA>
|
| 214 |
+
|
| 215 |
+
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 216 |
+
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 217 |
+
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 218 |
+
|
| 219 |
+
_USER_PROMPT_TEMPLATE = """<CANDIDATES>
|
| 220 |
+
{candidates_csv}
|
| 221 |
+
</CANDIDATES>
|
| 222 |
+
|
| 223 |
+
<USER_QUERY>
|
| 224 |
+
{question}
|
| 225 |
+
</USER_QUERY>
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _postprocess_sql(text: str) -> str:
|
| 230 |
+
"""Strip markdown fences and whitespace from generated SQL."""
|
| 231 |
+
cleaned = text.strip()
|
| 232 |
+
if "```sql" in cleaned:
|
| 233 |
+
cleaned = cleaned.split("```sql", 1)[1]
|
| 234 |
+
if cleaned.startswith("```"):
|
| 235 |
+
cleaned = cleaned[3:]
|
| 236 |
+
if "```" in cleaned:
|
| 237 |
+
cleaned = cleaned.split("```", 1)[0]
|
| 238 |
+
return cleaned.strip()
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def is_llama_server_available() -> bool:
|
| 242 |
+
"""Check if the llama-server is running and healthy."""
|
| 243 |
+
try:
|
| 244 |
+
resp = httpx.get(f"{LLAMA_SERVER_URL}/health", timeout=2)
|
| 245 |
+
return resp.status_code == 200
|
| 246 |
+
except (httpx.ConnectError, httpx.TimeoutException):
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _llama_chat_complete(messages: list[dict]) -> str:
|
| 251 |
+
"""Call llama-server /v1/chat/completions with a messages list."""
|
| 252 |
+
resp = httpx.post(
|
| 253 |
+
f"{LLAMA_SERVER_URL}/v1/chat/completions",
|
| 254 |
+
json={
|
| 255 |
+
"messages": messages,
|
| 256 |
+
"n_predict": LLAMA_MAX_TOKENS,
|
| 257 |
+
"temperature": LLAMA_TEMPERATURE,
|
| 258 |
+
"chat_template_kwargs": {"enable_thinking": False},
|
| 259 |
+
},
|
| 260 |
+
timeout=60,
|
| 261 |
+
)
|
| 262 |
+
if resp.status_code != 200:
|
| 263 |
+
logger.error("llama-server %s: %s", resp.status_code, resp.text[:500])
|
| 264 |
+
resp.raise_for_status()
|
| 265 |
+
return resp.json()["choices"][0]["message"]["content"]
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
_PLACES_SYSTEM_PROMPT = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
|
| 269 |
+
|
| 270 |
+
OUTPUT FORMAT:
|
| 271 |
+
{"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
|
| 272 |
+
"country" and "subtype" are optional; omit if not applicable.
|
| 273 |
+
|
| 274 |
+
RULES:
|
| 275 |
+
- Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
|
| 276 |
+
- No duplicate place names.
|
| 277 |
+
- "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
|
| 278 |
+
- "subtype": include only when the geographic level is clear from the query.
|
| 279 |
+
|
| 280 |
+
SUBTYPES:
|
| 281 |
+
country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
|
| 282 |
+
- Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def generate_places(user_query: str) -> PlacesResult:
|
| 286 |
+
"""Extract place names from a query using the finetuned GGUF model.
|
| 287 |
+
|
| 288 |
+
Uses the same prompt format the model was trained on.
|
| 289 |
+
Returns a PlacesResult; falls back to an empty result on parse failure.
|
| 290 |
+
"""
|
| 291 |
+
messages = [
|
| 292 |
+
{"role": "system", "content": _PLACES_SYSTEM_PROMPT},
|
| 293 |
+
{"role": "user", "content": user_query},
|
| 294 |
+
]
|
| 295 |
+
raw_output = _llama_chat_complete(messages).strip()
|
| 296 |
+
|
| 297 |
+
# Strip markdown fences if the model wrapped the JSON
|
| 298 |
+
if raw_output.startswith("```"):
|
| 299 |
+
raw_output = raw_output.split("```")[1]
|
| 300 |
+
if raw_output.startswith("json"):
|
| 301 |
+
raw_output = raw_output[4:]
|
| 302 |
+
raw_output = raw_output.strip()
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
data = json.loads(raw_output)
|
| 306 |
+
return PlacesResult.model_validate(data)
|
| 307 |
+
except Exception as exc:
|
| 308 |
+
logger.warning("generate_places: failed to parse output %r: %s", raw_output, exc)
|
| 309 |
+
# Best-effort: treat entire query as a single unnamed place
|
| 310 |
+
return PlacesResult(places=[Place(place=user_query)])
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
|
| 314 |
+
"""Generate SQL from a natural language query using the finetuned GGUF model.
|
| 315 |
+
|
| 316 |
+
Uses the same prompt format the model was trained on:
|
| 317 |
+
SYSTEM_PROMPT (includes schema) + USER_PROMPT_TEMPLATE with candidates CSV and question.
|
| 318 |
+
Single-shot — no retry loop (the finetuned model can't improve from error feedback).
|
| 319 |
+
"""
|
| 320 |
+
# Keep only columns the model was trained on
|
| 321 |
+
keep_cols = ["source", "id", "name", "subtype", "country", "region", "admin_level", "similarity"]
|
| 322 |
+
cols = [c for c in keep_cols if c in candidates_df.columns]
|
| 323 |
+
candidates_csv = candidates_df[cols].to_csv(index=False)
|
| 324 |
+
|
| 325 |
+
user_prompt = _USER_PROMPT_TEMPLATE.format(
|
| 326 |
+
candidates_csv=candidates_csv.strip(),
|
| 327 |
+
question=user_query.strip(),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
messages = [
|
| 331 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 332 |
+
{"role": "user", "content": user_prompt},
|
| 333 |
+
]
|
| 334 |
+
raw_output = _llama_chat_complete(messages)
|
| 335 |
+
return _postprocess_sql(raw_output)
|
src/gazet/search.py
CHANGED
|
@@ -5,31 +5,16 @@ from .config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
|
| 5 |
from .schemas import Place
|
| 6 |
|
| 7 |
|
| 8 |
-
def
|
| 9 |
con: duckdb.DuckDBPyConnection,
|
| 10 |
path: str,
|
| 11 |
source: str,
|
| 12 |
place: Place,
|
| 13 |
extra_select: str = "",
|
| 14 |
limit: int = 5,
|
| 15 |
-
is_overture: bool = False,
|
| 16 |
) -> pd.DataFrame:
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
country_params: list = []
|
| 20 |
-
if is_overture and place.country:
|
| 21 |
-
country_filter = "AND country = ?"
|
| 22 |
-
country_params = [place.country]
|
| 23 |
-
|
| 24 |
-
subtype_filter = ""
|
| 25 |
-
subtype_params: list = []
|
| 26 |
-
if is_overture and place.subtype:
|
| 27 |
-
subtype_filter = "AND subtype = ?"
|
| 28 |
-
subtype_params = [place.subtype]
|
| 29 |
-
|
| 30 |
-
params = (
|
| 31 |
-
[place.place, place.place, path] + country_params + subtype_params + [limit]
|
| 32 |
-
)
|
| 33 |
|
| 34 |
extra_clause = f", {extra_select}" if extra_select else ""
|
| 35 |
rel = con.execute(
|
|
@@ -44,12 +29,9 @@ def _fuzzy_search(
|
|
| 44 |
admin_level,
|
| 45 |
is_land,
|
| 46 |
is_territorial{extra_clause},
|
| 47 |
-
|
| 48 |
-
/ greatest(length(names."primary"), length(?), 1)) AS similarity
|
| 49 |
FROM read_parquet(?)
|
| 50 |
WHERE names."primary" IS NOT NULL AND trim(names."primary") != ''
|
| 51 |
-
{country_filter}
|
| 52 |
-
{subtype_filter}
|
| 53 |
ORDER BY similarity DESC, admin_level ASC
|
| 54 |
LIMIT ?
|
| 55 |
""",
|
|
@@ -57,11 +39,10 @@ def _fuzzy_search(
|
|
| 57 |
)
|
| 58 |
df = rel.fetchdf()
|
| 59 |
df.insert(0, "source", source)
|
| 60 |
-
label = f'"{place.place}"' + (f" [{place.country}]" if place.country else "")
|
| 61 |
if df.empty:
|
| 62 |
-
print(f"\n{source}
|
| 63 |
else:
|
| 64 |
-
print(f"\n{source}
|
| 65 |
print(df.to_string(index=False))
|
| 66 |
return df
|
| 67 |
|
|
@@ -70,14 +51,13 @@ def search_divisions_area(
|
|
| 70 |
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 71 |
) -> pd.DataFrame:
|
| 72 |
"""Fuzzy-match a place against divisions_area (Overture admin boundaries)."""
|
| 73 |
-
return
|
| 74 |
con,
|
| 75 |
DIVISIONS_AREA_PATH,
|
| 76 |
"divisions_area",
|
| 77 |
place,
|
| 78 |
extra_select="division_id",
|
| 79 |
limit=limit,
|
| 80 |
-
is_overture=True,
|
| 81 |
)
|
| 82 |
|
| 83 |
|
|
@@ -85,10 +65,26 @@ def search_natural_earth(
|
|
| 85 |
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 86 |
) -> pd.DataFrame:
|
| 87 |
"""Fuzzy-match a place against Natural Earth geography polygons."""
|
| 88 |
-
return
|
| 89 |
con,
|
| 90 |
NATURAL_EARTH_PATH,
|
| 91 |
"natural_earth",
|
| 92 |
place,
|
| 93 |
limit=limit,
|
| 94 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from .schemas import Place
|
| 6 |
|
| 7 |
|
| 8 |
+
def simple_fuzzy_search(
|
| 9 |
con: duckdb.DuckDBPyConnection,
|
| 10 |
path: str,
|
| 11 |
source: str,
|
| 12 |
place: Place,
|
| 13 |
extra_select: str = "",
|
| 14 |
limit: int = 5,
|
|
|
|
| 15 |
) -> pd.DataFrame:
|
| 16 |
+
"""Jaro-Winkler similarity search using only the place name."""
|
| 17 |
+
params = [place.place, path, limit]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
extra_clause = f", {extra_select}" if extra_select else ""
|
| 20 |
rel = con.execute(
|
|
|
|
| 29 |
admin_level,
|
| 30 |
is_land,
|
| 31 |
is_territorial{extra_clause},
|
| 32 |
+
jaro_winkler_similarity(lower(names."primary"), lower(?)) AS similarity
|
|
|
|
| 33 |
FROM read_parquet(?)
|
| 34 |
WHERE names."primary" IS NOT NULL AND trim(names."primary") != ''
|
|
|
|
|
|
|
| 35 |
ORDER BY similarity DESC, admin_level ASC
|
| 36 |
LIMIT ?
|
| 37 |
""",
|
|
|
|
| 39 |
)
|
| 40 |
df = rel.fetchdf()
|
| 41 |
df.insert(0, "source", source)
|
|
|
|
| 42 |
if df.empty:
|
| 43 |
+
print(f"\n{source} - \"{place.place}\": no matches")
|
| 44 |
else:
|
| 45 |
+
print(f"\n{source} - \"{place.place}\" (top {len(df)} by Jaro-Winkler):")
|
| 46 |
print(df.to_string(index=False))
|
| 47 |
return df
|
| 48 |
|
|
|
|
| 51 |
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 52 |
) -> pd.DataFrame:
|
| 53 |
"""Fuzzy-match a place against divisions_area (Overture admin boundaries)."""
|
| 54 |
+
return simple_fuzzy_search(
|
| 55 |
con,
|
| 56 |
DIVISIONS_AREA_PATH,
|
| 57 |
"divisions_area",
|
| 58 |
place,
|
| 59 |
extra_select="division_id",
|
| 60 |
limit=limit,
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
|
|
|
|
| 65 |
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 66 |
) -> pd.DataFrame:
|
| 67 |
"""Fuzzy-match a place against Natural Earth geography polygons."""
|
| 68 |
+
return simple_fuzzy_search(
|
| 69 |
con,
|
| 70 |
NATURAL_EARTH_PATH,
|
| 71 |
"natural_earth",
|
| 72 |
place,
|
| 73 |
limit=limit,
|
| 74 |
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def search_candidates(
|
| 78 |
+
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 79 |
+
) -> list[pd.DataFrame]:
|
| 80 |
+
"""Return candidate DataFrames for a place from both sources.
|
| 81 |
+
|
| 82 |
+
Always searches divisions_area and natural_earth to avoid missing
|
| 83 |
+
natural features when the model assigns an incorrect admin subtype.
|
| 84 |
+
"""
|
| 85 |
+
results = []
|
| 86 |
+
for fn in (search_divisions_area, search_natural_earth):
|
| 87 |
+
df = fn(con, place, limit=limit)
|
| 88 |
+
if not df.empty:
|
| 89 |
+
results.append(df)
|
| 90 |
+
return results
|
src/gazet/sql.py
CHANGED
|
@@ -4,8 +4,40 @@ from typing import Any, Generator, Optional
|
|
| 4 |
import duckdb
|
| 5 |
import pandas as pd
|
| 6 |
|
| 7 |
-
from .config import MAX_SQL_ITERATIONS, SCHEMA_INFO
|
| 8 |
-
from .lm import write_sql
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def _strip_fences(sql: Optional[str]) -> str:
|
|
@@ -17,13 +49,38 @@ def _strip_fences(sql: Optional[str]) -> str:
|
|
| 17 |
return sql.strip()
|
| 18 |
|
| 19 |
|
| 20 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
con: duckdb.DuckDBPyConnection,
|
| 22 |
user_query: str,
|
| 23 |
candidates_df: pd.DataFrame,
|
| 24 |
-
max_iterations: int = MAX_SQL_ITERATIONS,
|
| 25 |
) -> Generator[dict[str, Any], None, None]:
|
| 26 |
-
"""
|
| 27 |
|
| 28 |
Event types:
|
| 29 |
- ``sql_attempt`` – ``{"type": "sql_attempt", "sql": str, "iteration": int}``
|
|
@@ -31,7 +88,46 @@ def run_geo_sql_loop(
|
|
| 31 |
- ``result`` – ``{"type": "result", "df": DataFrame | None, "sql": str}``
|
| 32 |
"""
|
| 33 |
if candidates_df.empty:
|
| 34 |
-
print("\n[SQL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
yield {"type": "result", "df": None, "sql": ""}
|
| 36 |
return
|
| 37 |
|
|
@@ -41,7 +137,7 @@ def run_geo_sql_loop(
|
|
| 41 |
|
| 42 |
for iteration in range(1, max_iterations + 1):
|
| 43 |
print(f"\n{'=' * 60}")
|
| 44 |
-
print(f"[SQL
|
| 45 |
|
| 46 |
try:
|
| 47 |
pred = write_sql(
|
|
@@ -86,6 +182,6 @@ def run_geo_sql_loop(
|
|
| 86 |
yield {"type": "sql_error", "error": error, "iteration": iteration}
|
| 87 |
|
| 88 |
print(
|
| 89 |
-
f"\n[SQL
|
| 90 |
)
|
| 91 |
yield {"type": "result", "df": None, "sql": ""}
|
|
|
|
| 4 |
import duckdb
|
| 5 |
import pandas as pd
|
| 6 |
|
| 7 |
+
from .config import DIVISIONS_AREA_PATH, MAX_SQL_ITERATIONS, NATURAL_EARTH_PATH, SCHEMA_INFO
|
| 8 |
+
from .lm import generate_sql, write_sql
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _rewrite_data_paths(sql: str) -> str:
|
| 12 |
+
"""Replace any read_parquet table reference with the correct runtime path.
|
| 13 |
+
|
| 14 |
+
Handles three generations of model output:
|
| 15 |
+
- Symbolic: read_parquet('divisions_area')
|
| 16 |
+
- Old paths: read_parquet('/data/overture/division_area/...')
|
| 17 |
+
- Hallucinated variants: any quoted path containing 'division' or 'natural_earth'
|
| 18 |
+
|
| 19 |
+
Legacy replacements run FIRST so the absolute path is never re-matched.
|
| 20 |
+
"""
|
| 21 |
+
# Any quoted path that looks like a divisions_area reference
|
| 22 |
+
sql = re.sub(
|
| 23 |
+
r"read_parquet\(['\"][^'\"]*(?:division_area|divisions_area)[^'\"]*['\"]\)",
|
| 24 |
+
f"read_parquet('{DIVISIONS_AREA_PATH}')",
|
| 25 |
+
sql,
|
| 26 |
+
)
|
| 27 |
+
# Any quoted path that looks like a natural_earth reference
|
| 28 |
+
sql = re.sub(
|
| 29 |
+
r"read_parquet\(['\"][^'\"]*natural_earth[^'\"]*['\"]\)",
|
| 30 |
+
f"read_parquet('{NATURAL_EARTH_PATH}')",
|
| 31 |
+
sql,
|
| 32 |
+
)
|
| 33 |
+
# Symbolic names (current training format)
|
| 34 |
+
sql = sql.replace(
|
| 35 |
+
"read_parquet('divisions_area')", f"read_parquet('{DIVISIONS_AREA_PATH}')"
|
| 36 |
+
)
|
| 37 |
+
sql = sql.replace(
|
| 38 |
+
"read_parquet('natural_earth')", f"read_parquet('{NATURAL_EARTH_PATH}')"
|
| 39 |
+
)
|
| 40 |
+
return sql
|
| 41 |
|
| 42 |
|
| 43 |
def _strip_fences(sql: Optional[str]) -> str:
|
|
|
|
| 49 |
return sql.strip()
|
| 50 |
|
| 51 |
|
| 52 |
+
def _execute_sql(
|
| 53 |
+
con: duckdb.DuckDBPyConnection,
|
| 54 |
+
sql: str,
|
| 55 |
+
label: str,
|
| 56 |
+
iteration: int,
|
| 57 |
+
) -> Generator[dict[str, Any], None, None]:
|
| 58 |
+
"""Execute SQL and yield result/error events. Shared by both paths."""
|
| 59 |
+
try:
|
| 60 |
+
result_df = con.execute(sql).fetchdf()
|
| 61 |
+
if result_df.empty:
|
| 62 |
+
print(f"[{label}] Query returned no rows.")
|
| 63 |
+
yield {"type": "sql_error", "error": "Query returned no rows", "iteration": iteration}
|
| 64 |
+
yield {"type": "result", "df": None, "sql": sql}
|
| 65 |
+
else:
|
| 66 |
+
print(f"[{label}] Result ({len(result_df)} row(s))")
|
| 67 |
+
yield {"type": "result", "df": result_df, "sql": sql}
|
| 68 |
+
except Exception as exc:
|
| 69 |
+
error = str(exc)
|
| 70 |
+
print(f"[{label}] Execution error: {error}")
|
| 71 |
+
yield {"type": "sql_error", "error": error, "iteration": iteration}
|
| 72 |
+
yield {"type": "result", "df": None, "sql": sql}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ── GGUF path: finetuned model via llama-server (single-shot) ─────────────────
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def run_geo_sql_gguf(
|
| 79 |
con: duckdb.DuckDBPyConnection,
|
| 80 |
user_query: str,
|
| 81 |
candidates_df: pd.DataFrame,
|
|
|
|
| 82 |
) -> Generator[dict[str, Any], None, None]:
|
| 83 |
+
"""Single-shot text-to-SQL via the finetuned GGUF model (llama-server).
|
| 84 |
|
| 85 |
Event types:
|
| 86 |
- ``sql_attempt`` – ``{"type": "sql_attempt", "sql": str, "iteration": int}``
|
|
|
|
| 88 |
- ``result`` – ``{"type": "result", "df": DataFrame | None, "sql": str}``
|
| 89 |
"""
|
| 90 |
if candidates_df.empty:
|
| 91 |
+
print("\n[SQL·GGUF] No candidates to work with — skipping.")
|
| 92 |
+
yield {"type": "result", "df": None, "sql": ""}
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
sql = generate_sql(user_query, candidates_df)
|
| 97 |
+
except Exception as exc:
|
| 98 |
+
error = f"GGUF generation failed: {exc}"
|
| 99 |
+
print(f"[SQL·GGUF] {error}")
|
| 100 |
+
yield {"type": "sql_error", "error": error, "iteration": 1}
|
| 101 |
+
yield {"type": "result", "df": None, "sql": ""}
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
if not sql:
|
| 105 |
+
print("[SQL·GGUF] Model returned empty SQL.")
|
| 106 |
+
yield {"type": "sql_error", "error": "Empty SQL response", "iteration": 1}
|
| 107 |
+
yield {"type": "result", "df": None, "sql": ""}
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
sql = _rewrite_data_paths(sql)
|
| 111 |
+
print(f"\n[SQL·GGUF] Generated:\n{sql}\n")
|
| 112 |
+
yield {"type": "sql_attempt", "sql": sql, "iteration": 1}
|
| 113 |
+
yield from _execute_sql(con, sql, "SQL·GGUF", iteration=1)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ── DSPy path: cloud/local LM with retry loop ────────────────────────────────
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def run_geo_sql_dspy(
|
| 120 |
+
con: duckdb.DuckDBPyConnection,
|
| 121 |
+
user_query: str,
|
| 122 |
+
candidates_df: pd.DataFrame,
|
| 123 |
+
max_iterations: int = MAX_SQL_ITERATIONS,
|
| 124 |
+
) -> Generator[dict[str, Any], None, None]:
|
| 125 |
+
"""Code-act retry loop using the DSPy SQL writer (Ollama / cloud LM).
|
| 126 |
+
|
| 127 |
+
Same event types as ``run_geo_sql_gguf``.
|
| 128 |
+
"""
|
| 129 |
+
if candidates_df.empty:
|
| 130 |
+
print("\n[SQL·DSPy] No candidates to work with — skipping.")
|
| 131 |
yield {"type": "result", "df": None, "sql": ""}
|
| 132 |
return
|
| 133 |
|
|
|
|
| 137 |
|
| 138 |
for iteration in range(1, max_iterations + 1):
|
| 139 |
print(f"\n{'=' * 60}")
|
| 140 |
+
print(f"[SQL·DSPy] Iteration {iteration}/{max_iterations}")
|
| 141 |
|
| 142 |
try:
|
| 143 |
pred = write_sql(
|
|
|
|
| 182 |
yield {"type": "sql_error", "error": error, "iteration": iteration}
|
| 183 |
|
| 184 |
print(
|
| 185 |
+
f"\n[SQL·DSPy] Exhausted {max_iterations} iterations without a successful query."
|
| 186 |
)
|
| 187 |
yield {"type": "result", "df": None, "sql": ""}
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|