Spaces:
Running
Running
Add scripts to train & infer model on modal
Browse files- finetune/__init__.py +1 -0
- finetune/check_token_lengths.py +173 -0
- finetune/config.py +76 -0
- finetune/data.py +45 -0
- finetune/eval_demo.py +329 -0
- finetune/infer_modal.py +236 -0
- finetune/prompts.py +83 -0
- finetune/train_modal.py +279 -0
finetune/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
finetune/check_token_lengths.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Check token lengths of training samples to validate max_length setting.
|
| 2 |
+
|
| 3 |
+
Usage
|
| 4 |
+
-----
|
| 5 |
+
modal run finetune/check_token_lengths.py \
|
| 6 |
+
--train-jsonl /data/train.jsonl \
|
| 7 |
+
--val-jsonl /data/val.jsonl \
|
| 8 |
+
--base-model google/gemma-3-270m-it
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import modal
|
| 14 |
+
|
| 15 |
+
app = modal.App("gazet-check-token-lengths")
|
| 16 |
+
|
| 17 |
+
check_image = (
|
| 18 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 19 |
+
.pip_install(
|
| 20 |
+
"datasets>=3.0",
|
| 21 |
+
"pandas>=2.2",
|
| 22 |
+
"transformers>=4.46",
|
| 23 |
+
)
|
| 24 |
+
.add_local_python_source("finetune", copy=True)
|
| 25 |
+
.env({"HF_HOME": "/mnt/gazet/model_cache"})
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
|
| 29 |
+
|
| 30 |
+
VOLUMES = {
|
| 31 |
+
"/mnt/gazet": gazet_vol,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.function(
|
| 36 |
+
image=check_image,
|
| 37 |
+
volumes=VOLUMES,
|
| 38 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 39 |
+
)
|
| 40 |
+
def analyze_token_lengths(
|
| 41 |
+
train_jsonl: str,
|
| 42 |
+
val_jsonl: str | None,
|
| 43 |
+
base_model: str,
|
| 44 |
+
schema_file: str | None = None,
|
| 45 |
+
):
|
| 46 |
+
from transformers import AutoTokenizer
|
| 47 |
+
from finetune.data import format_dataset_for_sft, load_jsonl_splits, read_text
|
| 48 |
+
from finetune.prompts import DEFAULT_SCHEMA_DETAILS
|
| 49 |
+
|
| 50 |
+
print(f"Loading tokenizer: {base_model}")
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 52 |
+
|
| 53 |
+
print(f"Loading dataset from {train_jsonl}")
|
| 54 |
+
schema_details = read_text(schema_file, DEFAULT_SCHEMA_DETAILS)
|
| 55 |
+
ds = load_jsonl_splits(train_jsonl, val_jsonl)
|
| 56 |
+
formatted = format_dataset_for_sft(ds, schema_details)
|
| 57 |
+
|
| 58 |
+
def compute_lengths(split_name: str, dataset):
|
| 59 |
+
print(f"\n{'='*60}")
|
| 60 |
+
print(f"Analyzing {split_name} split ({len(dataset)} samples)")
|
| 61 |
+
print(f"{'='*60}")
|
| 62 |
+
|
| 63 |
+
lengths = []
|
| 64 |
+
for row in dataset:
|
| 65 |
+
text = row["prompt"] + row["completion"]
|
| 66 |
+
tokens = tokenizer.encode(text)
|
| 67 |
+
lengths.append(len(tokens))
|
| 68 |
+
|
| 69 |
+
lengths.sort()
|
| 70 |
+
n = len(lengths)
|
| 71 |
+
|
| 72 |
+
print(f"\nToken length statistics:")
|
| 73 |
+
print(f" Samples: {n:,}")
|
| 74 |
+
print(f" Min: {min(lengths):,}")
|
| 75 |
+
print(f" Max: {max(lengths):,}")
|
| 76 |
+
print(f" Mean: {sum(lengths)/n:.0f}")
|
| 77 |
+
print(f" Median: {lengths[n//2]:,}")
|
| 78 |
+
print(f" P90: {lengths[int(n*0.90)]:,}")
|
| 79 |
+
print(f" P95: {lengths[int(n*0.95)]:,}")
|
| 80 |
+
print(f" P99: {lengths[int(n*0.99)]:,}")
|
| 81 |
+
|
| 82 |
+
buckets = [
|
| 83 |
+
(512, "0-512"),
|
| 84 |
+
(1024, "513-1024"),
|
| 85 |
+
(2048, "1025-2048"),
|
| 86 |
+
(4096, "2049-4096"),
|
| 87 |
+
(8192, "4097-8192"),
|
| 88 |
+
(float("inf"), "8193+"),
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
print(f"\nFrequency distribution:")
|
| 92 |
+
prev_limit = 0
|
| 93 |
+
for limit, label in buckets:
|
| 94 |
+
count = sum(1 for l in lengths if prev_limit < l <= limit)
|
| 95 |
+
pct = 100 * count / n
|
| 96 |
+
bar = "█" * int(pct / 2)
|
| 97 |
+
print(f" {label:>12}: {count:5,} ({pct:5.1f}%) {bar}")
|
| 98 |
+
prev_limit = limit
|
| 99 |
+
|
| 100 |
+
thresholds = [1024, 2048, 4096, 8192]
|
| 101 |
+
print(f"\nSamples exceeding thresholds:")
|
| 102 |
+
for threshold in thresholds:
|
| 103 |
+
count = sum(1 for l in lengths if l > threshold)
|
| 104 |
+
pct = 100 * count / n
|
| 105 |
+
print(f" > {threshold:5,}: {count:5,} ({pct:5.1f}%)")
|
| 106 |
+
|
| 107 |
+
return lengths
|
| 108 |
+
|
| 109 |
+
train_lengths = compute_lengths("train", formatted["train"])
|
| 110 |
+
|
| 111 |
+
if "val" in formatted:
|
| 112 |
+
val_lengths = compute_lengths("val", formatted["val"])
|
| 113 |
+
else:
|
| 114 |
+
val_lengths = []
|
| 115 |
+
|
| 116 |
+
all_lengths = train_lengths + val_lengths
|
| 117 |
+
if all_lengths:
|
| 118 |
+
print(f"\n{'='*60}")
|
| 119 |
+
print(f"COMBINED STATISTICS")
|
| 120 |
+
print(f"{'='*60}")
|
| 121 |
+
all_lengths.sort()
|
| 122 |
+
n = len(all_lengths)
|
| 123 |
+
print(f" Total samples: {n:,}")
|
| 124 |
+
print(f" Max length: {max(all_lengths):,}")
|
| 125 |
+
print(f" P99: {all_lengths[int(n*0.99)]:,}")
|
| 126 |
+
|
| 127 |
+
for threshold in [1024, 2048, 4096]:
|
| 128 |
+
count = sum(1 for l in all_lengths if l > threshold)
|
| 129 |
+
pct = 100 * count / n
|
| 130 |
+
status = "⚠️ WARNING" if count > 0 and threshold == 2048 else "✓ OK"
|
| 131 |
+
print(f" > {threshold:5,}: {count:5,} ({pct:5.1f}%) {status}")
|
| 132 |
+
|
| 133 |
+
print(f"\n{'='*60}")
|
| 134 |
+
print("RECOMMENDATIONS")
|
| 135 |
+
print(f"{'='*60}")
|
| 136 |
+
|
| 137 |
+
max_len = max(all_lengths) if all_lengths else 0
|
| 138 |
+
over_2048 = sum(1 for l in all_lengths if l > 2048) if all_lengths else 0
|
| 139 |
+
|
| 140 |
+
if max_len <= 1024:
|
| 141 |
+
print("✓ All samples fit within 1024 tokens")
|
| 142 |
+
print(" Recommended max_length: 1024")
|
| 143 |
+
elif max_len <= 2048:
|
| 144 |
+
print("✓ All samples fit within 2048 tokens")
|
| 145 |
+
print(" Recommended max_length: 2048")
|
| 146 |
+
elif over_2048 < n * 0.01:
|
| 147 |
+
print(f"⚠️ {over_2048} samples ({100*over_2048/n:.1f}%) exceed 2048 tokens")
|
| 148 |
+
print(" Options:")
|
| 149 |
+
print(" 1. Keep max_length=2048 (truncates <1% of samples)")
|
| 150 |
+
print(" 2. Increase to max_length=4096 (uses more GPU memory)")
|
| 151 |
+
print(" 3. Reduce candidate rows in preprocessing")
|
| 152 |
+
else:
|
| 153 |
+
print(f"⚠️ {over_2048} samples ({100*over_2048/n:.1f}%) exceed 2048 tokens")
|
| 154 |
+
print(f" Recommended max_length: {max_len} (or reduce candidate rows)")
|
| 155 |
+
|
| 156 |
+
print()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@app.local_entrypoint()
|
| 160 |
+
def main(
|
| 161 |
+
train_jsonl: str = "/mnt/gazet/data/output/train.jsonl",
|
| 162 |
+
val_jsonl: str | None = "/mnt/gazet/data/output/val.jsonl",
|
| 163 |
+
base_model: str = "google/gemma-3-270m-it",
|
| 164 |
+
schema_file: str | None = None,
|
| 165 |
+
):
|
| 166 |
+
print(f"Checking token lengths for:")
|
| 167 |
+
print(f" Model: {base_model}")
|
| 168 |
+
print(f" Train: {train_jsonl}")
|
| 169 |
+
if val_jsonl:
|
| 170 |
+
print(f" Val: {val_jsonl}")
|
| 171 |
+
|
| 172 |
+
analyze_token_lengths.remote(train_jsonl, val_jsonl, base_model, schema_file)
|
| 173 |
+
print("Analysis complete!")
|
finetune/config.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training configuration for text-to-SQL LoRA finetuning."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
LORA_TARGET_MODULES = [
|
| 9 |
+
"q_proj",
|
| 10 |
+
"k_proj",
|
| 11 |
+
"v_proj",
|
| 12 |
+
"o_proj",
|
| 13 |
+
"gate_proj",
|
| 14 |
+
"up_proj",
|
| 15 |
+
"down_proj",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TrainingConfig:
|
| 21 |
+
# Model
|
| 22 |
+
base_model: str = "google/gemma-3-270m-it"
|
| 23 |
+
|
| 24 |
+
# Dataset (paths on the Modal volume)
|
| 25 |
+
train_jsonl: str = "/mnt/gazet/data/output/train.jsonl"
|
| 26 |
+
val_jsonl: Optional[str] = "/mnt/gazet/data/output/val.jsonl"
|
| 27 |
+
test_jsonl: Optional[str] = "/mnt/gazet/data/output/test.jsonl"
|
| 28 |
+
schema_file: Optional[str] = None
|
| 29 |
+
max_train_samples: Optional[int] = None
|
| 30 |
+
max_eval_samples: Optional[int] = None
|
| 31 |
+
|
| 32 |
+
# LoRA
|
| 33 |
+
lora_r: int = 16
|
| 34 |
+
lora_alpha: int = 16
|
| 35 |
+
lora_dropout: float = 0.05
|
| 36 |
+
target_modules: List[str] = field(default_factory=lambda: list(LORA_TARGET_MODULES))
|
| 37 |
+
|
| 38 |
+
# Training
|
| 39 |
+
num_train_epochs: int = 2
|
| 40 |
+
per_device_train_batch_size: int = 12
|
| 41 |
+
per_device_eval_batch_size: int = 12
|
| 42 |
+
gradient_accumulation_steps: int = 2
|
| 43 |
+
gradient_checkpointing: bool = True
|
| 44 |
+
optim: str = "adamw_torch_fused"
|
| 45 |
+
learning_rate: float = 1e-4
|
| 46 |
+
max_grad_norm: float = 0.7
|
| 47 |
+
warmup_steps: int = 50
|
| 48 |
+
lr_scheduler_type: str = "constant"
|
| 49 |
+
weight_decay: float = 0.0
|
| 50 |
+
packing: bool = False
|
| 51 |
+
max_length: int = 2048
|
| 52 |
+
|
| 53 |
+
# Logging / saving
|
| 54 |
+
logging_steps: int = 10
|
| 55 |
+
save_strategy: str = "steps"
|
| 56 |
+
save_steps: int = 300
|
| 57 |
+
eval_strategy: str = "steps"
|
| 58 |
+
eval_steps: int = 100
|
| 59 |
+
report_to: str = "trackio"
|
| 60 |
+
trackio_space_id: Optional[str] = "srmsoumya/gazet-trackio"
|
| 61 |
+
project: str = "gazet-nlg"
|
| 62 |
+
|
| 63 |
+
# SFT-specific
|
| 64 |
+
completion_only_loss: bool = True
|
| 65 |
+
dataset_num_proc: Optional[int] = 8
|
| 66 |
+
|
| 67 |
+
# Experiment
|
| 68 |
+
seed: int = 42
|
| 69 |
+
experiment_name: Optional[str] = None
|
| 70 |
+
merge_after_training: bool = True
|
| 71 |
+
|
| 72 |
+
def __post_init__(self):
|
| 73 |
+
if self.experiment_name is None:
|
| 74 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 75 |
+
model_short = self.base_model.split("/")[-1]
|
| 76 |
+
self.experiment_name = f"{model_short}-r{self.lora_r}-{timestamp}"
|
finetune/data.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loading and SFT formatting for text-to-SQL finetuning."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
from datasets import DatasetDict, load_dataset
|
| 10 |
+
|
| 11 |
+
from finetune.prompts import DEFAULT_SCHEMA_DETAILS, make_prompt_completion
|
| 12 |
+
|
| 13 |
+
LOGGER = logging.getLogger("nlg.data")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def read_text(path: Optional[str], default: str) -> str:
|
| 17 |
+
if not path:
|
| 18 |
+
return default
|
| 19 |
+
return Path(path).read_text(encoding="utf-8")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_jsonl_splits(
|
| 23 |
+
train_jsonl: str,
|
| 24 |
+
val_jsonl: Optional[str] = None,
|
| 25 |
+
test_jsonl: Optional[str] = None,
|
| 26 |
+
) -> DatasetDict:
|
| 27 |
+
data_files: Dict[str, str] = {"train": train_jsonl}
|
| 28 |
+
if val_jsonl:
|
| 29 |
+
data_files["val"] = val_jsonl
|
| 30 |
+
if test_jsonl:
|
| 31 |
+
data_files["test"] = test_jsonl
|
| 32 |
+
LOGGER.info("Loading dataset splits: %s", data_files)
|
| 33 |
+
return load_dataset("json", data_files=data_files)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def format_dataset_for_sft(
|
| 37 |
+
dataset: DatasetDict,
|
| 38 |
+
schema_details: str = DEFAULT_SCHEMA_DETAILS,
|
| 39 |
+
) -> DatasetDict:
|
| 40 |
+
formatted = DatasetDict()
|
| 41 |
+
for split, ds in dataset.items():
|
| 42 |
+
formatted[split] = ds.map(
|
| 43 |
+
lambda row: make_prompt_completion(row, schema_details)
|
| 44 |
+
)
|
| 45 |
+
return formatted
|
finetune/eval_demo.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit eval viewer: compare expected vs predicted SQL and view results on a map.
|
| 2 |
+
|
| 3 |
+
Usage: streamlit run finetune/eval_demo.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import difflib
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import pathlib
|
| 11 |
+
|
| 12 |
+
import duckdb
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import pydeck as pdk
|
| 16 |
+
import sqlparse
|
| 17 |
+
import streamlit as st
|
| 18 |
+
|
| 19 |
+
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
| 20 |
+
DATA_DIR = pathlib.Path(
|
| 21 |
+
os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
|
| 22 |
+
)
|
| 23 |
+
EVAL_DIR = PROJECT_ROOT / "data" / "eval_results"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_eval_results(path):
|
| 27 |
+
with open(path) as f:
|
| 28 |
+
return json.load(f)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def rewrite_data_paths(sql):
|
| 32 |
+
"""Replace hardcoded /data/ paths with the local data directory."""
|
| 33 |
+
return sql.replace("/data/", f"{DATA_DIR}/")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def format_sql(sql):
|
| 37 |
+
"""Pretty-print SQL with sqlparse."""
|
| 38 |
+
return sqlparse.format(sql, reindent=True, keyword_case="upper")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def sql_diff_html(expected, predicted):
|
| 42 |
+
"""Return an HTML diff of two SQL strings."""
|
| 43 |
+
expected_lines = format_sql(expected).splitlines()
|
| 44 |
+
predicted_lines = format_sql(predicted).splitlines()
|
| 45 |
+
diff = difflib.HtmlDiff(tabsize=2, wrapcolumn=80)
|
| 46 |
+
return diff.make_table(
|
| 47 |
+
expected_lines, predicted_lines,
|
| 48 |
+
fromdesc="Expected", todesc="Predicted",
|
| 49 |
+
context=False,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_duckdb_connection():
|
| 54 |
+
con = duckdb.connect()
|
| 55 |
+
con.execute("INSTALL spatial")
|
| 56 |
+
con.execute("LOAD spatial")
|
| 57 |
+
return con
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def execute_sql(con, sql):
|
| 61 |
+
"""Execute SQL, converting geometry columns to simplified GeoJSON strings."""
|
| 62 |
+
rel = con.sql(sql)
|
| 63 |
+
cols = rel.columns
|
| 64 |
+
types = [str(t) for t in rel.dtypes]
|
| 65 |
+
|
| 66 |
+
select_parts = []
|
| 67 |
+
for col, dtype in zip(cols, types):
|
| 68 |
+
if "GEOMETRY" in dtype.upper():
|
| 69 |
+
select_parts.append(
|
| 70 |
+
f'ST_AsGeoJSON(ST_SimplifyPreserveTopology("{col}", 0.001)) AS "{col}"'
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
select_parts.append(f'"{col}"')
|
| 74 |
+
|
| 75 |
+
wrapped = f"SELECT {', '.join(select_parts)} FROM ({sql})"
|
| 76 |
+
return con.execute(wrapped).fetchdf()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _is_notna(val):
|
| 80 |
+
"""Check if a value is not NA, handling arrays/lists/numpy arrays safely."""
|
| 81 |
+
if isinstance(val, (list, tuple, np.ndarray)):
|
| 82 |
+
return len(val) > 0
|
| 83 |
+
return pd.notna(val)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _to_python(val):
|
| 87 |
+
"""Convert numpy/pandas types to native Python for JSON serialization."""
|
| 88 |
+
if isinstance(val, (np.integer,)):
|
| 89 |
+
return int(val)
|
| 90 |
+
if isinstance(val, (np.floating,)):
|
| 91 |
+
return float(val)
|
| 92 |
+
if isinstance(val, np.ndarray):
|
| 93 |
+
return val.tolist()
|
| 94 |
+
if isinstance(val, (np.bool_,)):
|
| 95 |
+
return bool(val)
|
| 96 |
+
return val
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def to_feature_collection(result_df):
|
| 100 |
+
"""Build GeoJSON FeatureCollection from a DataFrame with GeoJSON string columns."""
|
| 101 |
+
geom_cols = []
|
| 102 |
+
for c in result_df.columns:
|
| 103 |
+
vals = [v for v in result_df[c].head(5) if isinstance(v, str)]
|
| 104 |
+
if vals and all(v.lstrip().startswith('{"type":') for v in vals):
|
| 105 |
+
geom_cols.append(c)
|
| 106 |
+
|
| 107 |
+
prop_cols = [c for c in result_df.columns if c not in geom_cols]
|
| 108 |
+
features = []
|
| 109 |
+
for _, row in result_df.iterrows():
|
| 110 |
+
geometry = None
|
| 111 |
+
if geom_cols:
|
| 112 |
+
raw = row[geom_cols[0]]
|
| 113 |
+
if raw and isinstance(raw, str):
|
| 114 |
+
geometry = json.loads(raw)
|
| 115 |
+
properties = {}
|
| 116 |
+
for c in prop_cols:
|
| 117 |
+
val = row[c]
|
| 118 |
+
if _is_notna(val):
|
| 119 |
+
properties[c] = _to_python(val)
|
| 120 |
+
features.append(
|
| 121 |
+
{"type": "Feature", "geometry": geometry, "properties": properties}
|
| 122 |
+
)
|
| 123 |
+
return {"type": "FeatureCollection", "features": features}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def bbox_from_geojson(geojson):
|
| 127 |
+
lngs, lats = [], []
|
| 128 |
+
for f in geojson.get("features", []):
|
| 129 |
+
geom = f.get("geometry")
|
| 130 |
+
if geom:
|
| 131 |
+
for coord in _extract_coords(geom):
|
| 132 |
+
lngs.append(coord[0])
|
| 133 |
+
lats.append(coord[1])
|
| 134 |
+
if not lngs:
|
| 135 |
+
return None
|
| 136 |
+
return min(lngs), min(lats), max(lngs), max(lats)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _extract_coords(geom):
|
| 140 |
+
t = geom.get("type", "")
|
| 141 |
+
coords = geom.get("coordinates", [])
|
| 142 |
+
if t == "Point":
|
| 143 |
+
yield coords
|
| 144 |
+
elif t in ("LineString", "MultiPoint"):
|
| 145 |
+
yield from coords
|
| 146 |
+
elif t == "Polygon":
|
| 147 |
+
for ring in coords:
|
| 148 |
+
yield from ring
|
| 149 |
+
elif t in ("MultiLineString", "MultiPolygon"):
|
| 150 |
+
for part in coords:
|
| 151 |
+
if t == "MultiLineString":
|
| 152 |
+
yield from part
|
| 153 |
+
else:
|
| 154 |
+
for ring in part:
|
| 155 |
+
yield from ring
|
| 156 |
+
elif t == "GeometryCollection":
|
| 157 |
+
for g in geom.get("geometries", []):
|
| 158 |
+
yield from _extract_coords(g)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _centroids_from_geojson(geojson):
|
| 162 |
+
"""Extract centroid [lng, lat] for each feature to use as scatter markers."""
|
| 163 |
+
centroids = []
|
| 164 |
+
for f in geojson.get("features", []):
|
| 165 |
+
geom = f.get("geometry")
|
| 166 |
+
if not geom:
|
| 167 |
+
continue
|
| 168 |
+
lngs, lats = [], []
|
| 169 |
+
for coord in _extract_coords(geom):
|
| 170 |
+
lngs.append(coord[0])
|
| 171 |
+
lats.append(coord[1])
|
| 172 |
+
if lngs:
|
| 173 |
+
centroids.append({"lng": sum(lngs) / len(lngs), "lat": sum(lats) / len(lats)})
|
| 174 |
+
return centroids
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def render_map(geojson, color, key):
|
| 178 |
+
n = len(geojson.get("features", []))
|
| 179 |
+
if not n:
|
| 180 |
+
st.info("Query returned no features.")
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
layers = [
|
| 184 |
+
pdk.Layer(
|
| 185 |
+
"GeoJsonLayer",
|
| 186 |
+
data=geojson,
|
| 187 |
+
get_fill_color=color,
|
| 188 |
+
get_line_color=[100, 100, 100, 200],
|
| 189 |
+
get_line_width=2,
|
| 190 |
+
pickable=True,
|
| 191 |
+
),
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
bbox = bbox_from_geojson(geojson)
|
| 195 |
+
if bbox:
|
| 196 |
+
min_lng, min_lat, max_lng, max_lat = bbox
|
| 197 |
+
span = max(max_lng - min_lng, max_lat - min_lat, 1e-6)
|
| 198 |
+
zoom = max(0, min(18, math.log2(360 / span) - 0.8))
|
| 199 |
+
|
| 200 |
+
# Add scatter markers when polygons would be too small to see
|
| 201 |
+
if zoom < 4:
|
| 202 |
+
centroids = _centroids_from_geojson(geojson)
|
| 203 |
+
if centroids:
|
| 204 |
+
layers.append(
|
| 205 |
+
pdk.Layer(
|
| 206 |
+
"ScatterplotLayer",
|
| 207 |
+
data=centroids,
|
| 208 |
+
get_position=["lng", "lat"],
|
| 209 |
+
get_fill_color=color[:3] + [220],
|
| 210 |
+
get_radius=50000,
|
| 211 |
+
radius_min_pixels=6,
|
| 212 |
+
pickable=True,
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
view = pdk.ViewState(
|
| 217 |
+
latitude=(min_lat + max_lat) / 2,
|
| 218 |
+
longitude=(min_lng + max_lng) / 2,
|
| 219 |
+
zoom=zoom,
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
view = pdk.ViewState(latitude=0, longitude=0, zoom=1)
|
| 223 |
+
|
| 224 |
+
st.pydeck_chart(
|
| 225 |
+
pdk.Deck(layers=layers, initial_view_state=view, map_style=None),
|
| 226 |
+
width="stretch",
|
| 227 |
+
height=400,
|
| 228 |
+
key=key,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# --- App ---
|
| 233 |
+
|
| 234 |
+
st.set_page_config(page_title="Eval Viewer", layout="wide")
|
| 235 |
+
st.title("Eval Viewer")
|
| 236 |
+
|
| 237 |
+
eval_files = sorted(EVAL_DIR.glob("eval-*.json"))
|
| 238 |
+
if not eval_files:
|
| 239 |
+
st.error(f"No eval result files found in {EVAL_DIR}")
|
| 240 |
+
st.stop()
|
| 241 |
+
|
| 242 |
+
selected_file = st.sidebar.selectbox(
|
| 243 |
+
"Eval file",
|
| 244 |
+
eval_files,
|
| 245 |
+
format_func=lambda p: p.stem,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
data = load_eval_results(selected_file)
|
| 249 |
+
summary = data["summary"]
|
| 250 |
+
results = data["results"]
|
| 251 |
+
|
| 252 |
+
st.sidebar.markdown(f"""
|
| 253 |
+
**Model**: `{summary.get('label', '')}`
|
| 254 |
+
**Exact match**: {summary['exact_matches']}/{summary['num_samples']} ({summary['exact_match_rate']:.1%})
|
| 255 |
+
""")
|
| 256 |
+
|
| 257 |
+
filter_option = st.sidebar.radio("Filter", ["All", "Matches only", "Mismatches only"])
|
| 258 |
+
if filter_option == "Matches only":
|
| 259 |
+
results = [r for r in results if r["exact_match"]]
|
| 260 |
+
elif filter_option == "Mismatches only":
|
| 261 |
+
results = [r for r in results if not r["exact_match"]]
|
| 262 |
+
|
| 263 |
+
if not results:
|
| 264 |
+
st.warning("No results match the current filter.")
|
| 265 |
+
st.stop()
|
| 266 |
+
|
| 267 |
+
questions = [f"[{r['index']}] {r['question']}" for r in results]
|
| 268 |
+
selected_idx = st.selectbox("Select a query", range(len(questions)), format_func=lambda i: questions[i])
|
| 269 |
+
row = results[selected_idx]
|
| 270 |
+
|
| 271 |
+
match_label = "MATCH" if row["exact_match"] else "MISMATCH"
|
| 272 |
+
match_color = "green" if row["exact_match"] else "red"
|
| 273 |
+
st.markdown(f"### :{match_color}[{match_label}]")
|
| 274 |
+
|
| 275 |
+
# Formatted SQL side-by-side
|
| 276 |
+
col_expected, col_predicted = st.columns(2)
|
| 277 |
+
with col_expected:
|
| 278 |
+
st.markdown("**Expected SQL**")
|
| 279 |
+
st.code(format_sql(row["expected_sql"]), language="sql")
|
| 280 |
+
with col_predicted:
|
| 281 |
+
st.markdown("**Predicted SQL**")
|
| 282 |
+
st.code(format_sql(row["predicted_sql"]), language="sql")
|
| 283 |
+
|
| 284 |
+
# Diff view
|
| 285 |
+
if not row["exact_match"]:
|
| 286 |
+
with st.expander("SQL Diff", expanded=True):
|
| 287 |
+
diff_html = sql_diff_html(row["expected_sql"], row["predicted_sql"])
|
| 288 |
+
diff_css = """
|
| 289 |
+
<style>
|
| 290 |
+
.diff_add { background-color: rgba(40, 167, 69, 0.15); }
|
| 291 |
+
.diff_sub { background-color: rgba(220, 53, 69, 0.15); }
|
| 292 |
+
.diff_chg { background-color: rgba(255, 193, 7, 0.15); }
|
| 293 |
+
.diff_header { background-color: rgba(128, 128, 128, 0.1); font-weight: bold; }
|
| 294 |
+
table.diff { border-collapse: collapse; width: 100%; font-family: monospace; color: inherit; }
|
| 295 |
+
table.diff td, table.diff th { padding: 4px 8px; border: 1px solid rgba(128, 128, 128, 0.2); }
|
| 296 |
+
</style>
|
| 297 |
+
"""
|
| 298 |
+
st.html(f"{diff_css}<div style='overflow-x:auto; font-size:13px;'>{diff_html}</div>")
|
| 299 |
+
|
| 300 |
+
# Auto-execute both SQLs and show maps
|
| 301 |
+
con = get_duckdb_connection()
|
| 302 |
+
|
| 303 |
+
map_col1, map_col2 = st.columns(2)
|
| 304 |
+
|
| 305 |
+
with map_col1:
|
| 306 |
+
st.markdown("**Expected result**")
|
| 307 |
+
sql = rewrite_data_paths(row["expected_sql"])
|
| 308 |
+
try:
|
| 309 |
+
df = execute_sql(con, sql)
|
| 310 |
+
geojson = to_feature_collection(df)
|
| 311 |
+
render_map(geojson, [40, 180, 160, 140], key="map_expected")
|
| 312 |
+
with st.expander("Result table"):
|
| 313 |
+
st.dataframe(df, width="stretch")
|
| 314 |
+
except Exception as e:
|
| 315 |
+
st.error(f"Execution error: {e}")
|
| 316 |
+
|
| 317 |
+
with map_col2:
|
| 318 |
+
st.markdown("**Predicted result**")
|
| 319 |
+
sql = rewrite_data_paths(row["predicted_sql"])
|
| 320 |
+
try:
|
| 321 |
+
df = execute_sql(con, sql)
|
| 322 |
+
geojson = to_feature_collection(df)
|
| 323 |
+
render_map(geojson, [180, 80, 60, 140], key="map_predicted")
|
| 324 |
+
with st.expander("Result table"):
|
| 325 |
+
st.dataframe(df, width="stretch")
|
| 326 |
+
except Exception as e:
|
| 327 |
+
st.error(f"Execution error: {e}")
|
| 328 |
+
|
| 329 |
+
con.close()
|
finetune/infer_modal.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal eval script: run a model on the test set and save results.
|
| 2 |
+
|
| 3 |
+
Usage
|
| 4 |
+
-----
|
| 5 |
+
# Eval finetuned model (uses raw prompt-completion format):
|
| 6 |
+
modal run finetune/infer_modal.py --label finetuned
|
| 7 |
+
|
| 8 |
+
# Eval base model (uses chat template so the model understands the instruction):
|
| 9 |
+
modal run finetune/infer_modal.py \
|
| 10 |
+
--model-path google/gemma-3-270m-it \
|
| 11 |
+
--label base \
|
| 12 |
+
--use-chat-template
|
| 13 |
+
|
| 14 |
+
# Limit samples:
|
| 15 |
+
modal run finetune/infer_modal.py --max-samples 50
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import pathlib
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import modal
|
| 26 |
+
|
| 27 |
+
app = modal.App("gazet-nlg-eval")
|
| 28 |
+
|
| 29 |
+
infer_image = (
|
| 30 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 31 |
+
.pip_install(
|
| 32 |
+
"accelerate>=1.0",
|
| 33 |
+
"pandas>=2.2",
|
| 34 |
+
"torch>=2.4",
|
| 35 |
+
"transformers>=4.46",
|
| 36 |
+
)
|
| 37 |
+
.add_local_python_source("finetune", copy=True)
|
| 38 |
+
.env({"HF_HOME": "/mnt/gazet/model_cache"})
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
|
| 42 |
+
|
| 43 |
+
VOLUMES = {
|
| 44 |
+
"/mnt/gazet": gazet_vol,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
DEFAULT_MODEL_PATH = "/mnt/gazet/checkpoints/gemma-3-270m-it-r16-20260331-134642/merged"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def postprocess_sql(text: str) -> str:
|
| 51 |
+
cleaned = text.strip()
|
| 52 |
+
if "```sql" in cleaned:
|
| 53 |
+
cleaned = cleaned.split("```sql", 1)[1]
|
| 54 |
+
if cleaned.startswith("```"):
|
| 55 |
+
cleaned = cleaned[3:]
|
| 56 |
+
if "```" in cleaned:
|
| 57 |
+
cleaned = cleaned.split("```", 1)[0]
|
| 58 |
+
return cleaned.strip()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@app.function(
|
| 62 |
+
image=infer_image,
|
| 63 |
+
gpu="L40S",
|
| 64 |
+
volumes=VOLUMES,
|
| 65 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 66 |
+
timeout=60 * 60,
|
| 67 |
+
)
|
| 68 |
+
def run_eval(
|
| 69 |
+
model_path: str,
|
| 70 |
+
label: str,
|
| 71 |
+
samples: list[dict],
|
| 72 |
+
output_path: str,
|
| 73 |
+
max_new_tokens: int = 512,
|
| 74 |
+
batch_size: int = 16,
|
| 75 |
+
use_chat_template: bool = False,
|
| 76 |
+
):
|
| 77 |
+
"""Run batched inference on all samples, save results to volume."""
|
| 78 |
+
import torch
|
| 79 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 80 |
+
|
| 81 |
+
from finetune.prompts import SYSTEM_PROMPT, build_user_prompt, DEFAULT_SCHEMA_DETAILS
|
| 82 |
+
|
| 83 |
+
print(f"Loading model [{label}]: {model_path}")
|
| 84 |
+
print(f"Chat template: {use_chat_template}")
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 86 |
+
if tokenizer.pad_token is None:
|
| 87 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 88 |
+
tokenizer.padding_side = "left"
|
| 89 |
+
|
| 90 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 91 |
+
model_path,
|
| 92 |
+
torch_dtype=torch.bfloat16,
|
| 93 |
+
attn_implementation="sdpa",
|
| 94 |
+
device_map="auto",
|
| 95 |
+
)
|
| 96 |
+
model.eval()
|
| 97 |
+
|
| 98 |
+
# Build all prompts upfront
|
| 99 |
+
prompts = []
|
| 100 |
+
for sample in samples:
|
| 101 |
+
user_content = build_user_prompt(
|
| 102 |
+
question=sample["question"],
|
| 103 |
+
candidates=sample["candidates"],
|
| 104 |
+
schema_details=DEFAULT_SCHEMA_DETAILS,
|
| 105 |
+
)
|
| 106 |
+
if use_chat_template:
|
| 107 |
+
messages = [
|
| 108 |
+
{"role": "user", "content": SYSTEM_PROMPT + "\n\n" + user_content},
|
| 109 |
+
]
|
| 110 |
+
prompt = tokenizer.apply_chat_template(
|
| 111 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
prompt = SYSTEM_PROMPT + "\n\n" + user_content
|
| 115 |
+
prompts.append(prompt)
|
| 116 |
+
|
| 117 |
+
# Batched inference
|
| 118 |
+
all_predictions = []
|
| 119 |
+
num_batches = (len(prompts) + batch_size - 1) // batch_size
|
| 120 |
+
|
| 121 |
+
for batch_idx in range(num_batches):
|
| 122 |
+
start = batch_idx * batch_size
|
| 123 |
+
end = min(start + batch_size, len(prompts))
|
| 124 |
+
batch_prompts = prompts[start:end]
|
| 125 |
+
|
| 126 |
+
inputs = tokenizer(
|
| 127 |
+
batch_prompts,
|
| 128 |
+
return_tensors="pt",
|
| 129 |
+
padding=True,
|
| 130 |
+
truncation=True,
|
| 131 |
+
max_length=2048,
|
| 132 |
+
).to(model.device)
|
| 133 |
+
input_len = inputs["input_ids"].shape[1]
|
| 134 |
+
|
| 135 |
+
with torch.inference_mode():
|
| 136 |
+
outputs = model.generate(
|
| 137 |
+
**inputs,
|
| 138 |
+
max_new_tokens=max_new_tokens,
|
| 139 |
+
do_sample=False,
|
| 140 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 141 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
for j in range(len(batch_prompts)):
|
| 145 |
+
generated = tokenizer.decode(
|
| 146 |
+
outputs[j][input_len:], skip_special_tokens=True
|
| 147 |
+
)
|
| 148 |
+
all_predictions.append(postprocess_sql(generated))
|
| 149 |
+
|
| 150 |
+
print(f"Batch {batch_idx+1}/{num_batches} done ({end}/{len(prompts)} samples)")
|
| 151 |
+
|
| 152 |
+
# Build results
|
| 153 |
+
results = []
|
| 154 |
+
matches = 0
|
| 155 |
+
for i, sample in enumerate(samples):
|
| 156 |
+
expected = sample.get("target", {}).get("sql", "")
|
| 157 |
+
predicted = all_predictions[i]
|
| 158 |
+
is_match = predicted.strip() == expected.strip()
|
| 159 |
+
if is_match:
|
| 160 |
+
matches += 1
|
| 161 |
+
|
| 162 |
+
results.append({
|
| 163 |
+
"index": i,
|
| 164 |
+
"question": sample["question"],
|
| 165 |
+
"candidates": sample["candidates"],
|
| 166 |
+
"expected_sql": expected,
|
| 167 |
+
"predicted_sql": predicted,
|
| 168 |
+
"exact_match": is_match,
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
total = len(results)
|
| 172 |
+
exact_match_rate = matches / total if total else 0
|
| 173 |
+
|
| 174 |
+
output = {
|
| 175 |
+
"summary": {
|
| 176 |
+
"label": label,
|
| 177 |
+
"model_path": model_path,
|
| 178 |
+
"num_samples": total,
|
| 179 |
+
"exact_matches": matches,
|
| 180 |
+
"exact_match_rate": exact_match_rate,
|
| 181 |
+
"timestamp": datetime.now().isoformat(),
|
| 182 |
+
},
|
| 183 |
+
"results": results,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
path = pathlib.Path(output_path)
|
| 187 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
with open(path, "w") as f:
|
| 189 |
+
json.dump(output, f, indent=2)
|
| 190 |
+
gazet_vol.commit()
|
| 191 |
+
|
| 192 |
+
print(f"\n{'='*60}")
|
| 193 |
+
print(f"[{label}] {matches}/{total} exact matches ({100*exact_match_rate:.1f}%)")
|
| 194 |
+
print(f"Results saved to {output_path}")
|
| 195 |
+
print(f"{'='*60}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@app.function(
|
| 199 |
+
image=infer_image,
|
| 200 |
+
volumes=VOLUMES,
|
| 201 |
+
)
|
| 202 |
+
def read_test_data(test_jsonl: str) -> list[dict]:
|
| 203 |
+
"""Read test JSONL from the volume."""
|
| 204 |
+
lines = []
|
| 205 |
+
with open(test_jsonl) as f:
|
| 206 |
+
for line in f:
|
| 207 |
+
lines.append(json.loads(line))
|
| 208 |
+
return lines
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@app.local_entrypoint()
|
| 212 |
+
def main(
|
| 213 |
+
model_path: str = DEFAULT_MODEL_PATH,
|
| 214 |
+
label: str = "finetuned",
|
| 215 |
+
test_jsonl: str = "/mnt/gazet/data/output/test.jsonl",
|
| 216 |
+
max_samples: Optional[int] = None,
|
| 217 |
+
max_new_tokens: int = 512,
|
| 218 |
+
batch_size: int = 16,
|
| 219 |
+
use_chat_template: bool = False,
|
| 220 |
+
output_dir: str = "/mnt/gazet/eval_results",
|
| 221 |
+
):
|
| 222 |
+
print(f"Model: {model_path}")
|
| 223 |
+
print(f"Label: {label}")
|
| 224 |
+
print(f"Chat template: {use_chat_template}")
|
| 225 |
+
|
| 226 |
+
print("Loading test data...")
|
| 227 |
+
samples = read_test_data.remote(test_jsonl)
|
| 228 |
+
if max_samples:
|
| 229 |
+
samples = samples[:max_samples]
|
| 230 |
+
print(f"Eval samples: {len(samples)}")
|
| 231 |
+
|
| 232 |
+
output_file = f"{output_dir}/eval-{label}.json"
|
| 233 |
+
run_eval.remote(
|
| 234 |
+
model_path, label, samples, output_file,
|
| 235 |
+
max_new_tokens, batch_size, use_chat_template,
|
| 236 |
+
)
|
finetune/prompts.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt templates and message formatting for natural language geocoding."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Sequence
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
SYSTEM_PROMPT = (
|
| 10 |
+
"You are a text to SQL query translator that helps in natural language geocoding."
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
USER_PROMPT_TEMPLATE = """GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, generate the corresponding SQL command to retrieve the desired geometry.
|
| 14 |
+
|
| 15 |
+
<SCHEMA_DETAILS>
|
| 16 |
+
{schema_details}
|
| 17 |
+
</SCHEMA_DETAILS>
|
| 18 |
+
|
| 19 |
+
<CANDIDATES>
|
| 20 |
+
{candidates_csv}
|
| 21 |
+
</CANDIDATES>
|
| 22 |
+
|
| 23 |
+
<USER_QUERY>
|
| 24 |
+
{question}
|
| 25 |
+
</USER_QUERY>
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
DEFAULT_SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin boundaries
|
| 29 |
+
path: '/data/overture/division_area/*.parquet'
|
| 30 |
+
columns:
|
| 31 |
+
id VARCHAR
|
| 32 |
+
names STRUCT("primary" VARCHAR, ...)
|
| 33 |
+
country VARCHAR
|
| 34 |
+
subtype VARCHAR
|
| 35 |
+
class VARCHAR
|
| 36 |
+
region VARCHAR
|
| 37 |
+
admin_level INTEGER
|
| 38 |
+
division_id VARCHAR
|
| 39 |
+
is_land BOOLEAN
|
| 40 |
+
is_territorial BOOLEAN
|
| 41 |
+
geometry GEOMETRY
|
| 42 |
+
|
| 43 |
+
2. natural_earth -- Natural Earth geography polygons
|
| 44 |
+
path: '/data/natural_earth_geoparquet/ne_geography.parquet'
|
| 45 |
+
columns:
|
| 46 |
+
id VARCHAR
|
| 47 |
+
name VARCHAR
|
| 48 |
+
featurecla VARCHAR
|
| 49 |
+
scalerank INTEGER
|
| 50 |
+
min_zoom DOUBLE
|
| 51 |
+
geometry GEOMETRY"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
|
| 55 |
+
df = pd.DataFrame(list(candidates))
|
| 56 |
+
if "candidate_id" in df.columns:
|
| 57 |
+
df = df.drop(columns=["candidate_id"])
|
| 58 |
+
return df.to_csv(index=False)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_user_prompt(
|
| 62 |
+
question: str,
|
| 63 |
+
candidates: Sequence[Dict[str, Any]],
|
| 64 |
+
schema_details: str,
|
| 65 |
+
) -> str:
|
| 66 |
+
return USER_PROMPT_TEMPLATE.format(
|
| 67 |
+
schema_details=schema_details.strip(),
|
| 68 |
+
candidates_csv=candidates_to_csv(candidates).strip(),
|
| 69 |
+
question=question.strip(),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_prompt_completion(
|
| 74 |
+
sample: Dict[str, Any],
|
| 75 |
+
schema_details: str,
|
| 76 |
+
) -> Dict[str, str]:
|
| 77 |
+
prompt = SYSTEM_PROMPT + "\n\n" + build_user_prompt(
|
| 78 |
+
question=sample["question"],
|
| 79 |
+
candidates=sample["candidates"],
|
| 80 |
+
schema_details=schema_details,
|
| 81 |
+
)
|
| 82 |
+
completion = sample.get("target", {}).get("sql", "")
|
| 83 |
+
return {"prompt": prompt, "completion": completion}
|
finetune/train_modal.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal training script for text-to-SQL LoRA finetuning.
|
| 2 |
+
|
| 3 |
+
Usage
|
| 4 |
+
-----
|
| 5 |
+
modal run finetune/train_modal.py \
|
| 6 |
+
--train-jsonl /data/train.jsonl \
|
| 7 |
+
--val-jsonl /data/val.jsonl \
|
| 8 |
+
--base-model google/gemma-3-1b-it
|
| 9 |
+
|
| 10 |
+
All CLI arguments map to TrainingConfig fields. Run with --help for details.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import pathlib
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
import modal
|
| 19 |
+
|
| 20 |
+
app = modal.App("gazet-nlg-finetune")
|
| 21 |
+
|
| 22 |
+
GPU_TYPE = "A100-80GB" # "L40S"
|
| 23 |
+
TIMEOUT_HOURS = 6
|
| 24 |
+
MAX_RETRIES = 1
|
| 25 |
+
|
| 26 |
+
train_image = (
|
| 27 |
+
modal.Image.debian_slim(python_version="3.11")
|
| 28 |
+
.pip_install(
|
| 29 |
+
"accelerate>=1.0",
|
| 30 |
+
"datasets>=3.0",
|
| 31 |
+
"hf-transfer>=0.1",
|
| 32 |
+
"huggingface_hub>=0.25",
|
| 33 |
+
"jinja2>=3.0",
|
| 34 |
+
"pandas>=2.2",
|
| 35 |
+
"peft>=0.13",
|
| 36 |
+
"torch>=2.4",
|
| 37 |
+
"trackio[gpu]",
|
| 38 |
+
"transformers>=4.46",
|
| 39 |
+
"trl>=0.12",
|
| 40 |
+
)
|
| 41 |
+
.add_local_python_source("finetune", copy=True)
|
| 42 |
+
.env({"HF_HOME": "/mnt/gazet/model_cache", "HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
with train_image.imports():
|
| 46 |
+
import torch
|
| 47 |
+
from datasets import DatasetDict
|
| 48 |
+
from peft import LoraConfig
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
| 50 |
+
from trl import SFTConfig, SFTTrainer
|
| 51 |
+
|
| 52 |
+
gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
|
| 53 |
+
|
| 54 |
+
VOLUMES = {
|
| 55 |
+
"/mnt/gazet": gazet_vol,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _load_tokenizer(model_name: str):
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 61 |
+
if tokenizer.pad_token is None:
|
| 62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 63 |
+
return tokenizer
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _load_model(model_name: str):
|
| 67 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 68 |
+
model_name,
|
| 69 |
+
torch_dtype=torch.bfloat16,
|
| 70 |
+
attn_implementation="sdpa",
|
| 71 |
+
device_map="auto",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _build_lora_config(config) -> LoraConfig:
|
| 76 |
+
return LoraConfig(
|
| 77 |
+
r=config.lora_r,
|
| 78 |
+
lora_alpha=config.lora_alpha,
|
| 79 |
+
lora_dropout=config.lora_dropout,
|
| 80 |
+
bias="none",
|
| 81 |
+
task_type="CAUSAL_LM",
|
| 82 |
+
target_modules=config.target_modules,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _load_and_format_dataset(config) -> DatasetDict:
|
| 87 |
+
"""Load JSONL splits and apply prompt-completion formatting."""
|
| 88 |
+
from finetune.data import (
|
| 89 |
+
format_dataset_for_sft,
|
| 90 |
+
load_jsonl_splits,
|
| 91 |
+
read_text,
|
| 92 |
+
)
|
| 93 |
+
from finetune.prompts import DEFAULT_SCHEMA_DETAILS
|
| 94 |
+
|
| 95 |
+
schema_details = read_text(config.schema_file, DEFAULT_SCHEMA_DETAILS)
|
| 96 |
+
raw_ds = load_jsonl_splits(config.train_jsonl, config.val_jsonl, config.test_jsonl)
|
| 97 |
+
ds = format_dataset_for_sft(raw_ds, schema_details)
|
| 98 |
+
|
| 99 |
+
if config.max_train_samples is not None:
|
| 100 |
+
ds["train"] = ds["train"].select(
|
| 101 |
+
range(min(config.max_train_samples, len(ds["train"])))
|
| 102 |
+
)
|
| 103 |
+
if config.max_eval_samples is not None and "val" in ds:
|
| 104 |
+
ds["val"] = ds["val"].select(
|
| 105 |
+
range(min(config.max_eval_samples, len(ds["val"])))
|
| 106 |
+
)
|
| 107 |
+
return ds
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _find_latest_checkpoint(checkpoint_dir: pathlib.Path) -> str | None:
|
| 111 |
+
if not checkpoint_dir.exists():
|
| 112 |
+
return None
|
| 113 |
+
checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
|
| 114 |
+
if not checkpoints:
|
| 115 |
+
return None
|
| 116 |
+
latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
|
| 117 |
+
print(f"Found existing checkpoint: {latest}")
|
| 118 |
+
return str(latest)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@app.function(
|
| 122 |
+
image=train_image,
|
| 123 |
+
gpu=GPU_TYPE,
|
| 124 |
+
volumes=VOLUMES,
|
| 125 |
+
secrets=[modal.Secret.from_name("huggingface-secret")],
|
| 126 |
+
timeout=TIMEOUT_HOURS * 60 * 60,
|
| 127 |
+
retries=modal.Retries(initial_delay=0.0, max_retries=MAX_RETRIES),
|
| 128 |
+
)
|
| 129 |
+
def finetune(config_dict: dict):
|
| 130 |
+
"""Run LoRA SFT training inside a Modal container."""
|
| 131 |
+
from finetune.config import TrainingConfig
|
| 132 |
+
|
| 133 |
+
config = TrainingConfig(**config_dict)
|
| 134 |
+
set_seed(config.seed)
|
| 135 |
+
|
| 136 |
+
experiment_dir = pathlib.Path("/mnt/gazet/checkpoints") / config.experiment_name
|
| 137 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
| 138 |
+
|
| 139 |
+
print(f"Experiment: {config.experiment_name}")
|
| 140 |
+
print(f"Model: {config.base_model}")
|
| 141 |
+
|
| 142 |
+
# Model and tokenizer
|
| 143 |
+
tokenizer = _load_tokenizer(config.base_model)
|
| 144 |
+
model = _load_model(config.base_model)
|
| 145 |
+
|
| 146 |
+
# Dataset
|
| 147 |
+
ds = _load_and_format_dataset(config)
|
| 148 |
+
print(f"Train samples: {len(ds['train']):,}")
|
| 149 |
+
if "val" in ds:
|
| 150 |
+
print(f"Val samples: {len(ds['val']):,}")
|
| 151 |
+
|
| 152 |
+
# LoRA
|
| 153 |
+
peft_config = _build_lora_config(config)
|
| 154 |
+
|
| 155 |
+
# SFT config
|
| 156 |
+
sft_args = SFTConfig(
|
| 157 |
+
output_dir=str(experiment_dir),
|
| 158 |
+
max_length=config.max_length,
|
| 159 |
+
packing=config.packing,
|
| 160 |
+
num_train_epochs=config.num_train_epochs,
|
| 161 |
+
per_device_train_batch_size=config.per_device_train_batch_size,
|
| 162 |
+
per_device_eval_batch_size=config.per_device_eval_batch_size,
|
| 163 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 164 |
+
gradient_checkpointing=config.gradient_checkpointing,
|
| 165 |
+
optim=config.optim,
|
| 166 |
+
logging_steps=config.logging_steps,
|
| 167 |
+
save_strategy=config.save_strategy,
|
| 168 |
+
save_steps=config.save_steps,
|
| 169 |
+
eval_strategy=config.eval_strategy,
|
| 170 |
+
eval_steps=config.eval_steps,
|
| 171 |
+
learning_rate=config.learning_rate,
|
| 172 |
+
bf16=True,
|
| 173 |
+
max_grad_norm=config.max_grad_norm,
|
| 174 |
+
warmup_steps=config.warmup_steps,
|
| 175 |
+
lr_scheduler_type=config.lr_scheduler_type,
|
| 176 |
+
weight_decay=config.weight_decay,
|
| 177 |
+
report_to=config.report_to,
|
| 178 |
+
trackio_space_id=config.trackio_space_id,
|
| 179 |
+
project=config.project,
|
| 180 |
+
completion_only_loss=config.completion_only_loss,
|
| 181 |
+
dataset_num_proc=config.dataset_num_proc,
|
| 182 |
+
seed=config.seed,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
trainer = SFTTrainer(
|
| 186 |
+
model=model,
|
| 187 |
+
args=sft_args,
|
| 188 |
+
train_dataset=ds["train"],
|
| 189 |
+
eval_dataset=ds.get("val"),
|
| 190 |
+
peft_config=peft_config,
|
| 191 |
+
processing_class=tokenizer,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 195 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 196 |
+
print(f"Total parameters: {total_params:,}")
|
| 197 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 198 |
+
|
| 199 |
+
# Resume from checkpoint if available (handles preemption)
|
| 200 |
+
resume_from = _find_latest_checkpoint(experiment_dir)
|
| 201 |
+
if resume_from:
|
| 202 |
+
print(f"Resuming from {resume_from}")
|
| 203 |
+
|
| 204 |
+
trainer.train(resume_from_checkpoint=resume_from)
|
| 205 |
+
|
| 206 |
+
# Save final adapter + tokenizer
|
| 207 |
+
print(f"Saving adapter to {experiment_dir}")
|
| 208 |
+
trainer.save_model(str(experiment_dir))
|
| 209 |
+
tokenizer.save_pretrained(str(experiment_dir))
|
| 210 |
+
gazet_vol.commit()
|
| 211 |
+
|
| 212 |
+
# Optionally merge adapter into base model
|
| 213 |
+
if config.merge_after_training:
|
| 214 |
+
_merge_and_save(config, experiment_dir)
|
| 215 |
+
|
| 216 |
+
print(f"Training complete: {config.experiment_name}")
|
| 217 |
+
return config.experiment_name
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _merge_and_save(config, experiment_dir: pathlib.Path):
|
| 221 |
+
from peft import PeftModel
|
| 222 |
+
|
| 223 |
+
merged_dir = experiment_dir / "merged"
|
| 224 |
+
merged_dir.mkdir(parents=True, exist_ok=True)
|
| 225 |
+
|
| 226 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 227 |
+
config.base_model,
|
| 228 |
+
device_map="cpu",
|
| 229 |
+
)
|
| 230 |
+
peft_model = PeftModel.from_pretrained(base, str(experiment_dir))
|
| 231 |
+
merged = peft_model.merge_and_unload()
|
| 232 |
+
merged.save_pretrained(str(merged_dir), safe_serialization=True, max_shard_size="2GB")
|
| 233 |
+
|
| 234 |
+
tokenizer = _load_tokenizer(config.base_model)
|
| 235 |
+
tokenizer.save_pretrained(str(merged_dir))
|
| 236 |
+
gazet_vol.commit()
|
| 237 |
+
print(f"Merged model saved to {merged_dir}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
# Local entrypoint
|
| 242 |
+
# ---------------------------------------------------------------------------
|
| 243 |
+
|
| 244 |
+
@app.local_entrypoint()
|
| 245 |
+
def main(
|
| 246 |
+
base_model: Optional[str] = None,
|
| 247 |
+
experiment_name: Optional[str] = None,
|
| 248 |
+
per_device_train_batch_size: Optional[int] = None,
|
| 249 |
+
max_train_samples: Optional[int] = None,
|
| 250 |
+
max_eval_samples: Optional[int] = None,
|
| 251 |
+
num_train_epochs: Optional[int] = None,
|
| 252 |
+
lora_r: Optional[int] = None,
|
| 253 |
+
max_length: Optional[int] = None,
|
| 254 |
+
):
|
| 255 |
+
from finetune.config import TrainingConfig
|
| 256 |
+
|
| 257 |
+
overrides = {
|
| 258 |
+
k: v for k, v in dict(
|
| 259 |
+
base_model=base_model,
|
| 260 |
+
experiment_name=experiment_name,
|
| 261 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 262 |
+
max_train_samples=max_train_samples,
|
| 263 |
+
max_eval_samples=max_eval_samples,
|
| 264 |
+
num_train_epochs=num_train_epochs,
|
| 265 |
+
lora_r=lora_r,
|
| 266 |
+
max_length=max_length,
|
| 267 |
+
).items() if v is not None
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
config = TrainingConfig(**overrides)
|
| 271 |
+
|
| 272 |
+
print(f"Starting experiment: {config.experiment_name}")
|
| 273 |
+
print(f"Model: {config.base_model}")
|
| 274 |
+
print(f"LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
|
| 275 |
+
effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
|
| 276 |
+
print(f"Effective batch size: {effective_batch}")
|
| 277 |
+
|
| 278 |
+
result = finetune.remote(config.__dict__)
|
| 279 |
+
print(f"Training complete: {result}")
|