Spaces:
Sleeping
Sleeping
chore: clean sql generation, use conversation format, move prompts from user to system
Browse files- dataset/scripts/convert_to_conversation_format.py +0 -134
- dataset/scripts/export_training_data.py +55 -25
- dataset/scripts/validate_dataset.py +14 -0
- finetune/eval_cli.py +5 -9
- finetune/nlg.py +41 -51
- finetune/prompts.py +18 -25
- finetune/train_modal_qwen35.py +2 -3
- src/gazet/lm.py +23 -17
dataset/scripts/convert_to_conversation_format.py
DELETED
|
@@ -1,134 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Convert prompt-completion format to conversation format.
|
| 3 |
-
|
| 4 |
-
Reads SQL and places JSONL from a run directory and converts to a single
|
| 5 |
-
"messages" list format suitable for various downstream uses.
|
| 6 |
-
|
| 7 |
-
Input format (current):
|
| 8 |
-
{
|
| 9 |
-
"prompt": [
|
| 10 |
-
{"role": "system", "content": "..."},
|
| 11 |
-
{"role": "user", "content": "..."}
|
| 12 |
-
],
|
| 13 |
-
"completion": [
|
| 14 |
-
{"role": "assistant", "content": "..."}
|
| 15 |
-
],
|
| 16 |
-
"metadata": {...}
|
| 17 |
-
}
|
| 18 |
-
|
| 19 |
-
Output format:
|
| 20 |
-
{
|
| 21 |
-
"messages": [
|
| 22 |
-
{"role": "system", "content": "..."},
|
| 23 |
-
{"role": "user", "content": "..."},
|
| 24 |
-
{"role": "assistant", "content": "..."}
|
| 25 |
-
]
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
Saves to JSONL files:
|
| 29 |
-
- train_conversation_sql.jsonl
|
| 30 |
-
- val_conversation_sql.jsonl
|
| 31 |
-
- test_conversation_sql.jsonl
|
| 32 |
-
- train_conversation_places.jsonl
|
| 33 |
-
- val_conversation_places.jsonl
|
| 34 |
-
- test_conversation_places.jsonl
|
| 35 |
-
|
| 36 |
-
Usage with datasets library:
|
| 37 |
-
from datasets import load_dataset
|
| 38 |
-
|
| 39 |
-
train_sql = load_dataset(
|
| 40 |
-
"json",
|
| 41 |
-
data_files="dataset/output/conversations/train_conversation_sql.jsonl",
|
| 42 |
-
split="train"
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
# Access messages:
|
| 46 |
-
print(train_sql[0]["messages"])
|
| 47 |
-
"""
|
| 48 |
-
|
| 49 |
-
import argparse
|
| 50 |
-
import json
|
| 51 |
-
from pathlib import Path
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def load_jsonl(path: Path) -> list[dict]:
|
| 55 |
-
rows = []
|
| 56 |
-
with open(path) as f:
|
| 57 |
-
for line in f:
|
| 58 |
-
line = line.strip()
|
| 59 |
-
if line:
|
| 60 |
-
rows.append(json.loads(line))
|
| 61 |
-
return rows
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def to_conversation_format(sample: dict) -> dict:
|
| 65 |
-
"""Convert prompt+completion format to messages format."""
|
| 66 |
-
return {
|
| 67 |
-
"messages": sample["prompt"] + sample["completion"],
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def process_task(run_dir: Path, task: str, output_dir: Path):
|
| 72 |
-
"""Process all splits for a single task (sql or places)."""
|
| 73 |
-
task_dir = run_dir / task
|
| 74 |
-
|
| 75 |
-
for split in ["train", "val", "test"]:
|
| 76 |
-
input_path = task_dir / f"{split}.jsonl"
|
| 77 |
-
if not input_path.exists():
|
| 78 |
-
print(f" Skipping {task}/{split}: {input_path} not found")
|
| 79 |
-
continue
|
| 80 |
-
|
| 81 |
-
samples = load_jsonl(input_path)
|
| 82 |
-
conversations = [to_conversation_format(s) for s in samples]
|
| 83 |
-
|
| 84 |
-
output_path = output_dir / f"{split}_conversation_{task}.jsonl"
|
| 85 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 86 |
-
|
| 87 |
-
with open(output_path, "w") as f:
|
| 88 |
-
for conv in conversations:
|
| 89 |
-
f.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
| 90 |
-
|
| 91 |
-
print(f" {task}/{split}: {len(conversations)} samples β {output_path}")
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def main():
|
| 95 |
-
parser = argparse.ArgumentParser(
|
| 96 |
-
description="Convert prompt-completion format to conversation format"
|
| 97 |
-
)
|
| 98 |
-
parser.add_argument(
|
| 99 |
-
"--run-dir",
|
| 100 |
-
type=Path,
|
| 101 |
-
default=Path("dataset/output/runs/v3-symbolic-paths"),
|
| 102 |
-
help="Path to run directory containing sql/ and places/ subdirectories",
|
| 103 |
-
)
|
| 104 |
-
parser.add_argument(
|
| 105 |
-
"--output-dir",
|
| 106 |
-
type=Path,
|
| 107 |
-
default=Path("dataset/output/conversations"),
|
| 108 |
-
help="Output directory for JSONL files",
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
args = parser.parse_args()
|
| 112 |
-
|
| 113 |
-
run_dir = args.run_dir
|
| 114 |
-
output_dir = args.output_dir
|
| 115 |
-
|
| 116 |
-
if not run_dir.exists():
|
| 117 |
-
print(f"Error: Run directory not found: {run_dir}")
|
| 118 |
-
return 1
|
| 119 |
-
|
| 120 |
-
print(f"Converting from: {run_dir}")
|
| 121 |
-
print(f"Output directory: {output_dir}")
|
| 122 |
-
print()
|
| 123 |
-
|
| 124 |
-
for task in ["sql", "places"]:
|
| 125 |
-
print(f"Processing {task}:")
|
| 126 |
-
process_task(run_dir, task, output_dir)
|
| 127 |
-
|
| 128 |
-
print()
|
| 129 |
-
print("Conversion complete!")
|
| 130 |
-
print(f"Output files in: {output_dir}/")
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
if __name__ == "__main__":
|
| 134 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/scripts/export_training_data.py
CHANGED
|
@@ -27,8 +27,11 @@ from collections import defaultdict
|
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Any, Dict, List, Optional, Tuple
|
| 29 |
|
|
|
|
| 30 |
import yaml
|
| 31 |
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# ---------------------------------------------------------------------------
|
| 34 |
# Loading
|
|
@@ -89,16 +92,12 @@ def stratified_split(
|
|
| 89 |
# Conversational prompt-completion: model sees system + user, generates SQL.
|
| 90 |
# ---------------------------------------------------------------------------
|
| 91 |
|
| 92 |
-
_SQL_SYSTEM =
|
| 93 |
-
"You are a text to SQL query translator that helps in natural language geocoding."
|
| 94 |
-
)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
"source", "id", "name", "subtype", "country", "region",
|
| 98 |
-
"admin_level", "similarity",
|
| 99 |
-
]
|
| 100 |
|
| 101 |
-
|
|
|
|
| 102 |
query: read_parquet('divisions_area')
|
| 103 |
columns:
|
| 104 |
id VARCHAR -- unique feature id
|
|
@@ -127,11 +126,17 @@ _SCHEMA = """1. divisions_area -- Overture polygon/multipolygon admin boundarie
|
|
| 127 |
is_land BOOLEAN
|
| 128 |
is_territorial BOOLEAN
|
| 129 |
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
|
|
|
| 130 |
|
| 131 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 132 |
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 133 |
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
def _candidates_csv(candidates: List[Dict]) -> str:
|
| 137 |
import io
|
|
@@ -149,26 +154,42 @@ def _candidates_csv(candidates: List[Dict]) -> str:
|
|
| 149 |
return buf.getvalue().strip()
|
| 150 |
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
| 153 |
"""Convert a raw sample to a conversational prompt-completion pair for SQL generation."""
|
| 154 |
sql = sample.get("target", {}).get("sql", "").strip()
|
| 155 |
if not sql:
|
| 156 |
return None
|
|
|
|
| 157 |
|
| 158 |
user_content = (
|
| 159 |
-
"GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, "
|
| 160 |
-
"generate the corresponding SQL command to retrieve the desired geometry.\n\n"
|
| 161 |
-
f"<SCHEMA_DETAILS>\n{_SCHEMA}\n</SCHEMA_DETAILS>\n\n"
|
| 162 |
f"<CANDIDATES>\n{_candidates_csv(sample.get('candidates', []))}\n</CANDIDATES>\n\n"
|
| 163 |
f"<USER_QUERY>\n{sample['question']}\n</USER_QUERY>"
|
| 164 |
)
|
| 165 |
|
| 166 |
return {
|
| 167 |
-
"
|
| 168 |
-
{"role": "system",
|
| 169 |
-
{"role": "user",
|
| 170 |
-
],
|
| 171 |
-
"completion": [
|
| 172 |
{"role": "assistant", "content": sql},
|
| 173 |
],
|
| 174 |
"metadata": sample.get("metadata", {}),
|
|
@@ -180,10 +201,21 @@ def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
|
| 180 |
# Derived from the same SQL samples: selected_candidates β PlacesResult JSON.
|
| 181 |
# ---------------------------------------------------------------------------
|
| 182 |
|
| 183 |
-
_PLACE_SYSTEM =
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# Overture division subtypes β used to filter out natural_earth candidates
|
| 189 |
# from the place extraction output (NE features don't have these subtypes).
|
|
@@ -241,11 +273,9 @@ def sample_to_place_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
|
| 241 |
completion_json = json.dumps({"places": places}, ensure_ascii=False)
|
| 242 |
|
| 243 |
return {
|
| 244 |
-
"
|
| 245 |
-
{"role": "system",
|
| 246 |
-
{"role": "user",
|
| 247 |
-
],
|
| 248 |
-
"completion": [
|
| 249 |
{"role": "assistant", "content": completion_json},
|
| 250 |
],
|
| 251 |
"metadata": sample.get("metadata", {}),
|
|
|
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Any, Dict, List, Optional, Tuple
|
| 29 |
|
| 30 |
+
import sqlparse
|
| 31 |
import yaml
|
| 32 |
|
| 33 |
+
from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
| 34 |
+
|
| 35 |
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
# Loading
|
|
|
|
| 92 |
# Conversational prompt-completion: model sees system + user, generates SQL.
|
| 93 |
# ---------------------------------------------------------------------------
|
| 94 |
|
| 95 |
+
_SQL_SYSTEM = """You are a text to SQL query translator that helps in natural language geocoding.
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
<SCHEMA>
|
| 100 |
+
1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 101 |
query: read_parquet('divisions_area')
|
| 102 |
columns:
|
| 103 |
id VARCHAR -- unique feature id
|
|
|
|
| 126 |
is_land BOOLEAN
|
| 127 |
is_territorial BOOLEAN
|
| 128 |
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 129 |
+
</SCHEMA>
|
| 130 |
|
| 131 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 132 |
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 133 |
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 134 |
|
| 135 |
+
_CANDIDATES_COLS = [
|
| 136 |
+
"source", "id", "name", "subtype", "country", "region",
|
| 137 |
+
"admin_level", "similarity",
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
|
| 141 |
def _candidates_csv(candidates: List[Dict]) -> str:
|
| 142 |
import io
|
|
|
|
| 154 |
return buf.getvalue().strip()
|
| 155 |
|
| 156 |
|
| 157 |
+
def _to_symbolic_sql(sql: str) -> str:
|
| 158 |
+
"""Normalize any hardcoded or runtime paths back to symbolic names."""
|
| 159 |
+
sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
|
| 160 |
+
sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
|
| 161 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
|
| 162 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
|
| 163 |
+
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
|
| 164 |
+
return sql
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _format_sql(sql: str) -> str:
|
| 168 |
+
"""Pretty-print SQL so the model learns clean, readable style."""
|
| 169 |
+
return sqlparse.format(
|
| 170 |
+
sql,
|
| 171 |
+
reindent=True,
|
| 172 |
+
keyword_case="upper",
|
| 173 |
+
indent_width=4,
|
| 174 |
+
).strip()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
| 178 |
"""Convert a raw sample to a conversational prompt-completion pair for SQL generation."""
|
| 179 |
sql = sample.get("target", {}).get("sql", "").strip()
|
| 180 |
if not sql:
|
| 181 |
return None
|
| 182 |
+
sql = _format_sql(_to_symbolic_sql(sql))
|
| 183 |
|
| 184 |
user_content = (
|
|
|
|
|
|
|
|
|
|
| 185 |
f"<CANDIDATES>\n{_candidates_csv(sample.get('candidates', []))}\n</CANDIDATES>\n\n"
|
| 186 |
f"<USER_QUERY>\n{sample['question']}\n</USER_QUERY>"
|
| 187 |
)
|
| 188 |
|
| 189 |
return {
|
| 190 |
+
"messages": [
|
| 191 |
+
{"role": "system", "content": _SQL_SYSTEM},
|
| 192 |
+
{"role": "user", "content": user_content},
|
|
|
|
|
|
|
| 193 |
{"role": "assistant", "content": sql},
|
| 194 |
],
|
| 195 |
"metadata": sample.get("metadata", {}),
|
|
|
|
| 201 |
# Derived from the same SQL samples: selected_candidates β PlacesResult JSON.
|
| 202 |
# ---------------------------------------------------------------------------
|
| 203 |
|
| 204 |
+
_PLACE_SYSTEM = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
|
| 205 |
+
|
| 206 |
+
OUTPUT FORMAT:
|
| 207 |
+
{"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
|
| 208 |
+
"country" and "subtype" are optional; omit if not applicable.
|
| 209 |
+
|
| 210 |
+
RULES:
|
| 211 |
+
- Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
|
| 212 |
+
- No duplicate place names.
|
| 213 |
+
- "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
|
| 214 |
+
- "subtype": include only when the geographic level is clear from the query.
|
| 215 |
+
|
| 216 |
+
SUBTYPES:
|
| 217 |
+
country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
|
| 218 |
+
- Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
|
| 219 |
|
| 220 |
# Overture division subtypes β used to filter out natural_earth candidates
|
| 221 |
# from the place extraction output (NE features don't have these subtypes).
|
|
|
|
| 273 |
completion_json = json.dumps({"places": places}, ensure_ascii=False)
|
| 274 |
|
| 275 |
return {
|
| 276 |
+
"messages": [
|
| 277 |
+
{"role": "system", "content": _PLACE_SYSTEM},
|
| 278 |
+
{"role": "user", "content": sample["question"]},
|
|
|
|
|
|
|
| 279 |
{"role": "assistant", "content": completion_json},
|
| 280 |
],
|
| 281 |
"metadata": sample.get("metadata", {}),
|
dataset/scripts/validate_dataset.py
CHANGED
|
@@ -50,6 +50,18 @@ def _resolve_paths(sql: str) -> str:
|
|
| 50 |
return sql
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
|
| 54 |
"""Validate that SQL executes without error.
|
| 55 |
|
|
@@ -120,6 +132,8 @@ def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]
|
|
| 120 |
try:
|
| 121 |
is_valid, issues = validate_sample(con, sample)
|
| 122 |
con.close()
|
|
|
|
|
|
|
| 123 |
return (sample['id'], is_valid, issues, sample if is_valid else None)
|
| 124 |
except Exception as e:
|
| 125 |
con.close()
|
|
|
|
| 50 |
return sql
|
| 51 |
|
| 52 |
|
| 53 |
+
def _to_symbolic_sql(sql: str) -> str:
|
| 54 |
+
"""Normalize any hardcoded or runtime paths back to symbolic names for storage."""
|
| 55 |
+
# Current local runtime paths
|
| 56 |
+
sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
|
| 57 |
+
sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
|
| 58 |
+
# Legacy Docker paths
|
| 59 |
+
sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
|
| 60 |
+
sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
|
| 61 |
+
sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
|
| 62 |
+
return sql
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
|
| 66 |
"""Validate that SQL executes without error.
|
| 67 |
|
|
|
|
| 132 |
try:
|
| 133 |
is_valid, issues = validate_sample(con, sample)
|
| 134 |
con.close()
|
| 135 |
+
if is_valid:
|
| 136 |
+
sample['target']['sql'] = _to_symbolic_sql(sample['target']['sql'])
|
| 137 |
return (sample['id'], is_valid, issues, sample if is_valid else None)
|
| 138 |
except Exception as e:
|
| 139 |
con.close()
|
finetune/eval_cli.py
CHANGED
|
@@ -83,19 +83,15 @@ def load_samples(run_dir: Path, task: str) -> list[dict]:
|
|
| 83 |
|
| 84 |
|
| 85 |
def build_raw_prompt(sample: dict) -> str:
|
| 86 |
-
"""Reconstruct the plain prompt string from
|
| 87 |
-
|
| 88 |
-
sample["prompt"] is [{role:system, content:...}, {role:user, content:...}].
|
| 89 |
-
Joins them with a blank line β same format used during training.
|
| 90 |
-
"""
|
| 91 |
-
return sample["prompt"][0]["content"] + "\n\n" + sample["prompt"][1]["content"]
|
| 92 |
|
| 93 |
|
| 94 |
def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
|
| 95 |
-
expected = sample["
|
| 96 |
-
messages = sample["
|
| 97 |
|
| 98 |
-
user_content = sample["
|
| 99 |
if "<USER_QUERY>" in user_content:
|
| 100 |
question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
|
| 101 |
else:
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def build_raw_prompt(sample: dict) -> str:
|
| 86 |
+
"""Reconstruct the plain prompt string from messages format (all turns except assistant)."""
|
| 87 |
+
return "\n\n".join(m["content"] for m in sample["messages"][:-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
|
| 91 |
+
expected = sample["messages"][-1]["content"]
|
| 92 |
+
messages = sample["messages"][:-1]
|
| 93 |
|
| 94 |
+
user_content = sample["messages"][-2]["content"]
|
| 95 |
if "<USER_QUERY>" in user_content:
|
| 96 |
question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
|
| 97 |
else:
|
finetune/nlg.py
CHANGED
|
@@ -66,49 +66,54 @@ from trl import SFTConfig, SFTTrainer
|
|
| 66 |
|
| 67 |
LOGGER = logging.getLogger("nlg")
|
| 68 |
|
| 69 |
-
SYSTEM_PROMPT =
|
| 70 |
-
"You are a text to SQL query translator that helps in natural language geocoding."
|
| 71 |
-
)
|
| 72 |
|
| 73 |
-
|
| 74 |
|
| 75 |
-
<
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
<CANDIDATES>
|
| 80 |
-
{candidates_csv}
|
| 81 |
-
</CANDIDATES>
|
| 82 |
-
|
| 83 |
-
<USER_QUERY>
|
| 84 |
-
{question}
|
| 85 |
-
</USER_QUERY>
|
| 86 |
-
"""
|
| 87 |
-
|
| 88 |
-
DEFAULT_SCHEMA_DETAILS = """1. divisions_area β Overture polygon/multipolygon admin boundaries
|
| 89 |
-
path: '/data/overture/division_area/*.parquet'
|
| 90 |
columns:
|
| 91 |
-
id VARCHAR
|
| 92 |
names STRUCT("primary" VARCHAR, ...)
|
| 93 |
-
country VARCHAR
|
| 94 |
-
subtype VARCHAR
|
|
|
|
| 95 |
class VARCHAR
|
| 96 |
region VARCHAR
|
| 97 |
admin_level INTEGER
|
| 98 |
division_id VARCHAR
|
| 99 |
is_land BOOLEAN
|
| 100 |
is_territorial BOOLEAN
|
| 101 |
-
geometry GEOMETRY
|
| 102 |
|
| 103 |
-
2. natural_earth
|
| 104 |
-
|
| 105 |
columns:
|
| 106 |
-
id VARCHAR
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
@dataclass
|
|
@@ -126,12 +131,6 @@ def setup_logging(verbose: bool = False) -> None:
|
|
| 126 |
)
|
| 127 |
|
| 128 |
|
| 129 |
-
def read_text(path: Optional[str], default: str) -> str:
|
| 130 |
-
if not path:
|
| 131 |
-
return default
|
| 132 |
-
return Path(path).read_text(encoding="utf-8")
|
| 133 |
-
|
| 134 |
-
|
| 135 |
def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
| 136 |
df = pd.DataFrame(list(candidates))
|
| 137 |
if "candidate_id" in df.columns:
|
|
@@ -139,15 +138,14 @@ def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
|
| 139 |
return df.to_csv(index=False)
|
| 140 |
|
| 141 |
|
| 142 |
-
def build_user_prompt(question: str, candidates: Sequence[Dict[str, Any]]
|
| 143 |
return USER_PROMPT_TEMPLATE.format(
|
| 144 |
-
schema_details=schema_details.strip(),
|
| 145 |
candidates_csv=candidates_to_csv(candidates).strip(),
|
| 146 |
question=question.strip(),
|
| 147 |
)
|
| 148 |
|
| 149 |
|
| 150 |
-
def make_messages(sample: Dict[str, Any]
|
| 151 |
messages = [
|
| 152 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 153 |
{
|
|
@@ -155,7 +153,6 @@ def make_messages(sample: Dict[str, Any], schema_details: str) -> Dict[str, Any]
|
|
| 155 |
"content": build_user_prompt(
|
| 156 |
question=sample["question"],
|
| 157 |
candidates=sample["candidates"],
|
| 158 |
-
schema_details=schema_details,
|
| 159 |
),
|
| 160 |
},
|
| 161 |
]
|
|
@@ -178,10 +175,10 @@ def load_jsonl_splits(
|
|
| 178 |
return load_dataset("json", data_files=data_files)
|
| 179 |
|
| 180 |
|
| 181 |
-
def format_dataset_for_sft(dataset: DatasetDict
|
| 182 |
formatted = DatasetDict()
|
| 183 |
for split, ds in dataset.items():
|
| 184 |
-
formatted[split] = ds.map(
|
| 185 |
return formatted
|
| 186 |
|
| 187 |
|
|
@@ -271,9 +268,8 @@ def build_lora_config(args: argparse.Namespace) -> LoraConfig:
|
|
| 271 |
|
| 272 |
def train(args: argparse.Namespace) -> None:
|
| 273 |
set_seed(args.seed)
|
| 274 |
-
schema_details = read_text(args.schema_file, DEFAULT_SCHEMA_DETAILS)
|
| 275 |
raw_ds = load_jsonl_splits(args.train_jsonl, args.val_jsonl, args.test_jsonl)
|
| 276 |
-
ds = format_dataset_for_sft(raw_ds
|
| 277 |
|
| 278 |
if args.max_train_samples is not None:
|
| 279 |
ds["train"] = ds["train"].select(range(min(args.max_train_samples, len(ds["train"]))))
|
|
@@ -362,7 +358,6 @@ def generate_sql(
|
|
| 362 |
tokenizer,
|
| 363 |
question: str,
|
| 364 |
candidates: Sequence[Dict[str, Any]],
|
| 365 |
-
schema_details: str,
|
| 366 |
max_new_tokens: int = 256,
|
| 367 |
do_sample: bool = False,
|
| 368 |
temperature: float = 0.1,
|
|
@@ -371,7 +366,6 @@ def generate_sql(
|
|
| 371 |
) -> GenerationResult:
|
| 372 |
messages = make_messages(
|
| 373 |
{"question": question, "candidates": list(candidates), "target": {}},
|
| 374 |
-
schema_details,
|
| 375 |
)["messages"]
|
| 376 |
prompt = render_prompt(tokenizer, messages)
|
| 377 |
inputs = tokenizer.apply_chat_template(
|
|
@@ -446,7 +440,6 @@ def execute_sqlite(sql: str, sqlite_db: str, limit: Optional[int] = None) -> Tup
|
|
| 446 |
|
| 447 |
|
| 448 |
def cmd_generate(args: argparse.Namespace) -> None:
|
| 449 |
-
schema_details = read_text(args.schema_file, DEFAULT_SCHEMA_DETAILS)
|
| 450 |
question = read_question(args)
|
| 451 |
candidates = read_candidates(args)
|
| 452 |
model, tokenizer = load_model_for_inference(
|
|
@@ -463,7 +456,6 @@ def cmd_generate(args: argparse.Namespace) -> None:
|
|
| 463 |
tokenizer=tokenizer,
|
| 464 |
question=question,
|
| 465 |
candidates=candidates,
|
| 466 |
-
schema_details=schema_details,
|
| 467 |
max_new_tokens=args.max_new_tokens,
|
| 468 |
do_sample=args.do_sample,
|
| 469 |
temperature=args.temperature,
|
|
@@ -511,7 +503,6 @@ def build_parser() -> argparse.ArgumentParser:
|
|
| 511 |
train_p.add_argument("--train-jsonl", required=True)
|
| 512 |
train_p.add_argument("--val-jsonl")
|
| 513 |
train_p.add_argument("--test-jsonl")
|
| 514 |
-
train_p.add_argument("--schema-file")
|
| 515 |
train_p.add_argument("--output-dir", required=True)
|
| 516 |
train_p.add_argument("--max-train-samples", type=int)
|
| 517 |
train_p.add_argument("--max-eval-samples", type=int)
|
|
@@ -552,7 +543,6 @@ def build_parser() -> argparse.ArgumentParser:
|
|
| 552 |
gen_p.add_argument("--model-path")
|
| 553 |
gen_p.add_argument("--base-model")
|
| 554 |
gen_p.add_argument("--adapter-path")
|
| 555 |
-
gen_p.add_argument("--schema-file")
|
| 556 |
gen_p.add_argument("--question")
|
| 557 |
gen_p.add_argument("--candidates-json")
|
| 558 |
gen_p.add_argument("--sample-jsonl")
|
|
|
|
| 66 |
|
| 67 |
LOGGER = logging.getLogger("nlg")
|
| 68 |
|
| 69 |
+
SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
|
| 72 |
|
| 73 |
+
<SCHEMA>
|
| 74 |
+
1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 75 |
+
query: read_parquet('divisions_area')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
columns:
|
| 77 |
+
id VARCHAR -- unique feature id
|
| 78 |
names STRUCT("primary" VARCHAR, ...)
|
| 79 |
+
country VARCHAR -- ISO 3166-1 alpha-2
|
| 80 |
+
subtype VARCHAR -- country | region | dependency | county | localadmin |
|
| 81 |
+
locality | macrohood | neighborhood | microhood
|
| 82 |
class VARCHAR
|
| 83 |
region VARCHAR
|
| 84 |
admin_level INTEGER
|
| 85 |
division_id VARCHAR
|
| 86 |
is_land BOOLEAN
|
| 87 |
is_territorial BOOLEAN
|
| 88 |
+
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 89 |
|
| 90 |
+
2. natural_earth -- Natural Earth geography polygons (oceans, seas, rivers, terrain)
|
| 91 |
+
query: read_parquet('natural_earth')
|
| 92 |
columns:
|
| 93 |
+
id VARCHAR -- unique feature id prefixed 'ne_'
|
| 94 |
+
names STRUCT("primary" VARCHAR, ...)
|
| 95 |
+
country VARCHAR
|
| 96 |
+
subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
|
| 97 |
+
class VARCHAR
|
| 98 |
+
region VARCHAR
|
| 99 |
+
admin_level INTEGER
|
| 100 |
+
is_land BOOLEAN
|
| 101 |
+
is_territorial BOOLEAN
|
| 102 |
+
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 103 |
+
</SCHEMA>
|
| 104 |
+
|
| 105 |
+
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 106 |
+
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 107 |
+
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 108 |
+
|
| 109 |
+
USER_PROMPT_TEMPLATE = """<CANDIDATES>
|
| 110 |
+
{candidates_csv}
|
| 111 |
+
</CANDIDATES>
|
| 112 |
+
|
| 113 |
+
<USER_QUERY>
|
| 114 |
+
{question}
|
| 115 |
+
</USER_QUERY>
|
| 116 |
+
"""
|
| 117 |
|
| 118 |
|
| 119 |
@dataclass
|
|
|
|
| 131 |
)
|
| 132 |
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
| 135 |
df = pd.DataFrame(list(candidates))
|
| 136 |
if "candidate_id" in df.columns:
|
|
|
|
| 138 |
return df.to_csv(index=False)
|
| 139 |
|
| 140 |
|
| 141 |
+
def build_user_prompt(question: str, candidates: Sequence[Dict[str, Any]]) -> str:
|
| 142 |
return USER_PROMPT_TEMPLATE.format(
|
|
|
|
| 143 |
candidates_csv=candidates_to_csv(candidates).strip(),
|
| 144 |
question=question.strip(),
|
| 145 |
)
|
| 146 |
|
| 147 |
|
| 148 |
+
def make_messages(sample: Dict[str, Any]) -> Dict[str, Any]:
|
| 149 |
messages = [
|
| 150 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 151 |
{
|
|
|
|
| 153 |
"content": build_user_prompt(
|
| 154 |
question=sample["question"],
|
| 155 |
candidates=sample["candidates"],
|
|
|
|
| 156 |
),
|
| 157 |
},
|
| 158 |
]
|
|
|
|
| 175 |
return load_dataset("json", data_files=data_files)
|
| 176 |
|
| 177 |
|
| 178 |
+
def format_dataset_for_sft(dataset: DatasetDict) -> DatasetDict:
|
| 179 |
formatted = DatasetDict()
|
| 180 |
for split, ds in dataset.items():
|
| 181 |
+
formatted[split] = ds.map(make_messages)
|
| 182 |
return formatted
|
| 183 |
|
| 184 |
|
|
|
|
| 268 |
|
| 269 |
def train(args: argparse.Namespace) -> None:
|
| 270 |
set_seed(args.seed)
|
|
|
|
| 271 |
raw_ds = load_jsonl_splits(args.train_jsonl, args.val_jsonl, args.test_jsonl)
|
| 272 |
+
ds = format_dataset_for_sft(raw_ds)
|
| 273 |
|
| 274 |
if args.max_train_samples is not None:
|
| 275 |
ds["train"] = ds["train"].select(range(min(args.max_train_samples, len(ds["train"]))))
|
|
|
|
| 358 |
tokenizer,
|
| 359 |
question: str,
|
| 360 |
candidates: Sequence[Dict[str, Any]],
|
|
|
|
| 361 |
max_new_tokens: int = 256,
|
| 362 |
do_sample: bool = False,
|
| 363 |
temperature: float = 0.1,
|
|
|
|
| 366 |
) -> GenerationResult:
|
| 367 |
messages = make_messages(
|
| 368 |
{"question": question, "candidates": list(candidates), "target": {}},
|
|
|
|
| 369 |
)["messages"]
|
| 370 |
prompt = render_prompt(tokenizer, messages)
|
| 371 |
inputs = tokenizer.apply_chat_template(
|
|
|
|
| 440 |
|
| 441 |
|
| 442 |
def cmd_generate(args: argparse.Namespace) -> None:
|
|
|
|
| 443 |
question = read_question(args)
|
| 444 |
candidates = read_candidates(args)
|
| 445 |
model, tokenizer = load_model_for_inference(
|
|
|
|
| 456 |
tokenizer=tokenizer,
|
| 457 |
question=question,
|
| 458 |
candidates=candidates,
|
|
|
|
| 459 |
max_new_tokens=args.max_new_tokens,
|
| 460 |
do_sample=args.do_sample,
|
| 461 |
temperature=args.temperature,
|
|
|
|
| 503 |
train_p.add_argument("--train-jsonl", required=True)
|
| 504 |
train_p.add_argument("--val-jsonl")
|
| 505 |
train_p.add_argument("--test-jsonl")
|
|
|
|
| 506 |
train_p.add_argument("--output-dir", required=True)
|
| 507 |
train_p.add_argument("--max-train-samples", type=int)
|
| 508 |
train_p.add_argument("--max-eval-samples", type=int)
|
|
|
|
| 543 |
gen_p.add_argument("--model-path")
|
| 544 |
gen_p.add_argument("--base-model")
|
| 545 |
gen_p.add_argument("--adapter-path")
|
|
|
|
| 546 |
gen_p.add_argument("--question")
|
| 547 |
gen_p.add_argument("--candidates-json")
|
| 548 |
gen_p.add_argument("--sample-jsonl")
|
finetune/prompts.py
CHANGED
|
@@ -6,35 +6,21 @@ from typing import Any, Dict, Sequence
|
|
| 6 |
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
-
SYSTEM_PROMPT =
|
| 10 |
-
"You are a text to SQL query translator that helps in natural language geocoding."
|
| 11 |
-
)
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
-
<
|
| 16 |
-
|
| 17 |
-
</SCHEMA_DETAILS>
|
| 18 |
-
|
| 19 |
-
<CANDIDATES>
|
| 20 |
-
{candidates_csv}
|
| 21 |
-
</CANDIDATES>
|
| 22 |
-
|
| 23 |
-
<USER_QUERY>
|
| 24 |
-
{question}
|
| 25 |
-
</USER_QUERY>
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
DEFAULT_SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 29 |
query: read_parquet('divisions_area')
|
| 30 |
columns:
|
| 31 |
-
id VARCHAR -- unique feature id
|
| 32 |
names STRUCT("primary" VARCHAR, ...)
|
| 33 |
country VARCHAR -- ISO 3166-1 alpha-2
|
| 34 |
subtype VARCHAR -- country | region | dependency | county | localadmin |
|
| 35 |
locality | macrohood | neighborhood | microhood
|
| 36 |
class VARCHAR
|
| 37 |
-
region VARCHAR
|
| 38 |
admin_level INTEGER
|
| 39 |
division_id VARCHAR
|
| 40 |
is_land BOOLEAN
|
|
@@ -54,9 +40,20 @@ DEFAULT_SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon
|
|
| 54 |
is_land BOOLEAN
|
| 55 |
is_territorial BOOLEAN
|
| 56 |
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
|
|
|
| 57 |
|
| 58 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 59 |
-
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
|
@@ -69,10 +66,8 @@ def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
|
| 69 |
def build_user_prompt(
|
| 70 |
question: str,
|
| 71 |
candidates: Sequence[Dict[str, Any]],
|
| 72 |
-
schema_details: str,
|
| 73 |
) -> str:
|
| 74 |
return USER_PROMPT_TEMPLATE.format(
|
| 75 |
-
schema_details=schema_details.strip(),
|
| 76 |
candidates_csv=candidates_to_csv(candidates).strip(),
|
| 77 |
question=question.strip(),
|
| 78 |
)
|
|
@@ -80,12 +75,10 @@ def build_user_prompt(
|
|
| 80 |
|
| 81 |
def make_prompt_completion(
|
| 82 |
sample: Dict[str, Any],
|
| 83 |
-
schema_details: str,
|
| 84 |
) -> Dict[str, str]:
|
| 85 |
prompt = SYSTEM_PROMPT + "\n\n" + build_user_prompt(
|
| 86 |
question=sample["question"],
|
| 87 |
candidates=sample["candidates"],
|
| 88 |
-
schema_details=schema_details,
|
| 89 |
)
|
| 90 |
completion = sample.get("target", {}).get("sql", "")
|
| 91 |
return {"prompt": prompt, "completion": completion}
|
|
|
|
| 6 |
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
+
SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
|
| 12 |
|
| 13 |
+
<SCHEMA>
|
| 14 |
+
1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
query: read_parquet('divisions_area')
|
| 16 |
columns:
|
| 17 |
+
id VARCHAR -- unique feature id
|
| 18 |
names STRUCT("primary" VARCHAR, ...)
|
| 19 |
country VARCHAR -- ISO 3166-1 alpha-2
|
| 20 |
subtype VARCHAR -- country | region | dependency | county | localadmin |
|
| 21 |
locality | macrohood | neighborhood | microhood
|
| 22 |
class VARCHAR
|
| 23 |
+
region VARCHAR
|
| 24 |
admin_level INTEGER
|
| 25 |
division_id VARCHAR
|
| 26 |
is_land BOOLEAN
|
|
|
|
| 40 |
is_land BOOLEAN
|
| 41 |
is_territorial BOOLEAN
|
| 42 |
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 43 |
+
</SCHEMA>
|
| 44 |
|
| 45 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 46 |
+
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 47 |
+
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 48 |
+
|
| 49 |
+
USER_PROMPT_TEMPLATE = """<CANDIDATES>
|
| 50 |
+
{candidates_csv}
|
| 51 |
+
</CANDIDATES>
|
| 52 |
+
|
| 53 |
+
<USER_QUERY>
|
| 54 |
+
{question}
|
| 55 |
+
</USER_QUERY>
|
| 56 |
+
"""
|
| 57 |
|
| 58 |
|
| 59 |
def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
|
|
|
| 66 |
def build_user_prompt(
|
| 67 |
question: str,
|
| 68 |
candidates: Sequence[Dict[str, Any]],
|
|
|
|
| 69 |
) -> str:
|
| 70 |
return USER_PROMPT_TEMPLATE.format(
|
|
|
|
| 71 |
candidates_csv=candidates_to_csv(candidates).strip(),
|
| 72 |
question=question.strip(),
|
| 73 |
)
|
|
|
|
| 75 |
|
| 76 |
def make_prompt_completion(
|
| 77 |
sample: Dict[str, Any],
|
|
|
|
| 78 |
) -> Dict[str, str]:
|
| 79 |
prompt = SYSTEM_PROMPT + "\n\n" + build_user_prompt(
|
| 80 |
question=sample["question"],
|
| 81 |
candidates=sample["candidates"],
|
|
|
|
| 82 |
)
|
| 83 |
completion = sample.get("target", {}).get("sql", "")
|
| 84 |
return {"prompt": prompt, "completion": completion}
|
finetune/train_modal_qwen35.py
CHANGED
|
@@ -123,8 +123,7 @@ def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples
|
|
| 123 |
"""Load JSONL data and apply Qwen3.5 chat template.
|
| 124 |
|
| 125 |
Each sample must have:
|
| 126 |
-
|
| 127 |
-
completion: list of {role, content} dicts (assistant)
|
| 128 |
|
| 129 |
The chat template produces the full ChatML string including the assistant turn.
|
| 130 |
train_on_responses_only then masks everything except the assistant response.
|
|
@@ -143,7 +142,7 @@ def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples
|
|
| 143 |
|
| 144 |
def to_message(sample: dict) -> dict:
|
| 145 |
text = tokenizer.apply_chat_template(
|
| 146 |
-
sample["
|
| 147 |
tokenize=False,
|
| 148 |
add_generation_prompt=False,
|
| 149 |
)
|
|
|
|
| 123 |
"""Load JSONL data and apply Qwen3.5 chat template.
|
| 124 |
|
| 125 |
Each sample must have:
|
| 126 |
+
messages: list of {role, content} dicts (system + user + assistant)
|
|
|
|
| 127 |
|
| 128 |
The chat template produces the full ChatML string including the assistant turn.
|
| 129 |
train_on_responses_only then masks everything except the assistant response.
|
|
|
|
| 142 |
|
| 143 |
def to_message(sample: dict) -> dict:
|
| 144 |
text = tokenizer.apply_chat_template(
|
| 145 |
+
sample["messages"],
|
| 146 |
tokenize=False,
|
| 147 |
add_generation_prompt=False,
|
| 148 |
)
|
src/gazet/lm.py
CHANGED
|
@@ -176,11 +176,12 @@ write_sql = SQLWriter(lm=sql_generation_lm)
|
|
| 176 |
|
| 177 |
# ββ GGUF SQL generation via llama-server ββββββββββββββββββββββββββββββββββββββ
|
| 178 |
|
| 179 |
-
_SYSTEM_PROMPT =
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
-
|
|
|
|
| 184 |
query: read_parquet('divisions_area')
|
| 185 |
columns:
|
| 186 |
id VARCHAR -- unique feature id
|
|
@@ -209,18 +210,13 @@ _SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin b
|
|
| 209 |
is_land BOOLEAN
|
| 210 |
is_territorial BOOLEAN
|
| 211 |
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
|
|
|
| 212 |
|
| 213 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 214 |
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 215 |
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 216 |
|
| 217 |
-
_USER_PROMPT_TEMPLATE = """
|
| 218 |
-
|
| 219 |
-
<SCHEMA_DETAILS>
|
| 220 |
-
{schema_details}
|
| 221 |
-
</SCHEMA_DETAILS>
|
| 222 |
-
|
| 223 |
-
<CANDIDATES>
|
| 224 |
{candidates_csv}
|
| 225 |
</CANDIDATES>
|
| 226 |
|
|
@@ -269,10 +265,21 @@ def _llama_chat_complete(messages: list[dict]) -> str:
|
|
| 269 |
return resp.json()["choices"][0]["message"]["content"]
|
| 270 |
|
| 271 |
|
| 272 |
-
_PLACES_SYSTEM_PROMPT =
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
def generate_places(user_query: str) -> PlacesResult:
|
|
@@ -307,7 +314,7 @@ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
|
|
| 307 |
"""Generate SQL from a natural language query using the finetuned GGUF model.
|
| 308 |
|
| 309 |
Uses the same prompt format the model was trained on:
|
| 310 |
-
SYSTEM_PROMPT + USER_PROMPT_TEMPLATE with
|
| 311 |
Single-shot β no retry loop (the finetuned model can't improve from error feedback).
|
| 312 |
"""
|
| 313 |
# Keep only columns the model was trained on
|
|
@@ -316,7 +323,6 @@ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
|
|
| 316 |
candidates_csv = candidates_df[cols].to_csv(index=False)
|
| 317 |
|
| 318 |
user_prompt = _USER_PROMPT_TEMPLATE.format(
|
| 319 |
-
schema_details=_SCHEMA_DETAILS.strip(),
|
| 320 |
candidates_csv=candidates_csv.strip(),
|
| 321 |
question=user_query.strip(),
|
| 322 |
)
|
|
|
|
| 176 |
|
| 177 |
# ββ GGUF SQL generation via llama-server ββββββββββββββββββββββββββββββββββββββ
|
| 178 |
|
| 179 |
+
_SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
|
| 180 |
+
|
| 181 |
+
You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
|
| 182 |
|
| 183 |
+
<SCHEMA>
|
| 184 |
+
1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 185 |
query: read_parquet('divisions_area')
|
| 186 |
columns:
|
| 187 |
id VARCHAR -- unique feature id
|
|
|
|
| 210 |
is_land BOOLEAN
|
| 211 |
is_territorial BOOLEAN
|
| 212 |
geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
|
| 213 |
+
</SCHEMA>
|
| 214 |
|
| 215 |
The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
|
| 216 |
Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
|
| 217 |
Use ST_AsGeoJSON(geometry) for all geometry outputs."""
|
| 218 |
|
| 219 |
+
_USER_PROMPT_TEMPLATE = """<CANDIDATES>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
{candidates_csv}
|
| 221 |
</CANDIDATES>
|
| 222 |
|
|
|
|
| 265 |
return resp.json()["choices"][0]["message"]["content"]
|
| 266 |
|
| 267 |
|
| 268 |
+
_PLACES_SYSTEM_PROMPT = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
|
| 269 |
+
|
| 270 |
+
OUTPUT FORMAT:
|
| 271 |
+
{"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
|
| 272 |
+
"country" and "subtype" are optional; omit if not applicable.
|
| 273 |
+
|
| 274 |
+
RULES:
|
| 275 |
+
- Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
|
| 276 |
+
- No duplicate place names.
|
| 277 |
+
- "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
|
| 278 |
+
- "subtype": include only when the geographic level is clear from the query.
|
| 279 |
+
|
| 280 |
+
SUBTYPES:
|
| 281 |
+
country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
|
| 282 |
+
- Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
|
| 283 |
|
| 284 |
|
| 285 |
def generate_places(user_query: str) -> PlacesResult:
|
|
|
|
| 314 |
"""Generate SQL from a natural language query using the finetuned GGUF model.
|
| 315 |
|
| 316 |
Uses the same prompt format the model was trained on:
|
| 317 |
+
SYSTEM_PROMPT (includes schema) + USER_PROMPT_TEMPLATE with candidates CSV and question.
|
| 318 |
Single-shot β no retry loop (the finetuned model can't improve from error feedback).
|
| 319 |
"""
|
| 320 |
# Keep only columns the model was trained on
|
|
|
|
| 323 |
candidates_csv = candidates_df[cols].to_csv(index=False)
|
| 324 |
|
| 325 |
user_prompt = _USER_PROMPT_TEMPLATE.format(
|
|
|
|
| 326 |
candidates_csv=candidates_csv.strip(),
|
| 327 |
question=user_query.strip(),
|
| 328 |
)
|