srmsoumya commited on
Commit
c77ca5f
Β·
1 Parent(s): 2bf5583

chore: clean sql generation, use conversation format, move prompts from user to system

Browse files
dataset/scripts/convert_to_conversation_format.py DELETED
@@ -1,134 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Convert prompt-completion format to conversation format.
3
-
4
- Reads SQL and places JSONL from a run directory and converts to a single
5
- "messages" list format suitable for various downstream uses.
6
-
7
- Input format (current):
8
- {
9
- "prompt": [
10
- {"role": "system", "content": "..."},
11
- {"role": "user", "content": "..."}
12
- ],
13
- "completion": [
14
- {"role": "assistant", "content": "..."}
15
- ],
16
- "metadata": {...}
17
- }
18
-
19
- Output format:
20
- {
21
- "messages": [
22
- {"role": "system", "content": "..."},
23
- {"role": "user", "content": "..."},
24
- {"role": "assistant", "content": "..."}
25
- ]
26
- }
27
-
28
- Saves to JSONL files:
29
- - train_conversation_sql.jsonl
30
- - val_conversation_sql.jsonl
31
- - test_conversation_sql.jsonl
32
- - train_conversation_places.jsonl
33
- - val_conversation_places.jsonl
34
- - test_conversation_places.jsonl
35
-
36
- Usage with datasets library:
37
- from datasets import load_dataset
38
-
39
- train_sql = load_dataset(
40
- "json",
41
- data_files="dataset/output/conversations/train_conversation_sql.jsonl",
42
- split="train"
43
- )
44
-
45
- # Access messages:
46
- print(train_sql[0]["messages"])
47
- """
48
-
49
- import argparse
50
- import json
51
- from pathlib import Path
52
-
53
-
54
- def load_jsonl(path: Path) -> list[dict]:
55
- rows = []
56
- with open(path) as f:
57
- for line in f:
58
- line = line.strip()
59
- if line:
60
- rows.append(json.loads(line))
61
- return rows
62
-
63
-
64
- def to_conversation_format(sample: dict) -> dict:
65
- """Convert prompt+completion format to messages format."""
66
- return {
67
- "messages": sample["prompt"] + sample["completion"],
68
- }
69
-
70
-
71
- def process_task(run_dir: Path, task: str, output_dir: Path):
72
- """Process all splits for a single task (sql or places)."""
73
- task_dir = run_dir / task
74
-
75
- for split in ["train", "val", "test"]:
76
- input_path = task_dir / f"{split}.jsonl"
77
- if not input_path.exists():
78
- print(f" Skipping {task}/{split}: {input_path} not found")
79
- continue
80
-
81
- samples = load_jsonl(input_path)
82
- conversations = [to_conversation_format(s) for s in samples]
83
-
84
- output_path = output_dir / f"{split}_conversation_{task}.jsonl"
85
- output_path.parent.mkdir(parents=True, exist_ok=True)
86
-
87
- with open(output_path, "w") as f:
88
- for conv in conversations:
89
- f.write(json.dumps(conv, ensure_ascii=False) + "\n")
90
-
91
- print(f" {task}/{split}: {len(conversations)} samples β†’ {output_path}")
92
-
93
-
94
- def main():
95
- parser = argparse.ArgumentParser(
96
- description="Convert prompt-completion format to conversation format"
97
- )
98
- parser.add_argument(
99
- "--run-dir",
100
- type=Path,
101
- default=Path("dataset/output/runs/v3-symbolic-paths"),
102
- help="Path to run directory containing sql/ and places/ subdirectories",
103
- )
104
- parser.add_argument(
105
- "--output-dir",
106
- type=Path,
107
- default=Path("dataset/output/conversations"),
108
- help="Output directory for JSONL files",
109
- )
110
-
111
- args = parser.parse_args()
112
-
113
- run_dir = args.run_dir
114
- output_dir = args.output_dir
115
-
116
- if not run_dir.exists():
117
- print(f"Error: Run directory not found: {run_dir}")
118
- return 1
119
-
120
- print(f"Converting from: {run_dir}")
121
- print(f"Output directory: {output_dir}")
122
- print()
123
-
124
- for task in ["sql", "places"]:
125
- print(f"Processing {task}:")
126
- process_task(run_dir, task, output_dir)
127
-
128
- print()
129
- print("Conversion complete!")
130
- print(f"Output files in: {output_dir}/")
131
-
132
-
133
- if __name__ == "__main__":
134
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/scripts/export_training_data.py CHANGED
@@ -27,8 +27,11 @@ from collections import defaultdict
27
  from pathlib import Path
28
  from typing import Any, Dict, List, Optional, Tuple
29
 
 
30
  import yaml
31
 
 
 
32
 
33
  # ---------------------------------------------------------------------------
34
  # Loading
@@ -89,16 +92,12 @@ def stratified_split(
89
  # Conversational prompt-completion: model sees system + user, generates SQL.
90
  # ---------------------------------------------------------------------------
91
 
92
- _SQL_SYSTEM = (
93
- "You are a text to SQL query translator that helps in natural language geocoding."
94
- )
95
 
96
- _CANDIDATES_COLS = [
97
- "source", "id", "name", "subtype", "country", "region",
98
- "admin_level", "similarity",
99
- ]
100
 
101
- _SCHEMA = """1. divisions_area -- Overture polygon/multipolygon admin boundaries
 
102
  query: read_parquet('divisions_area')
103
  columns:
104
  id VARCHAR -- unique feature id
@@ -127,11 +126,17 @@ _SCHEMA = """1. divisions_area -- Overture polygon/multipolygon admin boundarie
127
  is_land BOOLEAN
128
  is_territorial BOOLEAN
129
  geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
 
130
 
131
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
132
  Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
133
  Use ST_AsGeoJSON(geometry) for all geometry outputs."""
134
 
 
 
 
 
 
135
 
136
  def _candidates_csv(candidates: List[Dict]) -> str:
137
  import io
@@ -149,26 +154,42 @@ def _candidates_csv(candidates: List[Dict]) -> str:
149
  return buf.getvalue().strip()
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
153
  """Convert a raw sample to a conversational prompt-completion pair for SQL generation."""
154
  sql = sample.get("target", {}).get("sql", "").strip()
155
  if not sql:
156
  return None
 
157
 
158
  user_content = (
159
- "GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, "
160
- "generate the corresponding SQL command to retrieve the desired geometry.\n\n"
161
- f"<SCHEMA_DETAILS>\n{_SCHEMA}\n</SCHEMA_DETAILS>\n\n"
162
  f"<CANDIDATES>\n{_candidates_csv(sample.get('candidates', []))}\n</CANDIDATES>\n\n"
163
  f"<USER_QUERY>\n{sample['question']}\n</USER_QUERY>"
164
  )
165
 
166
  return {
167
- "prompt": [
168
- {"role": "system", "content": _SQL_SYSTEM},
169
- {"role": "user", "content": user_content},
170
- ],
171
- "completion": [
172
  {"role": "assistant", "content": sql},
173
  ],
174
  "metadata": sample.get("metadata", {}),
@@ -180,10 +201,21 @@ def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
180
  # Derived from the same SQL samples: selected_candidates β†’ PlacesResult JSON.
181
  # ---------------------------------------------------------------------------
182
 
183
- _PLACE_SYSTEM = (
184
- "You are a geographic entity extractor. "
185
- "Extract place names from the query and return valid JSON only."
186
- )
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # Overture division subtypes β€” used to filter out natural_earth candidates
189
  # from the place extraction output (NE features don't have these subtypes).
@@ -241,11 +273,9 @@ def sample_to_place_pair(sample: Dict[str, Any]) -> Optional[Dict]:
241
  completion_json = json.dumps({"places": places}, ensure_ascii=False)
242
 
243
  return {
244
- "prompt": [
245
- {"role": "system", "content": _PLACE_SYSTEM},
246
- {"role": "user", "content": sample["question"]},
247
- ],
248
- "completion": [
249
  {"role": "assistant", "content": completion_json},
250
  ],
251
  "metadata": sample.get("metadata", {}),
 
27
  from pathlib import Path
28
  from typing import Any, Dict, List, Optional, Tuple
29
 
30
+ import sqlparse
31
  import yaml
32
 
33
+ from gazet.config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
34
+
35
 
36
  # ---------------------------------------------------------------------------
37
  # Loading
 
92
  # Conversational prompt-completion: model sees system + user, generates SQL.
93
  # ---------------------------------------------------------------------------
94
 
95
+ _SQL_SYSTEM = """You are a text to SQL query translator that helps in natural language geocoding.
 
 
96
 
97
+ You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
 
 
 
98
 
99
+ <SCHEMA>
100
+ 1. divisions_area -- Overture polygon/multipolygon admin boundaries
101
  query: read_parquet('divisions_area')
102
  columns:
103
  id VARCHAR -- unique feature id
 
126
  is_land BOOLEAN
127
  is_territorial BOOLEAN
128
  geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
129
+ </SCHEMA>
130
 
131
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
132
  Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
133
  Use ST_AsGeoJSON(geometry) for all geometry outputs."""
134
 
135
+ _CANDIDATES_COLS = [
136
+ "source", "id", "name", "subtype", "country", "region",
137
+ "admin_level", "similarity",
138
+ ]
139
+
140
 
141
  def _candidates_csv(candidates: List[Dict]) -> str:
142
  import io
 
154
  return buf.getvalue().strip()
155
 
156
 
157
+ def _to_symbolic_sql(sql: str) -> str:
158
+ """Normalize any hardcoded or runtime paths back to symbolic names."""
159
+ sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
160
+ sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
161
+ sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
162
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
163
+ sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
164
+ return sql
165
+
166
+
167
+ def _format_sql(sql: str) -> str:
168
+ """Pretty-print SQL so the model learns clean, readable style."""
169
+ return sqlparse.format(
170
+ sql,
171
+ reindent=True,
172
+ keyword_case="upper",
173
+ indent_width=4,
174
+ ).strip()
175
+
176
+
177
  def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
178
  """Convert a raw sample to a conversational prompt-completion pair for SQL generation."""
179
  sql = sample.get("target", {}).get("sql", "").strip()
180
  if not sql:
181
  return None
182
+ sql = _format_sql(_to_symbolic_sql(sql))
183
 
184
  user_content = (
 
 
 
185
  f"<CANDIDATES>\n{_candidates_csv(sample.get('candidates', []))}\n</CANDIDATES>\n\n"
186
  f"<USER_QUERY>\n{sample['question']}\n</USER_QUERY>"
187
  )
188
 
189
  return {
190
+ "messages": [
191
+ {"role": "system", "content": _SQL_SYSTEM},
192
+ {"role": "user", "content": user_content},
 
 
193
  {"role": "assistant", "content": sql},
194
  ],
195
  "metadata": sample.get("metadata", {}),
 
201
  # Derived from the same SQL samples: selected_candidates β†’ PlacesResult JSON.
202
  # ---------------------------------------------------------------------------
203
 
204
+ _PLACE_SYSTEM = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
205
+
206
+ OUTPUT FORMAT:
207
+ {"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
208
+ "country" and "subtype" are optional; omit if not applicable.
209
+
210
+ RULES:
211
+ - Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
212
+ - No duplicate place names.
213
+ - "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
214
+ - "subtype": include only when the geographic level is clear from the query.
215
+
216
+ SUBTYPES:
217
+ country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
218
+ - Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
219
 
220
  # Overture division subtypes β€” used to filter out natural_earth candidates
221
  # from the place extraction output (NE features don't have these subtypes).
 
273
  completion_json = json.dumps({"places": places}, ensure_ascii=False)
274
 
275
  return {
276
+ "messages": [
277
+ {"role": "system", "content": _PLACE_SYSTEM},
278
+ {"role": "user", "content": sample["question"]},
 
 
279
  {"role": "assistant", "content": completion_json},
280
  ],
281
  "metadata": sample.get("metadata", {}),
dataset/scripts/validate_dataset.py CHANGED
@@ -50,6 +50,18 @@ def _resolve_paths(sql: str) -> str:
50
  return sql
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
54
  """Validate that SQL executes without error.
55
 
@@ -120,6 +132,8 @@ def validate_sample_worker(sample: Dict[str, Any]) -> Tuple[str, bool, List[str]
120
  try:
121
  is_valid, issues = validate_sample(con, sample)
122
  con.close()
 
 
123
  return (sample['id'], is_valid, issues, sample if is_valid else None)
124
  except Exception as e:
125
  con.close()
 
50
  return sql
51
 
52
 
53
+ def _to_symbolic_sql(sql: str) -> str:
54
+ """Normalize any hardcoded or runtime paths back to symbolic names for storage."""
55
+ # Current local runtime paths
56
+ sql = sql.replace(DIVISIONS_AREA_PATH, "divisions_area")
57
+ sql = sql.replace(NATURAL_EARTH_PATH, "natural_earth")
58
+ # Legacy Docker paths
59
+ sql = sql.replace("/data/overture/division_area/*.parquet", "divisions_area")
60
+ sql = sql.replace("/data/overture/divisions_area/*.parquet", "divisions_area")
61
+ sql = sql.replace("/data/natural_earth_geoparquet/ne_geography.parquet", "natural_earth")
62
+ return sql
63
+
64
+
65
  def validate_sql(con: duckdb.DuckDBPyConnection, sql: str) -> tuple[bool, str]:
66
  """Validate that SQL executes without error.
67
 
 
132
  try:
133
  is_valid, issues = validate_sample(con, sample)
134
  con.close()
135
+ if is_valid:
136
+ sample['target']['sql'] = _to_symbolic_sql(sample['target']['sql'])
137
  return (sample['id'], is_valid, issues, sample if is_valid else None)
138
  except Exception as e:
139
  con.close()
finetune/eval_cli.py CHANGED
@@ -83,19 +83,15 @@ def load_samples(run_dir: Path, task: str) -> list[dict]:
83
 
84
 
85
  def build_raw_prompt(sample: dict) -> str:
86
- """Reconstruct the plain prompt string from message-list format.
87
-
88
- sample["prompt"] is [{role:system, content:...}, {role:user, content:...}].
89
- Joins them with a blank line β€” same format used during training.
90
- """
91
- return sample["prompt"][0]["content"] + "\n\n" + sample["prompt"][1]["content"]
92
 
93
 
94
  def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
95
- expected = sample["completion"][0]["content"]
96
- messages = sample["prompt"]
97
 
98
- user_content = sample["prompt"][1]["content"]
99
  if "<USER_QUERY>" in user_content:
100
  question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
101
  else:
 
83
 
84
 
85
  def build_raw_prompt(sample: dict) -> str:
86
+ """Reconstruct the plain prompt string from messages format (all turns except assistant)."""
87
+ return "\n\n".join(m["content"] for m in sample["messages"][:-1])
 
 
 
 
88
 
89
 
90
  def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
91
+ expected = sample["messages"][-1]["content"]
92
+ messages = sample["messages"][:-1]
93
 
94
+ user_content = sample["messages"][-2]["content"]
95
  if "<USER_QUERY>" in user_content:
96
  question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
97
  else:
finetune/nlg.py CHANGED
@@ -66,49 +66,54 @@ from trl import SFTConfig, SFTTrainer
66
 
67
  LOGGER = logging.getLogger("nlg")
68
 
69
- SYSTEM_PROMPT = (
70
- "You are a text to SQL query translator that helps in natural language geocoding."
71
- )
72
 
73
- USER_PROMPT_TEMPLATE = """GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, generate the corresponding SQL command to retrieve the desired geometry.
74
 
75
- <SCHEMA_DETAILS>
76
- {schema_details}
77
- </SCHEMA_DETAILS>
78
-
79
- <CANDIDATES>
80
- {candidates_csv}
81
- </CANDIDATES>
82
-
83
- <USER_QUERY>
84
- {question}
85
- </USER_QUERY>
86
- """
87
-
88
- DEFAULT_SCHEMA_DETAILS = """1. divisions_area β€” Overture polygon/multipolygon admin boundaries
89
- path: '/data/overture/division_area/*.parquet'
90
  columns:
91
- id VARCHAR
92
  names STRUCT("primary" VARCHAR, ...)
93
- country VARCHAR
94
- subtype VARCHAR
 
95
  class VARCHAR
96
  region VARCHAR
97
  admin_level INTEGER
98
  division_id VARCHAR
99
  is_land BOOLEAN
100
  is_territorial BOOLEAN
101
- geometry GEOMETRY
102
 
103
- 2. natural_earth β€” Natural Earth geography polygons
104
- path: '/data/natural_earth_geoparquet/ne_geography.parquet'
105
  columns:
106
- id VARCHAR
107
- name VARCHAR
108
- featurecla VARCHAR
109
- scalerank INTEGER
110
- min_zoom DOUBLE
111
- geometry GEOMETRY"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  @dataclass
@@ -126,12 +131,6 @@ def setup_logging(verbose: bool = False) -> None:
126
  )
127
 
128
 
129
- def read_text(path: Optional[str], default: str) -> str:
130
- if not path:
131
- return default
132
- return Path(path).read_text(encoding="utf-8")
133
-
134
-
135
  def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
136
  df = pd.DataFrame(list(candidates))
137
  if "candidate_id" in df.columns:
@@ -139,15 +138,14 @@ def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
139
  return df.to_csv(index=False)
140
 
141
 
142
- def build_user_prompt(question: str, candidates: Sequence[Dict[str, Any]], schema_details: str) -> str:
143
  return USER_PROMPT_TEMPLATE.format(
144
- schema_details=schema_details.strip(),
145
  candidates_csv=candidates_to_csv(candidates).strip(),
146
  question=question.strip(),
147
  )
148
 
149
 
150
- def make_messages(sample: Dict[str, Any], schema_details: str) -> Dict[str, Any]:
151
  messages = [
152
  {"role": "system", "content": SYSTEM_PROMPT},
153
  {
@@ -155,7 +153,6 @@ def make_messages(sample: Dict[str, Any], schema_details: str) -> Dict[str, Any]
155
  "content": build_user_prompt(
156
  question=sample["question"],
157
  candidates=sample["candidates"],
158
- schema_details=schema_details,
159
  ),
160
  },
161
  ]
@@ -178,10 +175,10 @@ def load_jsonl_splits(
178
  return load_dataset("json", data_files=data_files)
179
 
180
 
181
- def format_dataset_for_sft(dataset: DatasetDict, schema_details: str) -> DatasetDict:
182
  formatted = DatasetDict()
183
  for split, ds in dataset.items():
184
- formatted[split] = ds.map(lambda row: make_messages(row, schema_details))
185
  return formatted
186
 
187
 
@@ -271,9 +268,8 @@ def build_lora_config(args: argparse.Namespace) -> LoraConfig:
271
 
272
  def train(args: argparse.Namespace) -> None:
273
  set_seed(args.seed)
274
- schema_details = read_text(args.schema_file, DEFAULT_SCHEMA_DETAILS)
275
  raw_ds = load_jsonl_splits(args.train_jsonl, args.val_jsonl, args.test_jsonl)
276
- ds = format_dataset_for_sft(raw_ds, schema_details)
277
 
278
  if args.max_train_samples is not None:
279
  ds["train"] = ds["train"].select(range(min(args.max_train_samples, len(ds["train"]))))
@@ -362,7 +358,6 @@ def generate_sql(
362
  tokenizer,
363
  question: str,
364
  candidates: Sequence[Dict[str, Any]],
365
- schema_details: str,
366
  max_new_tokens: int = 256,
367
  do_sample: bool = False,
368
  temperature: float = 0.1,
@@ -371,7 +366,6 @@ def generate_sql(
371
  ) -> GenerationResult:
372
  messages = make_messages(
373
  {"question": question, "candidates": list(candidates), "target": {}},
374
- schema_details,
375
  )["messages"]
376
  prompt = render_prompt(tokenizer, messages)
377
  inputs = tokenizer.apply_chat_template(
@@ -446,7 +440,6 @@ def execute_sqlite(sql: str, sqlite_db: str, limit: Optional[int] = None) -> Tup
446
 
447
 
448
  def cmd_generate(args: argparse.Namespace) -> None:
449
- schema_details = read_text(args.schema_file, DEFAULT_SCHEMA_DETAILS)
450
  question = read_question(args)
451
  candidates = read_candidates(args)
452
  model, tokenizer = load_model_for_inference(
@@ -463,7 +456,6 @@ def cmd_generate(args: argparse.Namespace) -> None:
463
  tokenizer=tokenizer,
464
  question=question,
465
  candidates=candidates,
466
- schema_details=schema_details,
467
  max_new_tokens=args.max_new_tokens,
468
  do_sample=args.do_sample,
469
  temperature=args.temperature,
@@ -511,7 +503,6 @@ def build_parser() -> argparse.ArgumentParser:
511
  train_p.add_argument("--train-jsonl", required=True)
512
  train_p.add_argument("--val-jsonl")
513
  train_p.add_argument("--test-jsonl")
514
- train_p.add_argument("--schema-file")
515
  train_p.add_argument("--output-dir", required=True)
516
  train_p.add_argument("--max-train-samples", type=int)
517
  train_p.add_argument("--max-eval-samples", type=int)
@@ -552,7 +543,6 @@ def build_parser() -> argparse.ArgumentParser:
552
  gen_p.add_argument("--model-path")
553
  gen_p.add_argument("--base-model")
554
  gen_p.add_argument("--adapter-path")
555
- gen_p.add_argument("--schema-file")
556
  gen_p.add_argument("--question")
557
  gen_p.add_argument("--candidates-json")
558
  gen_p.add_argument("--sample-jsonl")
 
66
 
67
  LOGGER = logging.getLogger("nlg")
68
 
69
+ SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
 
 
70
 
71
+ You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
72
 
73
+ <SCHEMA>
74
+ 1. divisions_area -- Overture polygon/multipolygon admin boundaries
75
+ query: read_parquet('divisions_area')
 
 
 
 
 
 
 
 
 
 
 
 
76
  columns:
77
+ id VARCHAR -- unique feature id
78
  names STRUCT("primary" VARCHAR, ...)
79
+ country VARCHAR -- ISO 3166-1 alpha-2
80
+ subtype VARCHAR -- country | region | dependency | county | localadmin |
81
+ locality | macrohood | neighborhood | microhood
82
  class VARCHAR
83
  region VARCHAR
84
  admin_level INTEGER
85
  division_id VARCHAR
86
  is_land BOOLEAN
87
  is_territorial BOOLEAN
88
+ geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
89
 
90
+ 2. natural_earth -- Natural Earth geography polygons (oceans, seas, rivers, terrain)
91
+ query: read_parquet('natural_earth')
92
  columns:
93
+ id VARCHAR -- unique feature id prefixed 'ne_'
94
+ names STRUCT("primary" VARCHAR, ...)
95
+ country VARCHAR
96
+ subtype VARCHAR -- e.g. 'ocean', 'sea', 'bay', 'Terrain area', 'Island group'
97
+ class VARCHAR
98
+ region VARCHAR
99
+ admin_level INTEGER
100
+ is_land BOOLEAN
101
+ is_territorial BOOLEAN
102
+ geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
103
+ </SCHEMA>
104
+
105
+ The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
106
+ Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
107
+ Use ST_AsGeoJSON(geometry) for all geometry outputs."""
108
+
109
+ USER_PROMPT_TEMPLATE = """<CANDIDATES>
110
+ {candidates_csv}
111
+ </CANDIDATES>
112
+
113
+ <USER_QUERY>
114
+ {question}
115
+ </USER_QUERY>
116
+ """
117
 
118
 
119
  @dataclass
 
131
  )
132
 
133
 
 
 
 
 
 
 
134
  def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
135
  df = pd.DataFrame(list(candidates))
136
  if "candidate_id" in df.columns:
 
138
  return df.to_csv(index=False)
139
 
140
 
141
+ def build_user_prompt(question: str, candidates: Sequence[Dict[str, Any]]) -> str:
142
  return USER_PROMPT_TEMPLATE.format(
 
143
  candidates_csv=candidates_to_csv(candidates).strip(),
144
  question=question.strip(),
145
  )
146
 
147
 
148
+ def make_messages(sample: Dict[str, Any]) -> Dict[str, Any]:
149
  messages = [
150
  {"role": "system", "content": SYSTEM_PROMPT},
151
  {
 
153
  "content": build_user_prompt(
154
  question=sample["question"],
155
  candidates=sample["candidates"],
 
156
  ),
157
  },
158
  ]
 
175
  return load_dataset("json", data_files=data_files)
176
 
177
 
178
+ def format_dataset_for_sft(dataset: DatasetDict) -> DatasetDict:
179
  formatted = DatasetDict()
180
  for split, ds in dataset.items():
181
+ formatted[split] = ds.map(make_messages)
182
  return formatted
183
 
184
 
 
268
 
269
  def train(args: argparse.Namespace) -> None:
270
  set_seed(args.seed)
 
271
  raw_ds = load_jsonl_splits(args.train_jsonl, args.val_jsonl, args.test_jsonl)
272
+ ds = format_dataset_for_sft(raw_ds)
273
 
274
  if args.max_train_samples is not None:
275
  ds["train"] = ds["train"].select(range(min(args.max_train_samples, len(ds["train"]))))
 
358
  tokenizer,
359
  question: str,
360
  candidates: Sequence[Dict[str, Any]],
 
361
  max_new_tokens: int = 256,
362
  do_sample: bool = False,
363
  temperature: float = 0.1,
 
366
  ) -> GenerationResult:
367
  messages = make_messages(
368
  {"question": question, "candidates": list(candidates), "target": {}},
 
369
  )["messages"]
370
  prompt = render_prompt(tokenizer, messages)
371
  inputs = tokenizer.apply_chat_template(
 
440
 
441
 
442
  def cmd_generate(args: argparse.Namespace) -> None:
 
443
  question = read_question(args)
444
  candidates = read_candidates(args)
445
  model, tokenizer = load_model_for_inference(
 
456
  tokenizer=tokenizer,
457
  question=question,
458
  candidates=candidates,
 
459
  max_new_tokens=args.max_new_tokens,
460
  do_sample=args.do_sample,
461
  temperature=args.temperature,
 
503
  train_p.add_argument("--train-jsonl", required=True)
504
  train_p.add_argument("--val-jsonl")
505
  train_p.add_argument("--test-jsonl")
 
506
  train_p.add_argument("--output-dir", required=True)
507
  train_p.add_argument("--max-train-samples", type=int)
508
  train_p.add_argument("--max-eval-samples", type=int)
 
543
  gen_p.add_argument("--model-path")
544
  gen_p.add_argument("--base-model")
545
  gen_p.add_argument("--adapter-path")
 
546
  gen_p.add_argument("--question")
547
  gen_p.add_argument("--candidates-json")
548
  gen_p.add_argument("--sample-jsonl")
finetune/prompts.py CHANGED
@@ -6,35 +6,21 @@ from typing import Any, Dict, Sequence
6
 
7
  import pandas as pd
8
 
9
- SYSTEM_PROMPT = (
10
- "You are a text to SQL query translator that helps in natural language geocoding."
11
- )
12
 
13
- USER_PROMPT_TEMPLATE = """GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, generate the corresponding SQL command to retrieve the desired geometry.
14
 
15
- <SCHEMA_DETAILS>
16
- {schema_details}
17
- </SCHEMA_DETAILS>
18
-
19
- <CANDIDATES>
20
- {candidates_csv}
21
- </CANDIDATES>
22
-
23
- <USER_QUERY>
24
- {question}
25
- </USER_QUERY>
26
- """
27
-
28
- DEFAULT_SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin boundaries
29
  query: read_parquet('divisions_area')
30
  columns:
31
- id VARCHAR -- unique feature id (use to filter precisely)
32
  names STRUCT("primary" VARCHAR, ...)
33
  country VARCHAR -- ISO 3166-1 alpha-2
34
  subtype VARCHAR -- country | region | dependency | county | localadmin |
35
  locality | macrohood | neighborhood | microhood
36
  class VARCHAR
37
- region VARCHAR -- region code e.g. 'IN-OR'
38
  admin_level INTEGER
39
  division_id VARCHAR
40
  is_land BOOLEAN
@@ -54,9 +40,20 @@ DEFAULT_SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon
54
  is_land BOOLEAN
55
  is_territorial BOOLEAN
56
  geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
 
57
 
58
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
59
- Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly."""
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
@@ -69,10 +66,8 @@ def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
69
  def build_user_prompt(
70
  question: str,
71
  candidates: Sequence[Dict[str, Any]],
72
- schema_details: str,
73
  ) -> str:
74
  return USER_PROMPT_TEMPLATE.format(
75
- schema_details=schema_details.strip(),
76
  candidates_csv=candidates_to_csv(candidates).strip(),
77
  question=question.strip(),
78
  )
@@ -80,12 +75,10 @@ def build_user_prompt(
80
 
81
  def make_prompt_completion(
82
  sample: Dict[str, Any],
83
- schema_details: str,
84
  ) -> Dict[str, str]:
85
  prompt = SYSTEM_PROMPT + "\n\n" + build_user_prompt(
86
  question=sample["question"],
87
  candidates=sample["candidates"],
88
- schema_details=schema_details,
89
  )
90
  completion = sample.get("target", {}).get("sql", "")
91
  return {"prompt": prompt, "completion": completion}
 
6
 
7
  import pandas as pd
8
 
9
+ SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
 
 
10
 
11
+ You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
12
 
13
+ <SCHEMA>
14
+ 1. divisions_area -- Overture polygon/multipolygon admin boundaries
 
 
 
 
 
 
 
 
 
 
 
 
15
  query: read_parquet('divisions_area')
16
  columns:
17
+ id VARCHAR -- unique feature id
18
  names STRUCT("primary" VARCHAR, ...)
19
  country VARCHAR -- ISO 3166-1 alpha-2
20
  subtype VARCHAR -- country | region | dependency | county | localadmin |
21
  locality | macrohood | neighborhood | microhood
22
  class VARCHAR
23
+ region VARCHAR
24
  admin_level INTEGER
25
  division_id VARCHAR
26
  is_land BOOLEAN
 
40
  is_land BOOLEAN
41
  is_territorial BOOLEAN
42
  geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
43
+ </SCHEMA>
44
 
45
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
46
+ Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
47
+ Use ST_AsGeoJSON(geometry) for all geometry outputs."""
48
+
49
+ USER_PROMPT_TEMPLATE = """<CANDIDATES>
50
+ {candidates_csv}
51
+ </CANDIDATES>
52
+
53
+ <USER_QUERY>
54
+ {question}
55
+ </USER_QUERY>
56
+ """
57
 
58
 
59
  def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
 
66
  def build_user_prompt(
67
  question: str,
68
  candidates: Sequence[Dict[str, Any]],
 
69
  ) -> str:
70
  return USER_PROMPT_TEMPLATE.format(
 
71
  candidates_csv=candidates_to_csv(candidates).strip(),
72
  question=question.strip(),
73
  )
 
75
 
76
  def make_prompt_completion(
77
  sample: Dict[str, Any],
 
78
  ) -> Dict[str, str]:
79
  prompt = SYSTEM_PROMPT + "\n\n" + build_user_prompt(
80
  question=sample["question"],
81
  candidates=sample["candidates"],
 
82
  )
83
  completion = sample.get("target", {}).get("sql", "")
84
  return {"prompt": prompt, "completion": completion}
finetune/train_modal_qwen35.py CHANGED
@@ -123,8 +123,7 @@ def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples
123
  """Load JSONL data and apply Qwen3.5 chat template.
124
 
125
  Each sample must have:
126
- prompt: list of {role, content} dicts (system + user)
127
- completion: list of {role, content} dicts (assistant)
128
 
129
  The chat template produces the full ChatML string including the assistant turn.
130
  train_on_responses_only then masks everything except the assistant response.
@@ -143,7 +142,7 @@ def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples
143
 
144
  def to_message(sample: dict) -> dict:
145
  text = tokenizer.apply_chat_template(
146
- sample["prompt"] + sample["completion"],
147
  tokenize=False,
148
  add_generation_prompt=False,
149
  )
 
123
  """Load JSONL data and apply Qwen3.5 chat template.
124
 
125
  Each sample must have:
126
+ messages: list of {role, content} dicts (system + user + assistant)
 
127
 
128
  The chat template produces the full ChatML string including the assistant turn.
129
  train_on_responses_only then masks everything except the assistant response.
 
142
 
143
  def to_message(sample: dict) -> dict:
144
  text = tokenizer.apply_chat_template(
145
+ sample["messages"],
146
  tokenize=False,
147
  add_generation_prompt=False,
148
  )
src/gazet/lm.py CHANGED
@@ -176,11 +176,12 @@ write_sql = SQLWriter(lm=sql_generation_lm)
176
 
177
  # ── GGUF SQL generation via llama-server ──────────────────────────────────────
178
 
179
- _SYSTEM_PROMPT = (
180
- "You are a text to SQL query translator that helps in natural language geocoding."
181
- )
182
 
183
- _SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin boundaries
 
184
  query: read_parquet('divisions_area')
185
  columns:
186
  id VARCHAR -- unique feature id
@@ -209,18 +210,13 @@ _SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin b
209
  is_land BOOLEAN
210
  is_territorial BOOLEAN
211
  geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
 
212
 
213
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
214
  Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
215
  Use ST_AsGeoJSON(geometry) for all geometry outputs."""
216
 
217
- _USER_PROMPT_TEMPLATE = """GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, generate the corresponding SQL command to retrieve the desired geometry.
218
-
219
- <SCHEMA_DETAILS>
220
- {schema_details}
221
- </SCHEMA_DETAILS>
222
-
223
- <CANDIDATES>
224
  {candidates_csv}
225
  </CANDIDATES>
226
 
@@ -269,10 +265,21 @@ def _llama_chat_complete(messages: list[dict]) -> str:
269
  return resp.json()["choices"][0]["message"]["content"]
270
 
271
 
272
- _PLACES_SYSTEM_PROMPT = (
273
- "You are a geographic entity extractor. "
274
- "Extract place names from the query and return valid JSON only."
275
- )
 
 
 
 
 
 
 
 
 
 
 
276
 
277
 
278
  def generate_places(user_query: str) -> PlacesResult:
@@ -307,7 +314,7 @@ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
307
  """Generate SQL from a natural language query using the finetuned GGUF model.
308
 
309
  Uses the same prompt format the model was trained on:
310
- SYSTEM_PROMPT + USER_PROMPT_TEMPLATE with schema, candidates CSV, and question.
311
  Single-shot β€” no retry loop (the finetuned model can't improve from error feedback).
312
  """
313
  # Keep only columns the model was trained on
@@ -316,7 +323,6 @@ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
316
  candidates_csv = candidates_df[cols].to_csv(index=False)
317
 
318
  user_prompt = _USER_PROMPT_TEMPLATE.format(
319
- schema_details=_SCHEMA_DETAILS.strip(),
320
  candidates_csv=candidates_csv.strip(),
321
  question=user_query.strip(),
322
  )
 
176
 
177
  # ── GGUF SQL generation via llama-server ──────────────────────────────────────
178
 
179
+ _SYSTEM_PROMPT = """You are a text to SQL query translator that helps in natural language geocoding.
180
+
181
+ You have access to two DuckDB parquet tables. Given a set of candidate entities and a user query, generate the SQL to retrieve the desired geometry.
182
 
183
+ <SCHEMA>
184
+ 1. divisions_area -- Overture polygon/multipolygon admin boundaries
185
  query: read_parquet('divisions_area')
186
  columns:
187
  id VARCHAR -- unique feature id
 
210
  is_land BOOLEAN
211
  is_territorial BOOLEAN
212
  geometry GEOMETRY -- WGS-84 polygon/multipolygon (spatial ext loaded)
213
+ </SCHEMA>
214
 
215
  The candidates table has a 'source' column: 'divisions_area' or 'natural_earth'.
216
  Use read_parquet('divisions_area') or read_parquet('natural_earth') accordingly.
217
  Use ST_AsGeoJSON(geometry) for all geometry outputs."""
218
 
219
+ _USER_PROMPT_TEMPLATE = """<CANDIDATES>
 
 
 
 
 
 
220
  {candidates_csv}
221
  </CANDIDATES>
222
 
 
265
  return resp.json()["choices"][0]["message"]["content"]
266
 
267
 
268
+ _PLACES_SYSTEM_PROMPT = """You are a geographic entity extractor. Extract place names from the user query and return valid JSON only.
269
+
270
+ OUTPUT FORMAT:
271
+ {"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
272
+ "country" and "subtype" are optional; omit if not applicable.
273
+
274
+ RULES:
275
+ - Only extract places explicitly mentioned. Never infer or expand (e.g. "states of India" -> extract "India" only).
276
+ - No duplicate place names.
277
+ - "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
278
+ - "subtype": include only when the geographic level is clear from the query.
279
+
280
+ SUBTYPES:
281
+ country, dependency, region, county, localadmin, locality, macrohood, neighborhood, microhood
282
+ - Default to locality for cities/towns; omit for physical features (oceans, rivers, mountains)."""
283
 
284
 
285
  def generate_places(user_query: str) -> PlacesResult:
 
314
  """Generate SQL from a natural language query using the finetuned GGUF model.
315
 
316
  Uses the same prompt format the model was trained on:
317
+ SYSTEM_PROMPT (includes schema) + USER_PROMPT_TEMPLATE with candidates CSV and question.
318
  Single-shot β€” no retry loop (the finetuned model can't improve from error feedback).
319
  """
320
  # Keep only columns the model was trained on
 
323
  candidates_csv = candidates_df[cols].to_csv(index=False)
324
 
325
  user_prompt = _USER_PROMPT_TEMPLATE.format(
 
326
  candidates_csv=candidates_csv.strip(),
327
  question=user_query.strip(),
328
  )