srmsoumya commited on
Commit
2bf5583
·
1 Parent(s): 0e492f1

Add model training script for qwen, update eval & gazet lm to work with qwen template

Browse files
finetune/eval_cli.py CHANGED
@@ -1,10 +1,10 @@
1
  """Interactive eval: run test samples through the local GGUF model.
2
 
3
  Requires llama-server running on port 8080:
4
- llama-server -m finetune/models/gemma-270m-q8.gguf -ngl 99 --port 8080 --log-disable
5
 
6
- Uses the /completion endpoint with a raw prompt string no chat template —
7
- matching how the model was fine-tuned (completion_only_loss on plain text).
8
 
9
  Usage
10
  -----
@@ -28,9 +28,9 @@ import urllib.error
28
  import urllib.request
29
  from pathlib import Path
30
 
31
- SERVER_URL = "http://localhost:8080"
32
- MAX_TOKENS = 512
33
- TEMPERATURE = 0
34
 
35
  DEFAULT_RUN_DIR = Path("dataset/output/runs/v3-symbolic-paths")
36
 
@@ -54,21 +54,22 @@ def check_server() -> bool:
54
  return False
55
 
56
 
57
- def complete(prompt: str) -> str:
58
- """Call llama-server /completion endpoint with a raw prompt string."""
59
  payload = json.dumps({
60
- "prompt": prompt,
61
  "n_predict": MAX_TOKENS,
62
  "temperature": TEMPERATURE,
 
63
  }).encode()
64
 
65
  req = urllib.request.Request(
66
- f"{SERVER_URL}/completion",
67
  data=payload,
68
  headers={"Content-Type": "application/json"},
69
  )
70
  with urllib.request.urlopen(req, timeout=60) as resp:
71
- return json.loads(resp.read())["content"]
72
 
73
 
74
  def load_samples(run_dir: Path, task: str) -> list[dict]:
@@ -92,7 +93,7 @@ def build_raw_prompt(sample: dict) -> str:
92
 
93
  def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
94
  expected = sample["completion"][0]["content"]
95
- prompt = build_raw_prompt(sample)
96
 
97
  user_content = sample["prompt"][1]["content"]
98
  if "<USER_QUERY>" in user_content:
@@ -107,8 +108,9 @@ def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool =
107
  print(f"\nQuestion: {question}\n")
108
 
109
  if verbose:
 
110
  print(f"{'─' * 60}")
111
- print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split()) } words):")
112
  print(f"{'─' * 60}")
113
  print(prompt)
114
 
@@ -121,7 +123,7 @@ def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool =
121
  print("Generated:")
122
  print(f"{'─' * 60}")
123
 
124
- raw = complete(prompt)
125
  generated = postprocess_sql(raw) if task == "sql" else raw.strip()
126
  print(generated)
127
 
@@ -145,7 +147,7 @@ def main() -> None:
145
 
146
  if not check_server():
147
  print("llama-server not running. Start it with:")
148
- print(" llama-server -m finetune/models/gemma-270m-q8.gguf -ngl 99 --port 8080 --log-disable")
149
  sys.exit(1)
150
 
151
  samples = load_samples(args.run_dir, args.task)
 
1
  """Interactive eval: run test samples through the local GGUF model.
2
 
3
  Requires llama-server running on port 8080:
4
+ llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 4096 --log-disable
5
 
6
+ Uses the /v1/chat/completions endpoint with a messages list. The Qwen3 GGUF
7
+ embeds its chat template in metadata, so llama-server applies it automatically.
8
 
9
  Usage
10
  -----
 
28
  import urllib.request
29
  from pathlib import Path
30
 
31
+ SERVER_URL = "http://localhost:9000"
32
+ MAX_TOKENS = 2048
33
+ TEMPERATURE = 0.6
34
 
35
  DEFAULT_RUN_DIR = Path("dataset/output/runs/v3-symbolic-paths")
36
 
 
54
  return False
55
 
56
 
57
+ def chat_complete(messages: list[dict]) -> str:
58
+ """Call llama-server /v1/chat/completions with a messages list."""
59
  payload = json.dumps({
60
+ "messages": messages,
61
  "n_predict": MAX_TOKENS,
62
  "temperature": TEMPERATURE,
63
+ "chat_template_kwargs": {"enable_thinking": False},
64
  }).encode()
65
 
66
  req = urllib.request.Request(
67
+ f"{SERVER_URL}/v1/chat/completions",
68
  data=payload,
69
  headers={"Content-Type": "application/json"},
70
  )
71
  with urllib.request.urlopen(req, timeout=60) as resp:
72
+ return json.loads(resp.read())["choices"][0]["message"]["content"]
73
 
74
 
75
  def load_samples(run_dir: Path, task: str) -> list[dict]:
 
93
 
94
  def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
95
  expected = sample["completion"][0]["content"]
96
+ messages = sample["prompt"]
97
 
98
  user_content = sample["prompt"][1]["content"]
99
  if "<USER_QUERY>" in user_content:
 
108
  print(f"\nQuestion: {question}\n")
109
 
110
  if verbose:
111
+ prompt = build_raw_prompt(sample)
112
  print(f"{'─' * 60}")
113
+ print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())} words):")
114
  print(f"{'─' * 60}")
115
  print(prompt)
116
 
 
123
  print("Generated:")
124
  print(f"{'─' * 60}")
125
 
126
+ raw = chat_complete(messages)
127
  generated = postprocess_sql(raw) if task == "sql" else raw.strip()
128
  print(generated)
129
 
 
147
 
148
  if not check_server():
149
  print("llama-server not running. Start it with:")
150
+ print("llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 2048 --log-disable")
151
  sys.exit(1)
152
 
153
  samples = load_samples(args.run_dir, args.task)
finetune/train_modal_qwen35.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal training script for gazet Qwen3.5 LoRA fine-tuning with Unsloth.
2
+
3
+ Key differences from train_modal.py (Gemma):
4
+ - Uses Unsloth's FastLanguageModel for memory-efficient training
5
+ - Applies Qwen3.5 chat template to format data (not plain prompt+completion strings)
6
+ - Uses train_on_responses_only with ChatML markers to mask non-assistant tokens
7
+ - Saves merged 16-bit model via unsloth's save_pretrained_merged
8
+
9
+ Usage
10
+ -----
11
+ modal run finetune/train_modal_qwen35.py
12
+ modal run finetune/train_modal_qwen35.py --base-model unsloth/Qwen3.5-0.8B
13
+ modal run finetune/train_modal_qwen35.py --run-dir /mnt/gazet/data/v3-symbolic-paths
14
+ modal run finetune/train_modal_qwen35.py --num-train-epochs 5 --lora-r 32
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import pathlib
20
+ from dataclasses import dataclass
21
+ from datetime import datetime
22
+ from typing import Optional
23
+
24
+ import modal
25
+
26
+ app = modal.App("gazet-nlg-qwen35-finetune")
27
+
28
+ GPU_TYPE = "A100-80GB"
29
+ TIMEOUT_HOURS = 24
30
+ MAX_RETRIES = 1
31
+
32
+ train_image = (
33
+ modal.Image.debian_slim(python_version="3.11")
34
+ .pip_install(
35
+ # Use unsloth's bundled CUDA+torch extra so bitsandbytes, xformers,
36
+ # and trl are all resolved together against the same CUDA/torch build.
37
+ # Mirrors the approach in https://modal.com/docs/examples/unsloth_finetune
38
+ "unsloth[cu129-torch280]",
39
+ "unsloth_zoo",
40
+ "transformers~=5.2.0",
41
+ "hf-transfer==0.1.9",
42
+ "trackio[gpu]==0.21.1",
43
+ "datasets",
44
+ "pandas",
45
+ )
46
+ .add_local_python_source("finetune", copy=True)
47
+ .env({
48
+ "HF_HOME": "/mnt/gazet/model_cache",
49
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
50
+ })
51
+ )
52
+
53
+ with train_image.imports():
54
+ from unsloth import FastLanguageModel
55
+ from unsloth.chat_templates import train_on_responses_only
56
+ from trl import SFTConfig, SFTTrainer
57
+ from transformers import set_seed
58
+
59
+ gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
60
+
61
+ VOLUMES = {
62
+ "/mnt/gazet": gazet_vol,
63
+ }
64
+
65
+ # ChatML response markers for Qwen3.5 — the empty <think> block is how Qwen3.5
66
+ # formats non-thinking responses. We train only on tokens after this prefix.
67
+ INSTRUCTION_PART = "<|im_start|>user\n"
68
+ RESPONSE_PART = "<|im_start|>assistant\n<think>\n\n</think>\n\n"
69
+
70
+
71
+ @dataclass
72
+ class Qwen35Config:
73
+ # Model
74
+ base_model: str = "unsloth/Qwen3.5-0.8B"
75
+
76
+ # Dataset — path to run dir with {task}/{split}.jsonl files
77
+ run_dir: str = "/mnt/gazet/data/v3-symbolic-paths"
78
+ max_train_samples: Optional[int] = None
79
+ max_eval_samples: Optional[int] = None
80
+
81
+ # Sequence
82
+ max_seq_length: int = 2048
83
+
84
+ # LoRA — alpha=2*r follows unsloth recommendation for Qwen models
85
+ lora_r: int = 16
86
+ lora_alpha: int = 32
87
+ lora_dropout: float = 0.0
88
+
89
+ # Training
90
+ num_train_epochs: int = 1
91
+ per_device_train_batch_size: int = 32
92
+ per_device_eval_batch_size: int = 16
93
+ gradient_accumulation_steps: int = 1 # effective batch = 48
94
+ learning_rate: float = 1e-4
95
+ max_grad_norm: float = 1.0
96
+ warmup_steps: int = 50
97
+ lr_scheduler_type: str = "linear"
98
+ weight_decay: float = 0.01
99
+ optim: str = "adamw_8bit"
100
+
101
+ # Logging / saving
102
+ logging_steps: int = 10
103
+ save_strategy: str = "steps"
104
+ save_steps: int = 400
105
+ eval_strategy: str = "steps"
106
+ eval_steps: int = 200
107
+ report_to: str = "trackio"
108
+ trackio_space_id: Optional[str] = "srmsoumya/gazet-trackio"
109
+ project: str = "gazet-nlg-qwen35"
110
+
111
+ # Experiment
112
+ seed: int = 42
113
+ experiment_name: Optional[str] = None
114
+
115
+ def __post_init__(self):
116
+ if self.experiment_name is None:
117
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
118
+ model_short = self.base_model.split("/")[-1]
119
+ self.experiment_name = f"{model_short}-r{self.lora_r}-{timestamp}"
120
+
121
+
122
+ def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples=None):
123
+ """Load JSONL data and apply Qwen3.5 chat template.
124
+
125
+ Each sample must have:
126
+ prompt: list of {role, content} dicts (system + user)
127
+ completion: list of {role, content} dicts (assistant)
128
+
129
+ The chat template produces the full ChatML string including the assistant turn.
130
+ train_on_responses_only then masks everything except the assistant response.
131
+ """
132
+ import json
133
+ from datasets import Dataset, DatasetDict
134
+
135
+ def load_jsonl(path: pathlib.Path) -> list[dict]:
136
+ rows = []
137
+ with open(path) as f:
138
+ for line in f:
139
+ line = line.strip()
140
+ if line:
141
+ rows.append(json.loads(line))
142
+ return rows
143
+
144
+ def to_message(sample: dict) -> dict:
145
+ text = tokenizer.apply_chat_template(
146
+ sample["prompt"] + sample["completion"],
147
+ tokenize=False,
148
+ add_generation_prompt=False,
149
+ )
150
+ return {"messages": text}
151
+
152
+ run_dir = pathlib.Path(run_dir)
153
+ tasks = ("sql", "places")
154
+ splits = ("train", "val")
155
+ ds_dict: dict = {}
156
+
157
+ for split in splits:
158
+ combined: list[dict] = []
159
+ for task in tasks:
160
+ path = run_dir / task / f"{split}.jsonl"
161
+ if not path.exists():
162
+ print(f"Missing {path} — skipping")
163
+ continue
164
+ rows = load_jsonl(path)
165
+ flattened = [to_message(r) for r in rows]
166
+ combined.extend(flattened)
167
+ print(f"Loaded {len(rows):,} {task}/{split} rows")
168
+
169
+ if combined:
170
+ ds_dict[split] = Dataset.from_list(combined)
171
+ print(f"{split} split: {len(combined):,} total rows")
172
+
173
+ ds = DatasetDict(ds_dict).shuffle(seed=42)
174
+
175
+ if max_train_samples is not None and "train" in ds:
176
+ ds["train"] = ds["train"].select(range(min(max_train_samples, len(ds["train"]))))
177
+ if max_eval_samples is not None and "val" in ds:
178
+ ds["val"] = ds["val"].select(range(min(max_eval_samples, len(ds["val"]))))
179
+
180
+ return ds
181
+
182
+
183
+ def _find_latest_checkpoint(checkpoint_dir: pathlib.Path) -> str | None:
184
+ if not checkpoint_dir.exists():
185
+ return None
186
+ checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
187
+ if not checkpoints:
188
+ return None
189
+ latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
190
+ print(f"Found existing checkpoint: {latest}")
191
+ return str(latest)
192
+
193
+
194
+ @app.function(
195
+ image=train_image,
196
+ gpu=GPU_TYPE,
197
+ volumes=VOLUMES,
198
+ secrets=[modal.Secret.from_name("huggingface-secret")],
199
+ timeout=TIMEOUT_HOURS * 60 * 60,
200
+ retries=modal.Retries(initial_delay=0.0, max_retries=MAX_RETRIES),
201
+ )
202
+ def finetune(config_dict: dict):
203
+ """Run Qwen3.5 LoRA SFT training with Unsloth inside a Modal container."""
204
+ config = Qwen35Config(**config_dict)
205
+ set_seed(config.seed)
206
+
207
+ experiment_dir = pathlib.Path("/mnt/gazet/checkpoints") / config.experiment_name
208
+ experiment_dir.mkdir(parents=True, exist_ok=True)
209
+
210
+ print(f"Experiment: {config.experiment_name}")
211
+ print(f"Model: {config.base_model}")
212
+ print(f"Run dir: {config.run_dir}")
213
+
214
+ # Load base model with unsloth — gradient checkpointing is handled internally
215
+ model, processor = FastLanguageModel.from_pretrained(
216
+ config.base_model,
217
+ max_seq_length=config.max_seq_length,
218
+ load_in_4bit=False,
219
+ use_gradient_checkpointing="unsloth",
220
+ fast_inference=False,
221
+ )
222
+ tokenizer = processor.tokenizer
223
+
224
+ # Apply LoRA adapters — let unsloth select target modules via finetune_* flags
225
+ model = FastLanguageModel.get_peft_model(
226
+ model,
227
+ r=config.lora_r,
228
+ lora_alpha=config.lora_alpha,
229
+ lora_dropout=config.lora_dropout,
230
+ finetune_vision_layers=False,
231
+ finetune_language_layers=True,
232
+ finetune_attention_modules=True,
233
+ finetune_mlp_modules=True,
234
+ bias="none",
235
+ random_state=config.seed,
236
+ use_gradient_checkpointing=False, # already set in from_pretrained
237
+ )
238
+
239
+ total_params = sum(p.numel() for p in model.parameters())
240
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
241
+ print(f"Total parameters: {total_params:,}")
242
+ print(f"Trainable parameters: {trainable_params:,}")
243
+
244
+ ds = _load_data(
245
+ config.run_dir,
246
+ tokenizer,
247
+ max_train_samples=config.max_train_samples,
248
+ max_eval_samples=config.max_eval_samples,
249
+ )
250
+ print(f"Train samples: {len(ds['train']):,}")
251
+ if "val" in ds:
252
+ print(f"Val samples: {len(ds['val']):,}")
253
+ effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
254
+ print(f"Effective batch: {effective_batch}")
255
+
256
+ sft_args = SFTConfig(
257
+ output_dir=str(experiment_dir),
258
+ dataset_text_field="messages",
259
+ max_seq_length=config.max_seq_length,
260
+ num_train_epochs=config.num_train_epochs,
261
+ per_device_train_batch_size=config.per_device_train_batch_size,
262
+ per_device_eval_batch_size=config.per_device_eval_batch_size,
263
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
264
+ learning_rate=config.learning_rate,
265
+ max_grad_norm=config.max_grad_norm,
266
+ warmup_steps=config.warmup_steps,
267
+ lr_scheduler_type=config.lr_scheduler_type,
268
+ weight_decay=config.weight_decay,
269
+ optim=config.optim,
270
+ bf16=True,
271
+ logging_steps=config.logging_steps,
272
+ save_strategy=config.save_strategy,
273
+ save_steps=config.save_steps,
274
+ eval_strategy=config.eval_strategy,
275
+ eval_steps=config.eval_steps,
276
+ report_to=config.report_to,
277
+ trackio_space_id=config.trackio_space_id,
278
+ project=config.project,
279
+ dataset_num_proc=8,
280
+ seed=config.seed,
281
+ )
282
+
283
+ trainer = SFTTrainer(
284
+ model=model,
285
+ tokenizer=tokenizer,
286
+ train_dataset=ds["train"],
287
+ eval_dataset=ds.get("val"),
288
+ args=sft_args,
289
+ )
290
+
291
+ # Mask all tokens except the assistant response — train on completions only
292
+ trainer = train_on_responses_only(
293
+ trainer,
294
+ instruction_part=INSTRUCTION_PART,
295
+ response_part=RESPONSE_PART,
296
+ )
297
+
298
+ resume_from = _find_latest_checkpoint(experiment_dir)
299
+ if resume_from:
300
+ print(f"Resuming from {resume_from}")
301
+
302
+ trainer.train(resume_from_checkpoint=resume_from)
303
+
304
+ # Save LoRA adapter + tokenizer (lightweight, for future merging)
305
+ print(f"Saving LoRA adapter to {experiment_dir}")
306
+ model.save_pretrained(str(experiment_dir))
307
+ tokenizer.save_pretrained(str(experiment_dir))
308
+
309
+ # Save merged 16-bit model (full weights, ready for inference / GGUF conversion)
310
+ merged_dir = experiment_dir / "merged"
311
+ merged_dir.mkdir(parents=True, exist_ok=True)
312
+ print(f"Saving merged 16-bit model to {merged_dir}")
313
+ model.save_pretrained_merged(str(merged_dir), tokenizer, save_method="merged_16bit")
314
+
315
+ gazet_vol.commit()
316
+ print(f"Training complete: {config.experiment_name}")
317
+ return config.experiment_name
318
+
319
+
320
+ @app.local_entrypoint()
321
+ def main(
322
+ base_model: Optional[str] = None,
323
+ experiment_name: Optional[str] = None,
324
+ run_dir: Optional[str] = None,
325
+ num_train_epochs: Optional[int] = None,
326
+ per_device_train_batch_size: Optional[int] = None,
327
+ max_train_samples: Optional[int] = None,
328
+ max_eval_samples: Optional[int] = None,
329
+ lora_r: Optional[int] = None,
330
+ max_seq_length: Optional[int] = None,
331
+ ):
332
+ overrides = {
333
+ k: v for k, v in dict(
334
+ base_model=base_model,
335
+ experiment_name=experiment_name,
336
+ run_dir=run_dir,
337
+ num_train_epochs=num_train_epochs,
338
+ per_device_train_batch_size=per_device_train_batch_size,
339
+ max_train_samples=max_train_samples,
340
+ max_eval_samples=max_eval_samples,
341
+ lora_r=lora_r,
342
+ max_seq_length=max_seq_length,
343
+ ).items() if v is not None
344
+ }
345
+
346
+ config = Qwen35Config(**overrides)
347
+ # lora_alpha follows r if r was overridden and alpha wasn't
348
+ if lora_r is not None:
349
+ config.lora_alpha = 2 * config.lora_r
350
+
351
+ print(f"Starting experiment: {config.experiment_name}")
352
+ print(f"Model: {config.base_model}")
353
+ print(f"Run dir: {config.run_dir}")
354
+ print(f"LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
355
+ effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
356
+ print(f"Effective batch: {effective_batch}")
357
+
358
+ result = finetune.remote(config.__dict__)
359
+ print(f"Training complete: {result}")
gazet_demo.py CHANGED
@@ -121,7 +121,7 @@ backend = st.sidebar.radio(
121
  format_func=lambda x: "⚡ GGUF (llama-server)" if x == "gguf" else "🧠 DSPy (cloud LM)",
122
  )
123
  st.sidebar.caption(
124
- "**gguf** → finetuned Gemma 270m via llama-server\n\n"
125
  "**dspy** → Ollama / cloud LM with retry loop"
126
  )
127
 
 
121
  format_func=lambda x: "⚡ GGUF (llama-server)" if x == "gguf" else "🧠 DSPy (cloud LM)",
122
  )
123
  st.sidebar.caption(
124
+ "**gguf** → finetuned Qwen3.5 via llama-server\n\n"
125
  "**dspy** → Ollama / cloud LM with retry loop"
126
  )
127
 
src/gazet/config.py CHANGED
@@ -27,7 +27,7 @@ SQL_GENERATION_MODEL = "gpt-oss:20b-cloud"
27
  MAX_SQL_ITERATIONS = 5
28
 
29
  # ── GGUF / llama-server config ────────────────────────────────────────────────
30
- LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://localhost:8080")
31
  LLAMA_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "2048"))
32
  LLAMA_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0"))
33
 
 
27
  MAX_SQL_ITERATIONS = 5
28
 
29
  # ── GGUF / llama-server config ────────────────────────────────────────────────
30
+ LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://localhost:9000")
31
  LLAMA_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "2048"))
32
  LLAMA_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0"))
33
 
src/gazet/lm.py CHANGED
@@ -251,21 +251,22 @@ def is_llama_server_available() -> bool:
251
  return False
252
 
253
 
254
- def _llama_complete(prompt: str) -> str:
255
- """Call llama-server /completion endpoint and return generated text."""
256
  resp = httpx.post(
257
- f"{LLAMA_SERVER_URL}/completion",
258
  json={
259
- "prompt": prompt,
260
  "n_predict": LLAMA_MAX_TOKENS,
261
  "temperature": LLAMA_TEMPERATURE,
 
262
  },
263
  timeout=60,
264
  )
265
  if resp.status_code != 200:
266
  logger.error("llama-server %s: %s", resp.status_code, resp.text[:500])
267
  resp.raise_for_status()
268
- return resp.json()["content"]
269
 
270
 
271
  _PLACES_SYSTEM_PROMPT = (
@@ -280,8 +281,11 @@ def generate_places(user_query: str) -> PlacesResult:
280
  Uses the same prompt format the model was trained on.
281
  Returns a PlacesResult; falls back to an empty result on parse failure.
282
  """
283
- raw_prompt = _PLACES_SYSTEM_PROMPT + "\n\n" + user_query
284
- raw_output = _llama_complete(raw_prompt).strip()
 
 
 
285
 
286
  # Strip markdown fences if the model wrapped the JSON
287
  if raw_output.startswith("```"):
@@ -317,6 +321,9 @@ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
317
  question=user_query.strip(),
318
  )
319
 
320
- raw_prompt = _SYSTEM_PROMPT + "\n\n" + user_prompt
321
- raw_output = _llama_complete(raw_prompt)
 
 
 
322
  return _postprocess_sql(raw_output)
 
251
  return False
252
 
253
 
254
+ def _llama_chat_complete(messages: list[dict]) -> str:
255
+ """Call llama-server /v1/chat/completions with a messages list."""
256
  resp = httpx.post(
257
+ f"{LLAMA_SERVER_URL}/v1/chat/completions",
258
  json={
259
+ "messages": messages,
260
  "n_predict": LLAMA_MAX_TOKENS,
261
  "temperature": LLAMA_TEMPERATURE,
262
+ "chat_template_kwargs": {"enable_thinking": False},
263
  },
264
  timeout=60,
265
  )
266
  if resp.status_code != 200:
267
  logger.error("llama-server %s: %s", resp.status_code, resp.text[:500])
268
  resp.raise_for_status()
269
+ return resp.json()["choices"][0]["message"]["content"]
270
 
271
 
272
  _PLACES_SYSTEM_PROMPT = (
 
281
  Uses the same prompt format the model was trained on.
282
  Returns a PlacesResult; falls back to an empty result on parse failure.
283
  """
284
+ messages = [
285
+ {"role": "system", "content": _PLACES_SYSTEM_PROMPT},
286
+ {"role": "user", "content": user_query},
287
+ ]
288
+ raw_output = _llama_chat_complete(messages).strip()
289
 
290
  # Strip markdown fences if the model wrapped the JSON
291
  if raw_output.startswith("```"):
 
321
  question=user_query.strip(),
322
  )
323
 
324
+ messages = [
325
+ {"role": "system", "content": _SYSTEM_PROMPT},
326
+ {"role": "user", "content": user_prompt},
327
+ ]
328
+ raw_output = _llama_chat_complete(messages)
329
  return _postprocess_sql(raw_output)
src/gazet/search.py CHANGED
@@ -5,67 +5,6 @@ from .config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
5
  from .schemas import Place
6
 
7
 
8
- def _fuzzy_search(
9
- con: duckdb.DuckDBPyConnection,
10
- path: str,
11
- source: str,
12
- place: Place,
13
- extra_select: str = "",
14
- limit: int = 5,
15
- is_overture: bool = False,
16
- ) -> pd.DataFrame:
17
- """Generic Levenshtein fuzzy search against any parquet with a names.primary column."""
18
- country_filter = ""
19
- country_params: list = []
20
- if is_overture and place.country:
21
- country_filter = "AND country = ?"
22
- country_params = [place.country]
23
-
24
- subtype_filter = ""
25
- subtype_params: list = []
26
- if is_overture and place.subtype:
27
- subtype_filter = "AND subtype = ?"
28
- subtype_params = [place.subtype]
29
-
30
- params = (
31
- [place.place, place.place, path] + country_params + subtype_params + [limit]
32
- )
33
-
34
- extra_clause = f", {extra_select}" if extra_select else ""
35
- rel = con.execute(
36
- f"""
37
- SELECT
38
- id,
39
- names."primary" AS name,
40
- country,
41
- subtype,
42
- class,
43
- region,
44
- admin_level,
45
- is_land,
46
- is_territorial{extra_clause},
47
- 1.0 - (levenshtein(lower(names."primary"), lower(?))::float
48
- / greatest(length(names."primary"), length(?), 1)) AS similarity
49
- FROM read_parquet(?)
50
- WHERE names."primary" IS NOT NULL AND trim(names."primary") != ''
51
- {country_filter}
52
- {subtype_filter}
53
- ORDER BY similarity DESC, admin_level ASC
54
- LIMIT ?
55
- """,
56
- params,
57
- )
58
- df = rel.fetchdf()
59
- df.insert(0, "source", source)
60
- label = f'"{place.place}"' + (f" [{place.country}]" if place.country else "")
61
- if df.empty:
62
- print(f"\n{source} - {label}: no matches")
63
- else:
64
- print(f"\n{source} - {label} (top {len(df)} by name similarity):")
65
- print(df.to_string(index=False))
66
- return df
67
-
68
-
69
  def simple_fuzzy_search(
70
  con: duckdb.DuckDBPyConnection,
71
  path: str,
@@ -138,22 +77,14 @@ def search_natural_earth(
138
  def search_candidates(
139
  con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
140
  ) -> list[pd.DataFrame]:
141
- """Return candidate DataFrames for a place, choosing sources by subtype.
142
 
143
- If the place has an Overture admin subtype (region, country, locality, etc.)
144
- it is definitively an admin boundary search divisions_area only.
145
- If no subtype is known, search both sources (handles seas, oceans, terrain).
146
  """
147
  results = []
148
- if place.subtype:
149
- # Known admin division — divisions_area only
150
- df = search_divisions_area(con, place, limit=limit)
151
  if not df.empty:
152
  results.append(df)
153
- else:
154
- # Ambiguous — could be physical feature or admin; search both
155
- for fn in (search_divisions_area, search_natural_earth):
156
- df = fn(con, place, limit=limit)
157
- if not df.empty:
158
- results.append(df)
159
  return results
 
5
  from .schemas import Place
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def simple_fuzzy_search(
9
  con: duckdb.DuckDBPyConnection,
10
  path: str,
 
77
  def search_candidates(
78
  con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
79
  ) -> list[pd.DataFrame]:
80
+ """Return candidate DataFrames for a place from both sources.
81
 
82
+ Always searches divisions_area and natural_earth to avoid missing
83
+ natural features when the model assigns an incorrect admin subtype.
 
84
  """
85
  results = []
86
+ for fn in (search_divisions_area, search_natural_earth):
87
+ df = fn(con, place, limit=limit)
 
88
  if not df.empty:
89
  results.append(df)
 
 
 
 
 
 
90
  return results