Spaces:
Sleeping
Sleeping
Add model training script for qwen, update eval & gazet lm to work with qwen template
Browse files- finetune/eval_cli.py +17 -15
- finetune/train_modal_qwen35.py +359 -0
- gazet_demo.py +1 -1
- src/gazet/config.py +1 -1
- src/gazet/lm.py +16 -9
- src/gazet/search.py +5 -74
finetune/eval_cli.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
"""Interactive eval: run test samples through the local GGUF model.
|
| 2 |
|
| 3 |
Requires llama-server running on port 8080:
|
| 4 |
-
llama-server -m finetune/models/
|
| 5 |
|
| 6 |
-
Uses the /
|
| 7 |
-
|
| 8 |
|
| 9 |
Usage
|
| 10 |
-----
|
|
@@ -28,9 +28,9 @@ import urllib.error
|
|
| 28 |
import urllib.request
|
| 29 |
from pathlib import Path
|
| 30 |
|
| 31 |
-
SERVER_URL = "http://localhost:
|
| 32 |
-
MAX_TOKENS =
|
| 33 |
-
TEMPERATURE = 0
|
| 34 |
|
| 35 |
DEFAULT_RUN_DIR = Path("dataset/output/runs/v3-symbolic-paths")
|
| 36 |
|
|
@@ -54,21 +54,22 @@ def check_server() -> bool:
|
|
| 54 |
return False
|
| 55 |
|
| 56 |
|
| 57 |
-
def
|
| 58 |
-
"""Call llama-server /
|
| 59 |
payload = json.dumps({
|
| 60 |
-
"
|
| 61 |
"n_predict": MAX_TOKENS,
|
| 62 |
"temperature": TEMPERATURE,
|
|
|
|
| 63 |
}).encode()
|
| 64 |
|
| 65 |
req = urllib.request.Request(
|
| 66 |
-
f"{SERVER_URL}/
|
| 67 |
data=payload,
|
| 68 |
headers={"Content-Type": "application/json"},
|
| 69 |
)
|
| 70 |
with urllib.request.urlopen(req, timeout=60) as resp:
|
| 71 |
-
return json.loads(resp.read())["content"]
|
| 72 |
|
| 73 |
|
| 74 |
def load_samples(run_dir: Path, task: str) -> list[dict]:
|
|
@@ -92,7 +93,7 @@ def build_raw_prompt(sample: dict) -> str:
|
|
| 92 |
|
| 93 |
def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
|
| 94 |
expected = sample["completion"][0]["content"]
|
| 95 |
-
|
| 96 |
|
| 97 |
user_content = sample["prompt"][1]["content"]
|
| 98 |
if "<USER_QUERY>" in user_content:
|
|
@@ -107,8 +108,9 @@ def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool =
|
|
| 107 |
print(f"\nQuestion: {question}\n")
|
| 108 |
|
| 109 |
if verbose:
|
|
|
|
| 110 |
print(f"{'─' * 60}")
|
| 111 |
-
print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())
|
| 112 |
print(f"{'─' * 60}")
|
| 113 |
print(prompt)
|
| 114 |
|
|
@@ -121,7 +123,7 @@ def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool =
|
|
| 121 |
print("Generated:")
|
| 122 |
print(f"{'─' * 60}")
|
| 123 |
|
| 124 |
-
raw =
|
| 125 |
generated = postprocess_sql(raw) if task == "sql" else raw.strip()
|
| 126 |
print(generated)
|
| 127 |
|
|
@@ -145,7 +147,7 @@ def main() -> None:
|
|
| 145 |
|
| 146 |
if not check_server():
|
| 147 |
print("llama-server not running. Start it with:")
|
| 148 |
-
print("
|
| 149 |
sys.exit(1)
|
| 150 |
|
| 151 |
samples = load_samples(args.run_dir, args.task)
|
|
|
|
| 1 |
"""Interactive eval: run test samples through the local GGUF model.
|
| 2 |
|
| 3 |
Requires llama-server running on port 8080:
|
| 4 |
+
llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 4096 --log-disable
|
| 5 |
|
| 6 |
+
Uses the /v1/chat/completions endpoint with a messages list. The Qwen3 GGUF
|
| 7 |
+
embeds its chat template in metadata, so llama-server applies it automatically.
|
| 8 |
|
| 9 |
Usage
|
| 10 |
-----
|
|
|
|
| 28 |
import urllib.request
|
| 29 |
from pathlib import Path
|
| 30 |
|
| 31 |
+
SERVER_URL = "http://localhost:9000"
|
| 32 |
+
MAX_TOKENS = 2048
|
| 33 |
+
TEMPERATURE = 0.6
|
| 34 |
|
| 35 |
DEFAULT_RUN_DIR = Path("dataset/output/runs/v3-symbolic-paths")
|
| 36 |
|
|
|
|
| 54 |
return False
|
| 55 |
|
| 56 |
|
| 57 |
+
def chat_complete(messages: list[dict]) -> str:
|
| 58 |
+
"""Call llama-server /v1/chat/completions with a messages list."""
|
| 59 |
payload = json.dumps({
|
| 60 |
+
"messages": messages,
|
| 61 |
"n_predict": MAX_TOKENS,
|
| 62 |
"temperature": TEMPERATURE,
|
| 63 |
+
"chat_template_kwargs": {"enable_thinking": False},
|
| 64 |
}).encode()
|
| 65 |
|
| 66 |
req = urllib.request.Request(
|
| 67 |
+
f"{SERVER_URL}/v1/chat/completions",
|
| 68 |
data=payload,
|
| 69 |
headers={"Content-Type": "application/json"},
|
| 70 |
)
|
| 71 |
with urllib.request.urlopen(req, timeout=60) as resp:
|
| 72 |
+
return json.loads(resp.read())["choices"][0]["message"]["content"]
|
| 73 |
|
| 74 |
|
| 75 |
def load_samples(run_dir: Path, task: str) -> list[dict]:
|
|
|
|
| 93 |
|
| 94 |
def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
|
| 95 |
expected = sample["completion"][0]["content"]
|
| 96 |
+
messages = sample["prompt"]
|
| 97 |
|
| 98 |
user_content = sample["prompt"][1]["content"]
|
| 99 |
if "<USER_QUERY>" in user_content:
|
|
|
|
| 108 |
print(f"\nQuestion: {question}\n")
|
| 109 |
|
| 110 |
if verbose:
|
| 111 |
+
prompt = build_raw_prompt(sample)
|
| 112 |
print(f"{'─' * 60}")
|
| 113 |
+
print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())} words):")
|
| 114 |
print(f"{'─' * 60}")
|
| 115 |
print(prompt)
|
| 116 |
|
|
|
|
| 123 |
print("Generated:")
|
| 124 |
print(f"{'─' * 60}")
|
| 125 |
|
| 126 |
+
raw = chat_complete(messages)
|
| 127 |
generated = postprocess_sql(raw) if task == "sql" else raw.strip()
|
| 128 |
print(generated)
|
| 129 |
|
|
|
|
| 147 |
|
| 148 |
if not check_server():
|
| 149 |
print("llama-server not running. Start it with:")
|
| 150 |
+
print("llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 2048 --log-disable")
|
| 151 |
sys.exit(1)
|
| 152 |
|
| 153 |
samples = load_samples(args.run_dir, args.task)
|
finetune/train_modal_qwen35.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal training script for gazet Qwen3.5 LoRA fine-tuning with Unsloth.
|
| 2 |
+
|
| 3 |
+
Key differences from train_modal.py (Gemma):
|
| 4 |
+
- Uses Unsloth's FastLanguageModel for memory-efficient training
|
| 5 |
+
- Applies Qwen3.5 chat template to format data (not plain prompt+completion strings)
|
| 6 |
+
- Uses train_on_responses_only with ChatML markers to mask non-assistant tokens
|
| 7 |
+
- Saves merged 16-bit model via unsloth's save_pretrained_merged
|
| 8 |
+
|
| 9 |
+
Usage
|
| 10 |
+
-----
|
| 11 |
+
modal run finetune/train_modal_qwen35.py
|
| 12 |
+
modal run finetune/train_modal_qwen35.py --base-model unsloth/Qwen3.5-0.8B
|
| 13 |
+
modal run finetune/train_modal_qwen35.py --run-dir /mnt/gazet/data/v3-symbolic-paths
|
| 14 |
+
modal run finetune/train_modal_qwen35.py --num-train-epochs 5 --lora-r 32
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import pathlib
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import modal
|
| 25 |
+
|
| 26 |
+
app = modal.App("gazet-nlg-qwen35-finetune")
|
| 27 |
+
|
| 28 |
+
GPU_TYPE = "A100-80GB"
|
| 29 |
+
TIMEOUT_HOURS = 24
|
| 30 |
+
MAX_RETRIES = 1
|
| 31 |
+
|
| 32 |
+
train_image = (
|
| 33 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 34 |
+
.pip_install(
|
| 35 |
+
# Use unsloth's bundled CUDA+torch extra so bitsandbytes, xformers,
|
| 36 |
+
# and trl are all resolved together against the same CUDA/torch build.
|
| 37 |
+
# Mirrors the approach in https://modal.com/docs/examples/unsloth_finetune
|
| 38 |
+
"unsloth[cu129-torch280]",
|
| 39 |
+
"unsloth_zoo",
|
| 40 |
+
"transformers~=5.2.0",
|
| 41 |
+
"hf-transfer==0.1.9",
|
| 42 |
+
"trackio[gpu]==0.21.1",
|
| 43 |
+
"datasets",
|
| 44 |
+
"pandas",
|
| 45 |
+
)
|
| 46 |
+
.add_local_python_source("finetune", copy=True)
|
| 47 |
+
.env({
|
| 48 |
+
"HF_HOME": "/mnt/gazet/model_cache",
|
| 49 |
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
| 50 |
+
})
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
with train_image.imports():
|
| 54 |
+
from unsloth import FastLanguageModel
|
| 55 |
+
from unsloth.chat_templates import train_on_responses_only
|
| 56 |
+
from trl import SFTConfig, SFTTrainer
|
| 57 |
+
from transformers import set_seed
|
| 58 |
+
|
| 59 |
+
gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
|
| 60 |
+
|
| 61 |
+
VOLUMES = {
|
| 62 |
+
"/mnt/gazet": gazet_vol,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# ChatML response markers for Qwen3.5 — the empty <think> block is how Qwen3.5
|
| 66 |
+
# formats non-thinking responses. We train only on tokens after this prefix.
|
| 67 |
+
INSTRUCTION_PART = "<|im_start|>user\n"
|
| 68 |
+
RESPONSE_PART = "<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class Qwen35Config:
|
| 73 |
+
# Model
|
| 74 |
+
base_model: str = "unsloth/Qwen3.5-0.8B"
|
| 75 |
+
|
| 76 |
+
# Dataset — path to run dir with {task}/{split}.jsonl files
|
| 77 |
+
run_dir: str = "/mnt/gazet/data/v3-symbolic-paths"
|
| 78 |
+
max_train_samples: Optional[int] = None
|
| 79 |
+
max_eval_samples: Optional[int] = None
|
| 80 |
+
|
| 81 |
+
# Sequence
|
| 82 |
+
max_seq_length: int = 2048
|
| 83 |
+
|
| 84 |
+
# LoRA — alpha=2*r follows unsloth recommendation for Qwen models
|
| 85 |
+
lora_r: int = 16
|
| 86 |
+
lora_alpha: int = 32
|
| 87 |
+
lora_dropout: float = 0.0
|
| 88 |
+
|
| 89 |
+
# Training
|
| 90 |
+
num_train_epochs: int = 1
|
| 91 |
+
per_device_train_batch_size: int = 32
|
| 92 |
+
per_device_eval_batch_size: int = 16
|
| 93 |
+
gradient_accumulation_steps: int = 1 # effective batch = 48
|
| 94 |
+
learning_rate: float = 1e-4
|
| 95 |
+
max_grad_norm: float = 1.0
|
| 96 |
+
warmup_steps: int = 50
|
| 97 |
+
lr_scheduler_type: str = "linear"
|
| 98 |
+
weight_decay: float = 0.01
|
| 99 |
+
optim: str = "adamw_8bit"
|
| 100 |
+
|
| 101 |
+
# Logging / saving
|
| 102 |
+
logging_steps: int = 10
|
| 103 |
+
save_strategy: str = "steps"
|
| 104 |
+
save_steps: int = 400
|
| 105 |
+
eval_strategy: str = "steps"
|
| 106 |
+
eval_steps: int = 200
|
| 107 |
+
report_to: str = "trackio"
|
| 108 |
+
trackio_space_id: Optional[str] = "srmsoumya/gazet-trackio"
|
| 109 |
+
project: str = "gazet-nlg-qwen35"
|
| 110 |
+
|
| 111 |
+
# Experiment
|
| 112 |
+
seed: int = 42
|
| 113 |
+
experiment_name: Optional[str] = None
|
| 114 |
+
|
| 115 |
+
def __post_init__(self):
|
| 116 |
+
if self.experiment_name is None:
|
| 117 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 118 |
+
model_short = self.base_model.split("/")[-1]
|
| 119 |
+
self.experiment_name = f"{model_short}-r{self.lora_r}-{timestamp}"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _load_data(run_dir: str, tokenizer, max_train_samples=None, max_eval_samples=None):
|
| 123 |
+
"""Load JSONL data and apply Qwen3.5 chat template.
|
| 124 |
+
|
| 125 |
+
Each sample must have:
|
| 126 |
+
prompt: list of {role, content} dicts (system + user)
|
| 127 |
+
completion: list of {role, content} dicts (assistant)
|
| 128 |
+
|
| 129 |
+
The chat template produces the full ChatML string including the assistant turn.
|
| 130 |
+
train_on_responses_only then masks everything except the assistant response.
|
| 131 |
+
"""
|
| 132 |
+
import json
|
| 133 |
+
from datasets import Dataset, DatasetDict
|
| 134 |
+
|
| 135 |
+
def load_jsonl(path: pathlib.Path) -> list[dict]:
|
| 136 |
+
rows = []
|
| 137 |
+
with open(path) as f:
|
| 138 |
+
for line in f:
|
| 139 |
+
line = line.strip()
|
| 140 |
+
if line:
|
| 141 |
+
rows.append(json.loads(line))
|
| 142 |
+
return rows
|
| 143 |
+
|
| 144 |
+
def to_message(sample: dict) -> dict:
|
| 145 |
+
text = tokenizer.apply_chat_template(
|
| 146 |
+
sample["prompt"] + sample["completion"],
|
| 147 |
+
tokenize=False,
|
| 148 |
+
add_generation_prompt=False,
|
| 149 |
+
)
|
| 150 |
+
return {"messages": text}
|
| 151 |
+
|
| 152 |
+
run_dir = pathlib.Path(run_dir)
|
| 153 |
+
tasks = ("sql", "places")
|
| 154 |
+
splits = ("train", "val")
|
| 155 |
+
ds_dict: dict = {}
|
| 156 |
+
|
| 157 |
+
for split in splits:
|
| 158 |
+
combined: list[dict] = []
|
| 159 |
+
for task in tasks:
|
| 160 |
+
path = run_dir / task / f"{split}.jsonl"
|
| 161 |
+
if not path.exists():
|
| 162 |
+
print(f"Missing {path} — skipping")
|
| 163 |
+
continue
|
| 164 |
+
rows = load_jsonl(path)
|
| 165 |
+
flattened = [to_message(r) for r in rows]
|
| 166 |
+
combined.extend(flattened)
|
| 167 |
+
print(f"Loaded {len(rows):,} {task}/{split} rows")
|
| 168 |
+
|
| 169 |
+
if combined:
|
| 170 |
+
ds_dict[split] = Dataset.from_list(combined)
|
| 171 |
+
print(f"{split} split: {len(combined):,} total rows")
|
| 172 |
+
|
| 173 |
+
ds = DatasetDict(ds_dict).shuffle(seed=42)
|
| 174 |
+
|
| 175 |
+
if max_train_samples is not None and "train" in ds:
|
| 176 |
+
ds["train"] = ds["train"].select(range(min(max_train_samples, len(ds["train"]))))
|
| 177 |
+
if max_eval_samples is not None and "val" in ds:
|
| 178 |
+
ds["val"] = ds["val"].select(range(min(max_eval_samples, len(ds["val"]))))
|
| 179 |
+
|
| 180 |
+
return ds
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _find_latest_checkpoint(checkpoint_dir: pathlib.Path) -> str | None:
|
| 184 |
+
if not checkpoint_dir.exists():
|
| 185 |
+
return None
|
| 186 |
+
checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
|
| 187 |
+
if not checkpoints:
|
| 188 |
+
return None
|
| 189 |
+
latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
|
| 190 |
+
print(f"Found existing checkpoint: {latest}")
|
| 191 |
+
return str(latest)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@app.function(
|
| 195 |
+
image=train_image,
|
| 196 |
+
gpu=GPU_TYPE,
|
| 197 |
+
volumes=VOLUMES,
|
| 198 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 199 |
+
timeout=TIMEOUT_HOURS * 60 * 60,
|
| 200 |
+
retries=modal.Retries(initial_delay=0.0, max_retries=MAX_RETRIES),
|
| 201 |
+
)
|
| 202 |
+
def finetune(config_dict: dict):
|
| 203 |
+
"""Run Qwen3.5 LoRA SFT training with Unsloth inside a Modal container."""
|
| 204 |
+
config = Qwen35Config(**config_dict)
|
| 205 |
+
set_seed(config.seed)
|
| 206 |
+
|
| 207 |
+
experiment_dir = pathlib.Path("/mnt/gazet/checkpoints") / config.experiment_name
|
| 208 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
| 209 |
+
|
| 210 |
+
print(f"Experiment: {config.experiment_name}")
|
| 211 |
+
print(f"Model: {config.base_model}")
|
| 212 |
+
print(f"Run dir: {config.run_dir}")
|
| 213 |
+
|
| 214 |
+
# Load base model with unsloth — gradient checkpointing is handled internally
|
| 215 |
+
model, processor = FastLanguageModel.from_pretrained(
|
| 216 |
+
config.base_model,
|
| 217 |
+
max_seq_length=config.max_seq_length,
|
| 218 |
+
load_in_4bit=False,
|
| 219 |
+
use_gradient_checkpointing="unsloth",
|
| 220 |
+
fast_inference=False,
|
| 221 |
+
)
|
| 222 |
+
tokenizer = processor.tokenizer
|
| 223 |
+
|
| 224 |
+
# Apply LoRA adapters — let unsloth select target modules via finetune_* flags
|
| 225 |
+
model = FastLanguageModel.get_peft_model(
|
| 226 |
+
model,
|
| 227 |
+
r=config.lora_r,
|
| 228 |
+
lora_alpha=config.lora_alpha,
|
| 229 |
+
lora_dropout=config.lora_dropout,
|
| 230 |
+
finetune_vision_layers=False,
|
| 231 |
+
finetune_language_layers=True,
|
| 232 |
+
finetune_attention_modules=True,
|
| 233 |
+
finetune_mlp_modules=True,
|
| 234 |
+
bias="none",
|
| 235 |
+
random_state=config.seed,
|
| 236 |
+
use_gradient_checkpointing=False, # already set in from_pretrained
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 240 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 241 |
+
print(f"Total parameters: {total_params:,}")
|
| 242 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 243 |
+
|
| 244 |
+
ds = _load_data(
|
| 245 |
+
config.run_dir,
|
| 246 |
+
tokenizer,
|
| 247 |
+
max_train_samples=config.max_train_samples,
|
| 248 |
+
max_eval_samples=config.max_eval_samples,
|
| 249 |
+
)
|
| 250 |
+
print(f"Train samples: {len(ds['train']):,}")
|
| 251 |
+
if "val" in ds:
|
| 252 |
+
print(f"Val samples: {len(ds['val']):,}")
|
| 253 |
+
effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
|
| 254 |
+
print(f"Effective batch: {effective_batch}")
|
| 255 |
+
|
| 256 |
+
sft_args = SFTConfig(
|
| 257 |
+
output_dir=str(experiment_dir),
|
| 258 |
+
dataset_text_field="messages",
|
| 259 |
+
max_seq_length=config.max_seq_length,
|
| 260 |
+
num_train_epochs=config.num_train_epochs,
|
| 261 |
+
per_device_train_batch_size=config.per_device_train_batch_size,
|
| 262 |
+
per_device_eval_batch_size=config.per_device_eval_batch_size,
|
| 263 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 264 |
+
learning_rate=config.learning_rate,
|
| 265 |
+
max_grad_norm=config.max_grad_norm,
|
| 266 |
+
warmup_steps=config.warmup_steps,
|
| 267 |
+
lr_scheduler_type=config.lr_scheduler_type,
|
| 268 |
+
weight_decay=config.weight_decay,
|
| 269 |
+
optim=config.optim,
|
| 270 |
+
bf16=True,
|
| 271 |
+
logging_steps=config.logging_steps,
|
| 272 |
+
save_strategy=config.save_strategy,
|
| 273 |
+
save_steps=config.save_steps,
|
| 274 |
+
eval_strategy=config.eval_strategy,
|
| 275 |
+
eval_steps=config.eval_steps,
|
| 276 |
+
report_to=config.report_to,
|
| 277 |
+
trackio_space_id=config.trackio_space_id,
|
| 278 |
+
project=config.project,
|
| 279 |
+
dataset_num_proc=8,
|
| 280 |
+
seed=config.seed,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
trainer = SFTTrainer(
|
| 284 |
+
model=model,
|
| 285 |
+
tokenizer=tokenizer,
|
| 286 |
+
train_dataset=ds["train"],
|
| 287 |
+
eval_dataset=ds.get("val"),
|
| 288 |
+
args=sft_args,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Mask all tokens except the assistant response — train on completions only
|
| 292 |
+
trainer = train_on_responses_only(
|
| 293 |
+
trainer,
|
| 294 |
+
instruction_part=INSTRUCTION_PART,
|
| 295 |
+
response_part=RESPONSE_PART,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
resume_from = _find_latest_checkpoint(experiment_dir)
|
| 299 |
+
if resume_from:
|
| 300 |
+
print(f"Resuming from {resume_from}")
|
| 301 |
+
|
| 302 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 303 |
+
|
| 304 |
+
# Save LoRA adapter + tokenizer (lightweight, for future merging)
|
| 305 |
+
print(f"Saving LoRA adapter to {experiment_dir}")
|
| 306 |
+
model.save_pretrained(str(experiment_dir))
|
| 307 |
+
tokenizer.save_pretrained(str(experiment_dir))
|
| 308 |
+
|
| 309 |
+
# Save merged 16-bit model (full weights, ready for inference / GGUF conversion)
|
| 310 |
+
merged_dir = experiment_dir / "merged"
|
| 311 |
+
merged_dir.mkdir(parents=True, exist_ok=True)
|
| 312 |
+
print(f"Saving merged 16-bit model to {merged_dir}")
|
| 313 |
+
model.save_pretrained_merged(str(merged_dir), tokenizer, save_method="merged_16bit")
|
| 314 |
+
|
| 315 |
+
gazet_vol.commit()
|
| 316 |
+
print(f"Training complete: {config.experiment_name}")
|
| 317 |
+
return config.experiment_name
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@app.local_entrypoint()
|
| 321 |
+
def main(
|
| 322 |
+
base_model: Optional[str] = None,
|
| 323 |
+
experiment_name: Optional[str] = None,
|
| 324 |
+
run_dir: Optional[str] = None,
|
| 325 |
+
num_train_epochs: Optional[int] = None,
|
| 326 |
+
per_device_train_batch_size: Optional[int] = None,
|
| 327 |
+
max_train_samples: Optional[int] = None,
|
| 328 |
+
max_eval_samples: Optional[int] = None,
|
| 329 |
+
lora_r: Optional[int] = None,
|
| 330 |
+
max_seq_length: Optional[int] = None,
|
| 331 |
+
):
|
| 332 |
+
overrides = {
|
| 333 |
+
k: v for k, v in dict(
|
| 334 |
+
base_model=base_model,
|
| 335 |
+
experiment_name=experiment_name,
|
| 336 |
+
run_dir=run_dir,
|
| 337 |
+
num_train_epochs=num_train_epochs,
|
| 338 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 339 |
+
max_train_samples=max_train_samples,
|
| 340 |
+
max_eval_samples=max_eval_samples,
|
| 341 |
+
lora_r=lora_r,
|
| 342 |
+
max_seq_length=max_seq_length,
|
| 343 |
+
).items() if v is not None
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
config = Qwen35Config(**overrides)
|
| 347 |
+
# lora_alpha follows r if r was overridden and alpha wasn't
|
| 348 |
+
if lora_r is not None:
|
| 349 |
+
config.lora_alpha = 2 * config.lora_r
|
| 350 |
+
|
| 351 |
+
print(f"Starting experiment: {config.experiment_name}")
|
| 352 |
+
print(f"Model: {config.base_model}")
|
| 353 |
+
print(f"Run dir: {config.run_dir}")
|
| 354 |
+
print(f"LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
|
| 355 |
+
effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
|
| 356 |
+
print(f"Effective batch: {effective_batch}")
|
| 357 |
+
|
| 358 |
+
result = finetune.remote(config.__dict__)
|
| 359 |
+
print(f"Training complete: {result}")
|
gazet_demo.py
CHANGED
|
@@ -121,7 +121,7 @@ backend = st.sidebar.radio(
|
|
| 121 |
format_func=lambda x: "⚡ GGUF (llama-server)" if x == "gguf" else "🧠 DSPy (cloud LM)",
|
| 122 |
)
|
| 123 |
st.sidebar.caption(
|
| 124 |
-
"**gguf** → finetuned
|
| 125 |
"**dspy** → Ollama / cloud LM with retry loop"
|
| 126 |
)
|
| 127 |
|
|
|
|
| 121 |
format_func=lambda x: "⚡ GGUF (llama-server)" if x == "gguf" else "🧠 DSPy (cloud LM)",
|
| 122 |
)
|
| 123 |
st.sidebar.caption(
|
| 124 |
+
"**gguf** → finetuned Qwen3.5 via llama-server\n\n"
|
| 125 |
"**dspy** → Ollama / cloud LM with retry loop"
|
| 126 |
)
|
| 127 |
|
src/gazet/config.py
CHANGED
|
@@ -27,7 +27,7 @@ SQL_GENERATION_MODEL = "gpt-oss:20b-cloud"
|
|
| 27 |
MAX_SQL_ITERATIONS = 5
|
| 28 |
|
| 29 |
# ── GGUF / llama-server config ────────────────────────────────────────────────
|
| 30 |
-
LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://localhost:
|
| 31 |
LLAMA_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "2048"))
|
| 32 |
LLAMA_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0"))
|
| 33 |
|
|
|
|
| 27 |
MAX_SQL_ITERATIONS = 5
|
| 28 |
|
| 29 |
# ── GGUF / llama-server config ────────────────────────────────────────────────
|
| 30 |
+
LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://localhost:9000")
|
| 31 |
LLAMA_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "2048"))
|
| 32 |
LLAMA_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0"))
|
| 33 |
|
src/gazet/lm.py
CHANGED
|
@@ -251,21 +251,22 @@ def is_llama_server_available() -> bool:
|
|
| 251 |
return False
|
| 252 |
|
| 253 |
|
| 254 |
-
def
|
| 255 |
-
"""Call llama-server /
|
| 256 |
resp = httpx.post(
|
| 257 |
-
f"{LLAMA_SERVER_URL}/
|
| 258 |
json={
|
| 259 |
-
"
|
| 260 |
"n_predict": LLAMA_MAX_TOKENS,
|
| 261 |
"temperature": LLAMA_TEMPERATURE,
|
|
|
|
| 262 |
},
|
| 263 |
timeout=60,
|
| 264 |
)
|
| 265 |
if resp.status_code != 200:
|
| 266 |
logger.error("llama-server %s: %s", resp.status_code, resp.text[:500])
|
| 267 |
resp.raise_for_status()
|
| 268 |
-
return resp.json()["content"]
|
| 269 |
|
| 270 |
|
| 271 |
_PLACES_SYSTEM_PROMPT = (
|
|
@@ -280,8 +281,11 @@ def generate_places(user_query: str) -> PlacesResult:
|
|
| 280 |
Uses the same prompt format the model was trained on.
|
| 281 |
Returns a PlacesResult; falls back to an empty result on parse failure.
|
| 282 |
"""
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
# Strip markdown fences if the model wrapped the JSON
|
| 287 |
if raw_output.startswith("```"):
|
|
@@ -317,6 +321,9 @@ def generate_sql(user_query: str, candidates_df: pd.DataFrame) -> str:
|
|
| 317 |
question=user_query.strip(),
|
| 318 |
)
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
| 322 |
return _postprocess_sql(raw_output)
|
|
|
|
| 251 |
return False
|
| 252 |
|
| 253 |
|
| 254 |
+
def _llama_chat_complete(messages: list[dict]) -> str:
|
| 255 |
+
"""Call llama-server /v1/chat/completions with a messages list."""
|
| 256 |
resp = httpx.post(
|
| 257 |
+
f"{LLAMA_SERVER_URL}/v1/chat/completions",
|
| 258 |
json={
|
| 259 |
+
"messages": messages,
|
| 260 |
"n_predict": LLAMA_MAX_TOKENS,
|
| 261 |
"temperature": LLAMA_TEMPERATURE,
|
| 262 |
+
"chat_template_kwargs": {"enable_thinking": False},
|
| 263 |
},
|
| 264 |
timeout=60,
|
| 265 |
)
|
| 266 |
if resp.status_code != 200:
|
| 267 |
logger.error("llama-server %s: %s", resp.status_code, resp.text[:500])
|
| 268 |
resp.raise_for_status()
|
| 269 |
+
return resp.json()["choices"][0]["message"]["content"]
|
| 270 |
|
| 271 |
|
| 272 |
_PLACES_SYSTEM_PROMPT = (
|
|
|
|
| 281 |
Uses the same prompt format the model was trained on.
|
| 282 |
Returns a PlacesResult; falls back to an empty result on parse failure.
|
| 283 |
"""
|
| 284 |
+
messages = [
|
| 285 |
+
{"role": "system", "content": _PLACES_SYSTEM_PROMPT},
|
| 286 |
+
{"role": "user", "content": user_query},
|
| 287 |
+
]
|
| 288 |
+
raw_output = _llama_chat_complete(messages).strip()
|
| 289 |
|
| 290 |
# Strip markdown fences if the model wrapped the JSON
|
| 291 |
if raw_output.startswith("```"):
|
|
|
|
| 321 |
question=user_query.strip(),
|
| 322 |
)
|
| 323 |
|
| 324 |
+
messages = [
|
| 325 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 326 |
+
{"role": "user", "content": user_prompt},
|
| 327 |
+
]
|
| 328 |
+
raw_output = _llama_chat_complete(messages)
|
| 329 |
return _postprocess_sql(raw_output)
|
src/gazet/search.py
CHANGED
|
@@ -5,67 +5,6 @@ from .config import DIVISIONS_AREA_PATH, NATURAL_EARTH_PATH
|
|
| 5 |
from .schemas import Place
|
| 6 |
|
| 7 |
|
| 8 |
-
def _fuzzy_search(
|
| 9 |
-
con: duckdb.DuckDBPyConnection,
|
| 10 |
-
path: str,
|
| 11 |
-
source: str,
|
| 12 |
-
place: Place,
|
| 13 |
-
extra_select: str = "",
|
| 14 |
-
limit: int = 5,
|
| 15 |
-
is_overture: bool = False,
|
| 16 |
-
) -> pd.DataFrame:
|
| 17 |
-
"""Generic Levenshtein fuzzy search against any parquet with a names.primary column."""
|
| 18 |
-
country_filter = ""
|
| 19 |
-
country_params: list = []
|
| 20 |
-
if is_overture and place.country:
|
| 21 |
-
country_filter = "AND country = ?"
|
| 22 |
-
country_params = [place.country]
|
| 23 |
-
|
| 24 |
-
subtype_filter = ""
|
| 25 |
-
subtype_params: list = []
|
| 26 |
-
if is_overture and place.subtype:
|
| 27 |
-
subtype_filter = "AND subtype = ?"
|
| 28 |
-
subtype_params = [place.subtype]
|
| 29 |
-
|
| 30 |
-
params = (
|
| 31 |
-
[place.place, place.place, path] + country_params + subtype_params + [limit]
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
extra_clause = f", {extra_select}" if extra_select else ""
|
| 35 |
-
rel = con.execute(
|
| 36 |
-
f"""
|
| 37 |
-
SELECT
|
| 38 |
-
id,
|
| 39 |
-
names."primary" AS name,
|
| 40 |
-
country,
|
| 41 |
-
subtype,
|
| 42 |
-
class,
|
| 43 |
-
region,
|
| 44 |
-
admin_level,
|
| 45 |
-
is_land,
|
| 46 |
-
is_territorial{extra_clause},
|
| 47 |
-
1.0 - (levenshtein(lower(names."primary"), lower(?))::float
|
| 48 |
-
/ greatest(length(names."primary"), length(?), 1)) AS similarity
|
| 49 |
-
FROM read_parquet(?)
|
| 50 |
-
WHERE names."primary" IS NOT NULL AND trim(names."primary") != ''
|
| 51 |
-
{country_filter}
|
| 52 |
-
{subtype_filter}
|
| 53 |
-
ORDER BY similarity DESC, admin_level ASC
|
| 54 |
-
LIMIT ?
|
| 55 |
-
""",
|
| 56 |
-
params,
|
| 57 |
-
)
|
| 58 |
-
df = rel.fetchdf()
|
| 59 |
-
df.insert(0, "source", source)
|
| 60 |
-
label = f'"{place.place}"' + (f" [{place.country}]" if place.country else "")
|
| 61 |
-
if df.empty:
|
| 62 |
-
print(f"\n{source} - {label}: no matches")
|
| 63 |
-
else:
|
| 64 |
-
print(f"\n{source} - {label} (top {len(df)} by name similarity):")
|
| 65 |
-
print(df.to_string(index=False))
|
| 66 |
-
return df
|
| 67 |
-
|
| 68 |
-
|
| 69 |
def simple_fuzzy_search(
|
| 70 |
con: duckdb.DuckDBPyConnection,
|
| 71 |
path: str,
|
|
@@ -138,22 +77,14 @@ def search_natural_earth(
|
|
| 138 |
def search_candidates(
|
| 139 |
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 140 |
) -> list[pd.DataFrame]:
|
| 141 |
-
"""Return candidate DataFrames for a place
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
If no subtype is known, search both sources (handles seas, oceans, terrain).
|
| 146 |
"""
|
| 147 |
results = []
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
df = search_divisions_area(con, place, limit=limit)
|
| 151 |
if not df.empty:
|
| 152 |
results.append(df)
|
| 153 |
-
else:
|
| 154 |
-
# Ambiguous — could be physical feature or admin; search both
|
| 155 |
-
for fn in (search_divisions_area, search_natural_earth):
|
| 156 |
-
df = fn(con, place, limit=limit)
|
| 157 |
-
if not df.empty:
|
| 158 |
-
results.append(df)
|
| 159 |
return results
|
|
|
|
| 5 |
from .schemas import Place
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def simple_fuzzy_search(
|
| 9 |
con: duckdb.DuckDBPyConnection,
|
| 10 |
path: str,
|
|
|
|
| 77 |
def search_candidates(
|
| 78 |
con: duckdb.DuckDBPyConnection, place: Place, limit: int = 5
|
| 79 |
) -> list[pd.DataFrame]:
|
| 80 |
+
"""Return candidate DataFrames for a place from both sources.
|
| 81 |
|
| 82 |
+
Always searches divisions_area and natural_earth to avoid missing
|
| 83 |
+
natural features when the model assigns an incorrect admin subtype.
|
|
|
|
| 84 |
"""
|
| 85 |
results = []
|
| 86 |
+
for fn in (search_divisions_area, search_natural_earth):
|
| 87 |
+
df = fn(con, place, limit=limit)
|
|
|
|
| 88 |
if not df.empty:
|
| 89 |
results.append(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return results
|