srmsoumya commited on
Commit
599b4c3
·
1 Parent(s): c88725f

Add scripts to train & infer model on modal

Browse files
finetune/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
finetune/check_token_lengths.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Check token lengths of training samples to validate max_length setting.
2
+
3
+ Usage
4
+ -----
5
+ modal run finetune/check_token_lengths.py \
6
+ --train-jsonl /data/train.jsonl \
7
+ --val-jsonl /data/val.jsonl \
8
+ --base-model google/gemma-3-270m-it
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import modal
14
+
15
+ app = modal.App("gazet-check-token-lengths")
16
+
17
+ check_image = (
18
+ modal.Image.debian_slim(python_version="3.11")
19
+ .pip_install(
20
+ "datasets>=3.0",
21
+ "pandas>=2.2",
22
+ "transformers>=4.46",
23
+ )
24
+ .add_local_python_source("finetune", copy=True)
25
+ .env({"HF_HOME": "/mnt/gazet/model_cache"})
26
+ )
27
+
28
+ gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
29
+
30
+ VOLUMES = {
31
+ "/mnt/gazet": gazet_vol,
32
+ }
33
+
34
+
35
+ @app.function(
36
+ image=check_image,
37
+ volumes=VOLUMES,
38
+ secrets=[modal.Secret.from_name("huggingface-secret")],
39
+ )
40
+ def analyze_token_lengths(
41
+ train_jsonl: str,
42
+ val_jsonl: str | None,
43
+ base_model: str,
44
+ schema_file: str | None = None,
45
+ ):
46
+ from transformers import AutoTokenizer
47
+ from finetune.data import format_dataset_for_sft, load_jsonl_splits, read_text
48
+ from finetune.prompts import DEFAULT_SCHEMA_DETAILS
49
+
50
+ print(f"Loading tokenizer: {base_model}")
51
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
52
+
53
+ print(f"Loading dataset from {train_jsonl}")
54
+ schema_details = read_text(schema_file, DEFAULT_SCHEMA_DETAILS)
55
+ ds = load_jsonl_splits(train_jsonl, val_jsonl)
56
+ formatted = format_dataset_for_sft(ds, schema_details)
57
+
58
+ def compute_lengths(split_name: str, dataset):
59
+ print(f"\n{'='*60}")
60
+ print(f"Analyzing {split_name} split ({len(dataset)} samples)")
61
+ print(f"{'='*60}")
62
+
63
+ lengths = []
64
+ for row in dataset:
65
+ text = row["prompt"] + row["completion"]
66
+ tokens = tokenizer.encode(text)
67
+ lengths.append(len(tokens))
68
+
69
+ lengths.sort()
70
+ n = len(lengths)
71
+
72
+ print(f"\nToken length statistics:")
73
+ print(f" Samples: {n:,}")
74
+ print(f" Min: {min(lengths):,}")
75
+ print(f" Max: {max(lengths):,}")
76
+ print(f" Mean: {sum(lengths)/n:.0f}")
77
+ print(f" Median: {lengths[n//2]:,}")
78
+ print(f" P90: {lengths[int(n*0.90)]:,}")
79
+ print(f" P95: {lengths[int(n*0.95)]:,}")
80
+ print(f" P99: {lengths[int(n*0.99)]:,}")
81
+
82
+ buckets = [
83
+ (512, "0-512"),
84
+ (1024, "513-1024"),
85
+ (2048, "1025-2048"),
86
+ (4096, "2049-4096"),
87
+ (8192, "4097-8192"),
88
+ (float("inf"), "8193+"),
89
+ ]
90
+
91
+ print(f"\nFrequency distribution:")
92
+ prev_limit = 0
93
+ for limit, label in buckets:
94
+ count = sum(1 for l in lengths if prev_limit < l <= limit)
95
+ pct = 100 * count / n
96
+ bar = "█" * int(pct / 2)
97
+ print(f" {label:>12}: {count:5,} ({pct:5.1f}%) {bar}")
98
+ prev_limit = limit
99
+
100
+ thresholds = [1024, 2048, 4096, 8192]
101
+ print(f"\nSamples exceeding thresholds:")
102
+ for threshold in thresholds:
103
+ count = sum(1 for l in lengths if l > threshold)
104
+ pct = 100 * count / n
105
+ print(f" > {threshold:5,}: {count:5,} ({pct:5.1f}%)")
106
+
107
+ return lengths
108
+
109
+ train_lengths = compute_lengths("train", formatted["train"])
110
+
111
+ if "val" in formatted:
112
+ val_lengths = compute_lengths("val", formatted["val"])
113
+ else:
114
+ val_lengths = []
115
+
116
+ all_lengths = train_lengths + val_lengths
117
+ if all_lengths:
118
+ print(f"\n{'='*60}")
119
+ print(f"COMBINED STATISTICS")
120
+ print(f"{'='*60}")
121
+ all_lengths.sort()
122
+ n = len(all_lengths)
123
+ print(f" Total samples: {n:,}")
124
+ print(f" Max length: {max(all_lengths):,}")
125
+ print(f" P99: {all_lengths[int(n*0.99)]:,}")
126
+
127
+ for threshold in [1024, 2048, 4096]:
128
+ count = sum(1 for l in all_lengths if l > threshold)
129
+ pct = 100 * count / n
130
+ status = "⚠️ WARNING" if count > 0 and threshold == 2048 else "✓ OK"
131
+ print(f" > {threshold:5,}: {count:5,} ({pct:5.1f}%) {status}")
132
+
133
+ print(f"\n{'='*60}")
134
+ print("RECOMMENDATIONS")
135
+ print(f"{'='*60}")
136
+
137
+ max_len = max(all_lengths) if all_lengths else 0
138
+ over_2048 = sum(1 for l in all_lengths if l > 2048) if all_lengths else 0
139
+
140
+ if max_len <= 1024:
141
+ print("✓ All samples fit within 1024 tokens")
142
+ print(" Recommended max_length: 1024")
143
+ elif max_len <= 2048:
144
+ print("✓ All samples fit within 2048 tokens")
145
+ print(" Recommended max_length: 2048")
146
+ elif over_2048 < n * 0.01:
147
+ print(f"⚠️ {over_2048} samples ({100*over_2048/n:.1f}%) exceed 2048 tokens")
148
+ print(" Options:")
149
+ print(" 1. Keep max_length=2048 (truncates <1% of samples)")
150
+ print(" 2. Increase to max_length=4096 (uses more GPU memory)")
151
+ print(" 3. Reduce candidate rows in preprocessing")
152
+ else:
153
+ print(f"⚠️ {over_2048} samples ({100*over_2048/n:.1f}%) exceed 2048 tokens")
154
+ print(f" Recommended max_length: {max_len} (or reduce candidate rows)")
155
+
156
+ print()
157
+
158
+
159
+ @app.local_entrypoint()
160
+ def main(
161
+ train_jsonl: str = "/mnt/gazet/data/output/train.jsonl",
162
+ val_jsonl: str | None = "/mnt/gazet/data/output/val.jsonl",
163
+ base_model: str = "google/gemma-3-270m-it",
164
+ schema_file: str | None = None,
165
+ ):
166
+ print(f"Checking token lengths for:")
167
+ print(f" Model: {base_model}")
168
+ print(f" Train: {train_jsonl}")
169
+ if val_jsonl:
170
+ print(f" Val: {val_jsonl}")
171
+
172
+ analyze_token_lengths.remote(train_jsonl, val_jsonl, base_model, schema_file)
173
+ print("Analysis complete!")
finetune/config.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training configuration for text-to-SQL LoRA finetuning."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from typing import List, Optional
6
+
7
+
8
+ LORA_TARGET_MODULES = [
9
+ "q_proj",
10
+ "k_proj",
11
+ "v_proj",
12
+ "o_proj",
13
+ "gate_proj",
14
+ "up_proj",
15
+ "down_proj",
16
+ ]
17
+
18
+
19
+ @dataclass
20
+ class TrainingConfig:
21
+ # Model
22
+ base_model: str = "google/gemma-3-270m-it"
23
+
24
+ # Dataset (paths on the Modal volume)
25
+ train_jsonl: str = "/mnt/gazet/data/output/train.jsonl"
26
+ val_jsonl: Optional[str] = "/mnt/gazet/data/output/val.jsonl"
27
+ test_jsonl: Optional[str] = "/mnt/gazet/data/output/test.jsonl"
28
+ schema_file: Optional[str] = None
29
+ max_train_samples: Optional[int] = None
30
+ max_eval_samples: Optional[int] = None
31
+
32
+ # LoRA
33
+ lora_r: int = 16
34
+ lora_alpha: int = 16
35
+ lora_dropout: float = 0.05
36
+ target_modules: List[str] = field(default_factory=lambda: list(LORA_TARGET_MODULES))
37
+
38
+ # Training
39
+ num_train_epochs: int = 2
40
+ per_device_train_batch_size: int = 12
41
+ per_device_eval_batch_size: int = 12
42
+ gradient_accumulation_steps: int = 2
43
+ gradient_checkpointing: bool = True
44
+ optim: str = "adamw_torch_fused"
45
+ learning_rate: float = 1e-4
46
+ max_grad_norm: float = 0.7
47
+ warmup_steps: int = 50
48
+ lr_scheduler_type: str = "constant"
49
+ weight_decay: float = 0.0
50
+ packing: bool = False
51
+ max_length: int = 2048
52
+
53
+ # Logging / saving
54
+ logging_steps: int = 10
55
+ save_strategy: str = "steps"
56
+ save_steps: int = 300
57
+ eval_strategy: str = "steps"
58
+ eval_steps: int = 100
59
+ report_to: str = "trackio"
60
+ trackio_space_id: Optional[str] = "srmsoumya/gazet-trackio"
61
+ project: str = "gazet-nlg"
62
+
63
+ # SFT-specific
64
+ completion_only_loss: bool = True
65
+ dataset_num_proc: Optional[int] = 8
66
+
67
+ # Experiment
68
+ seed: int = 42
69
+ experiment_name: Optional[str] = None
70
+ merge_after_training: bool = True
71
+
72
+ def __post_init__(self):
73
+ if self.experiment_name is None:
74
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
75
+ model_short = self.base_model.split("/")[-1]
76
+ self.experiment_name = f"{model_short}-r{self.lora_r}-{timestamp}"
finetune/data.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loading and SFT formatting for text-to-SQL finetuning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Dict, Optional
8
+
9
+ from datasets import DatasetDict, load_dataset
10
+
11
+ from finetune.prompts import DEFAULT_SCHEMA_DETAILS, make_prompt_completion
12
+
13
+ LOGGER = logging.getLogger("nlg.data")
14
+
15
+
16
+ def read_text(path: Optional[str], default: str) -> str:
17
+ if not path:
18
+ return default
19
+ return Path(path).read_text(encoding="utf-8")
20
+
21
+
22
+ def load_jsonl_splits(
23
+ train_jsonl: str,
24
+ val_jsonl: Optional[str] = None,
25
+ test_jsonl: Optional[str] = None,
26
+ ) -> DatasetDict:
27
+ data_files: Dict[str, str] = {"train": train_jsonl}
28
+ if val_jsonl:
29
+ data_files["val"] = val_jsonl
30
+ if test_jsonl:
31
+ data_files["test"] = test_jsonl
32
+ LOGGER.info("Loading dataset splits: %s", data_files)
33
+ return load_dataset("json", data_files=data_files)
34
+
35
+
36
+ def format_dataset_for_sft(
37
+ dataset: DatasetDict,
38
+ schema_details: str = DEFAULT_SCHEMA_DETAILS,
39
+ ) -> DatasetDict:
40
+ formatted = DatasetDict()
41
+ for split, ds in dataset.items():
42
+ formatted[split] = ds.map(
43
+ lambda row: make_prompt_completion(row, schema_details)
44
+ )
45
+ return formatted
finetune/eval_demo.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit eval viewer: compare expected vs predicted SQL and view results on a map.
2
+
3
+ Usage: streamlit run finetune/eval_demo.py
4
+ """
5
+
6
+ import difflib
7
+ import json
8
+ import math
9
+ import os
10
+ import pathlib
11
+
12
+ import duckdb
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pydeck as pdk
16
+ import sqlparse
17
+ import streamlit as st
18
+
19
+ PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
20
+ DATA_DIR = pathlib.Path(
21
+ os.environ.get("GAZET_DATA_DIR", str(PROJECT_ROOT / "data"))
22
+ )
23
+ EVAL_DIR = PROJECT_ROOT / "data" / "eval_results"
24
+
25
+
26
+ def load_eval_results(path):
27
+ with open(path) as f:
28
+ return json.load(f)
29
+
30
+
31
+ def rewrite_data_paths(sql):
32
+ """Replace hardcoded /data/ paths with the local data directory."""
33
+ return sql.replace("/data/", f"{DATA_DIR}/")
34
+
35
+
36
+ def format_sql(sql):
37
+ """Pretty-print SQL with sqlparse."""
38
+ return sqlparse.format(sql, reindent=True, keyword_case="upper")
39
+
40
+
41
+ def sql_diff_html(expected, predicted):
42
+ """Return an HTML diff of two SQL strings."""
43
+ expected_lines = format_sql(expected).splitlines()
44
+ predicted_lines = format_sql(predicted).splitlines()
45
+ diff = difflib.HtmlDiff(tabsize=2, wrapcolumn=80)
46
+ return diff.make_table(
47
+ expected_lines, predicted_lines,
48
+ fromdesc="Expected", todesc="Predicted",
49
+ context=False,
50
+ )
51
+
52
+
53
+ def get_duckdb_connection():
54
+ con = duckdb.connect()
55
+ con.execute("INSTALL spatial")
56
+ con.execute("LOAD spatial")
57
+ return con
58
+
59
+
60
+ def execute_sql(con, sql):
61
+ """Execute SQL, converting geometry columns to simplified GeoJSON strings."""
62
+ rel = con.sql(sql)
63
+ cols = rel.columns
64
+ types = [str(t) for t in rel.dtypes]
65
+
66
+ select_parts = []
67
+ for col, dtype in zip(cols, types):
68
+ if "GEOMETRY" in dtype.upper():
69
+ select_parts.append(
70
+ f'ST_AsGeoJSON(ST_SimplifyPreserveTopology("{col}", 0.001)) AS "{col}"'
71
+ )
72
+ else:
73
+ select_parts.append(f'"{col}"')
74
+
75
+ wrapped = f"SELECT {', '.join(select_parts)} FROM ({sql})"
76
+ return con.execute(wrapped).fetchdf()
77
+
78
+
79
+ def _is_notna(val):
80
+ """Check if a value is not NA, handling arrays/lists/numpy arrays safely."""
81
+ if isinstance(val, (list, tuple, np.ndarray)):
82
+ return len(val) > 0
83
+ return pd.notna(val)
84
+
85
+
86
+ def _to_python(val):
87
+ """Convert numpy/pandas types to native Python for JSON serialization."""
88
+ if isinstance(val, (np.integer,)):
89
+ return int(val)
90
+ if isinstance(val, (np.floating,)):
91
+ return float(val)
92
+ if isinstance(val, np.ndarray):
93
+ return val.tolist()
94
+ if isinstance(val, (np.bool_,)):
95
+ return bool(val)
96
+ return val
97
+
98
+
99
+ def to_feature_collection(result_df):
100
+ """Build GeoJSON FeatureCollection from a DataFrame with GeoJSON string columns."""
101
+ geom_cols = []
102
+ for c in result_df.columns:
103
+ vals = [v for v in result_df[c].head(5) if isinstance(v, str)]
104
+ if vals and all(v.lstrip().startswith('{"type":') for v in vals):
105
+ geom_cols.append(c)
106
+
107
+ prop_cols = [c for c in result_df.columns if c not in geom_cols]
108
+ features = []
109
+ for _, row in result_df.iterrows():
110
+ geometry = None
111
+ if geom_cols:
112
+ raw = row[geom_cols[0]]
113
+ if raw and isinstance(raw, str):
114
+ geometry = json.loads(raw)
115
+ properties = {}
116
+ for c in prop_cols:
117
+ val = row[c]
118
+ if _is_notna(val):
119
+ properties[c] = _to_python(val)
120
+ features.append(
121
+ {"type": "Feature", "geometry": geometry, "properties": properties}
122
+ )
123
+ return {"type": "FeatureCollection", "features": features}
124
+
125
+
126
+ def bbox_from_geojson(geojson):
127
+ lngs, lats = [], []
128
+ for f in geojson.get("features", []):
129
+ geom = f.get("geometry")
130
+ if geom:
131
+ for coord in _extract_coords(geom):
132
+ lngs.append(coord[0])
133
+ lats.append(coord[1])
134
+ if not lngs:
135
+ return None
136
+ return min(lngs), min(lats), max(lngs), max(lats)
137
+
138
+
139
+ def _extract_coords(geom):
140
+ t = geom.get("type", "")
141
+ coords = geom.get("coordinates", [])
142
+ if t == "Point":
143
+ yield coords
144
+ elif t in ("LineString", "MultiPoint"):
145
+ yield from coords
146
+ elif t == "Polygon":
147
+ for ring in coords:
148
+ yield from ring
149
+ elif t in ("MultiLineString", "MultiPolygon"):
150
+ for part in coords:
151
+ if t == "MultiLineString":
152
+ yield from part
153
+ else:
154
+ for ring in part:
155
+ yield from ring
156
+ elif t == "GeometryCollection":
157
+ for g in geom.get("geometries", []):
158
+ yield from _extract_coords(g)
159
+
160
+
161
+ def _centroids_from_geojson(geojson):
162
+ """Extract centroid [lng, lat] for each feature to use as scatter markers."""
163
+ centroids = []
164
+ for f in geojson.get("features", []):
165
+ geom = f.get("geometry")
166
+ if not geom:
167
+ continue
168
+ lngs, lats = [], []
169
+ for coord in _extract_coords(geom):
170
+ lngs.append(coord[0])
171
+ lats.append(coord[1])
172
+ if lngs:
173
+ centroids.append({"lng": sum(lngs) / len(lngs), "lat": sum(lats) / len(lats)})
174
+ return centroids
175
+
176
+
177
+ def render_map(geojson, color, key):
178
+ n = len(geojson.get("features", []))
179
+ if not n:
180
+ st.info("Query returned no features.")
181
+ return
182
+
183
+ layers = [
184
+ pdk.Layer(
185
+ "GeoJsonLayer",
186
+ data=geojson,
187
+ get_fill_color=color,
188
+ get_line_color=[100, 100, 100, 200],
189
+ get_line_width=2,
190
+ pickable=True,
191
+ ),
192
+ ]
193
+
194
+ bbox = bbox_from_geojson(geojson)
195
+ if bbox:
196
+ min_lng, min_lat, max_lng, max_lat = bbox
197
+ span = max(max_lng - min_lng, max_lat - min_lat, 1e-6)
198
+ zoom = max(0, min(18, math.log2(360 / span) - 0.8))
199
+
200
+ # Add scatter markers when polygons would be too small to see
201
+ if zoom < 4:
202
+ centroids = _centroids_from_geojson(geojson)
203
+ if centroids:
204
+ layers.append(
205
+ pdk.Layer(
206
+ "ScatterplotLayer",
207
+ data=centroids,
208
+ get_position=["lng", "lat"],
209
+ get_fill_color=color[:3] + [220],
210
+ get_radius=50000,
211
+ radius_min_pixels=6,
212
+ pickable=True,
213
+ )
214
+ )
215
+
216
+ view = pdk.ViewState(
217
+ latitude=(min_lat + max_lat) / 2,
218
+ longitude=(min_lng + max_lng) / 2,
219
+ zoom=zoom,
220
+ )
221
+ else:
222
+ view = pdk.ViewState(latitude=0, longitude=0, zoom=1)
223
+
224
+ st.pydeck_chart(
225
+ pdk.Deck(layers=layers, initial_view_state=view, map_style=None),
226
+ width="stretch",
227
+ height=400,
228
+ key=key,
229
+ )
230
+
231
+
232
+ # --- App ---
233
+
234
+ st.set_page_config(page_title="Eval Viewer", layout="wide")
235
+ st.title("Eval Viewer")
236
+
237
+ eval_files = sorted(EVAL_DIR.glob("eval-*.json"))
238
+ if not eval_files:
239
+ st.error(f"No eval result files found in {EVAL_DIR}")
240
+ st.stop()
241
+
242
+ selected_file = st.sidebar.selectbox(
243
+ "Eval file",
244
+ eval_files,
245
+ format_func=lambda p: p.stem,
246
+ )
247
+
248
+ data = load_eval_results(selected_file)
249
+ summary = data["summary"]
250
+ results = data["results"]
251
+
252
+ st.sidebar.markdown(f"""
253
+ **Model**: `{summary.get('label', '')}`
254
+ **Exact match**: {summary['exact_matches']}/{summary['num_samples']} ({summary['exact_match_rate']:.1%})
255
+ """)
256
+
257
+ filter_option = st.sidebar.radio("Filter", ["All", "Matches only", "Mismatches only"])
258
+ if filter_option == "Matches only":
259
+ results = [r for r in results if r["exact_match"]]
260
+ elif filter_option == "Mismatches only":
261
+ results = [r for r in results if not r["exact_match"]]
262
+
263
+ if not results:
264
+ st.warning("No results match the current filter.")
265
+ st.stop()
266
+
267
+ questions = [f"[{r['index']}] {r['question']}" for r in results]
268
+ selected_idx = st.selectbox("Select a query", range(len(questions)), format_func=lambda i: questions[i])
269
+ row = results[selected_idx]
270
+
271
+ match_label = "MATCH" if row["exact_match"] else "MISMATCH"
272
+ match_color = "green" if row["exact_match"] else "red"
273
+ st.markdown(f"### :{match_color}[{match_label}]")
274
+
275
+ # Formatted SQL side-by-side
276
+ col_expected, col_predicted = st.columns(2)
277
+ with col_expected:
278
+ st.markdown("**Expected SQL**")
279
+ st.code(format_sql(row["expected_sql"]), language="sql")
280
+ with col_predicted:
281
+ st.markdown("**Predicted SQL**")
282
+ st.code(format_sql(row["predicted_sql"]), language="sql")
283
+
284
+ # Diff view
285
+ if not row["exact_match"]:
286
+ with st.expander("SQL Diff", expanded=True):
287
+ diff_html = sql_diff_html(row["expected_sql"], row["predicted_sql"])
288
+ diff_css = """
289
+ <style>
290
+ .diff_add { background-color: rgba(40, 167, 69, 0.15); }
291
+ .diff_sub { background-color: rgba(220, 53, 69, 0.15); }
292
+ .diff_chg { background-color: rgba(255, 193, 7, 0.15); }
293
+ .diff_header { background-color: rgba(128, 128, 128, 0.1); font-weight: bold; }
294
+ table.diff { border-collapse: collapse; width: 100%; font-family: monospace; color: inherit; }
295
+ table.diff td, table.diff th { padding: 4px 8px; border: 1px solid rgba(128, 128, 128, 0.2); }
296
+ </style>
297
+ """
298
+ st.html(f"{diff_css}<div style='overflow-x:auto; font-size:13px;'>{diff_html}</div>")
299
+
300
+ # Auto-execute both SQLs and show maps
301
+ con = get_duckdb_connection()
302
+
303
+ map_col1, map_col2 = st.columns(2)
304
+
305
+ with map_col1:
306
+ st.markdown("**Expected result**")
307
+ sql = rewrite_data_paths(row["expected_sql"])
308
+ try:
309
+ df = execute_sql(con, sql)
310
+ geojson = to_feature_collection(df)
311
+ render_map(geojson, [40, 180, 160, 140], key="map_expected")
312
+ with st.expander("Result table"):
313
+ st.dataframe(df, width="stretch")
314
+ except Exception as e:
315
+ st.error(f"Execution error: {e}")
316
+
317
+ with map_col2:
318
+ st.markdown("**Predicted result**")
319
+ sql = rewrite_data_paths(row["predicted_sql"])
320
+ try:
321
+ df = execute_sql(con, sql)
322
+ geojson = to_feature_collection(df)
323
+ render_map(geojson, [180, 80, 60, 140], key="map_predicted")
324
+ with st.expander("Result table"):
325
+ st.dataframe(df, width="stretch")
326
+ except Exception as e:
327
+ st.error(f"Execution error: {e}")
328
+
329
+ con.close()
finetune/infer_modal.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal eval script: run a model on the test set and save results.
2
+
3
+ Usage
4
+ -----
5
+ # Eval finetuned model (uses raw prompt-completion format):
6
+ modal run finetune/infer_modal.py --label finetuned
7
+
8
+ # Eval base model (uses chat template so the model understands the instruction):
9
+ modal run finetune/infer_modal.py \
10
+ --model-path google/gemma-3-270m-it \
11
+ --label base \
12
+ --use-chat-template
13
+
14
+ # Limit samples:
15
+ modal run finetune/infer_modal.py --max-samples 50
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import pathlib
22
+ from datetime import datetime
23
+ from typing import Optional
24
+
25
+ import modal
26
+
27
+ app = modal.App("gazet-nlg-eval")
28
+
29
+ infer_image = (
30
+ modal.Image.debian_slim(python_version="3.11")
31
+ .pip_install(
32
+ "accelerate>=1.0",
33
+ "pandas>=2.2",
34
+ "torch>=2.4",
35
+ "transformers>=4.46",
36
+ )
37
+ .add_local_python_source("finetune", copy=True)
38
+ .env({"HF_HOME": "/mnt/gazet/model_cache"})
39
+ )
40
+
41
+ gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
42
+
43
+ VOLUMES = {
44
+ "/mnt/gazet": gazet_vol,
45
+ }
46
+
47
+ DEFAULT_MODEL_PATH = "/mnt/gazet/checkpoints/gemma-3-270m-it-r16-20260331-134642/merged"
48
+
49
+
50
+ def postprocess_sql(text: str) -> str:
51
+ cleaned = text.strip()
52
+ if "```sql" in cleaned:
53
+ cleaned = cleaned.split("```sql", 1)[1]
54
+ if cleaned.startswith("```"):
55
+ cleaned = cleaned[3:]
56
+ if "```" in cleaned:
57
+ cleaned = cleaned.split("```", 1)[0]
58
+ return cleaned.strip()
59
+
60
+
61
+ @app.function(
62
+ image=infer_image,
63
+ gpu="L40S",
64
+ volumes=VOLUMES,
65
+ secrets=[modal.Secret.from_name("huggingface-secret")],
66
+ timeout=60 * 60,
67
+ )
68
+ def run_eval(
69
+ model_path: str,
70
+ label: str,
71
+ samples: list[dict],
72
+ output_path: str,
73
+ max_new_tokens: int = 512,
74
+ batch_size: int = 16,
75
+ use_chat_template: bool = False,
76
+ ):
77
+ """Run batched inference on all samples, save results to volume."""
78
+ import torch
79
+ from transformers import AutoModelForCausalLM, AutoTokenizer
80
+
81
+ from finetune.prompts import SYSTEM_PROMPT, build_user_prompt, DEFAULT_SCHEMA_DETAILS
82
+
83
+ print(f"Loading model [{label}]: {model_path}")
84
+ print(f"Chat template: {use_chat_template}")
85
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
86
+ if tokenizer.pad_token is None:
87
+ tokenizer.pad_token = tokenizer.eos_token
88
+ tokenizer.padding_side = "left"
89
+
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ model_path,
92
+ torch_dtype=torch.bfloat16,
93
+ attn_implementation="sdpa",
94
+ device_map="auto",
95
+ )
96
+ model.eval()
97
+
98
+ # Build all prompts upfront
99
+ prompts = []
100
+ for sample in samples:
101
+ user_content = build_user_prompt(
102
+ question=sample["question"],
103
+ candidates=sample["candidates"],
104
+ schema_details=DEFAULT_SCHEMA_DETAILS,
105
+ )
106
+ if use_chat_template:
107
+ messages = [
108
+ {"role": "user", "content": SYSTEM_PROMPT + "\n\n" + user_content},
109
+ ]
110
+ prompt = tokenizer.apply_chat_template(
111
+ messages, tokenize=False, add_generation_prompt=True
112
+ )
113
+ else:
114
+ prompt = SYSTEM_PROMPT + "\n\n" + user_content
115
+ prompts.append(prompt)
116
+
117
+ # Batched inference
118
+ all_predictions = []
119
+ num_batches = (len(prompts) + batch_size - 1) // batch_size
120
+
121
+ for batch_idx in range(num_batches):
122
+ start = batch_idx * batch_size
123
+ end = min(start + batch_size, len(prompts))
124
+ batch_prompts = prompts[start:end]
125
+
126
+ inputs = tokenizer(
127
+ batch_prompts,
128
+ return_tensors="pt",
129
+ padding=True,
130
+ truncation=True,
131
+ max_length=2048,
132
+ ).to(model.device)
133
+ input_len = inputs["input_ids"].shape[1]
134
+
135
+ with torch.inference_mode():
136
+ outputs = model.generate(
137
+ **inputs,
138
+ max_new_tokens=max_new_tokens,
139
+ do_sample=False,
140
+ pad_token_id=tokenizer.pad_token_id,
141
+ eos_token_id=tokenizer.eos_token_id,
142
+ )
143
+
144
+ for j in range(len(batch_prompts)):
145
+ generated = tokenizer.decode(
146
+ outputs[j][input_len:], skip_special_tokens=True
147
+ )
148
+ all_predictions.append(postprocess_sql(generated))
149
+
150
+ print(f"Batch {batch_idx+1}/{num_batches} done ({end}/{len(prompts)} samples)")
151
+
152
+ # Build results
153
+ results = []
154
+ matches = 0
155
+ for i, sample in enumerate(samples):
156
+ expected = sample.get("target", {}).get("sql", "")
157
+ predicted = all_predictions[i]
158
+ is_match = predicted.strip() == expected.strip()
159
+ if is_match:
160
+ matches += 1
161
+
162
+ results.append({
163
+ "index": i,
164
+ "question": sample["question"],
165
+ "candidates": sample["candidates"],
166
+ "expected_sql": expected,
167
+ "predicted_sql": predicted,
168
+ "exact_match": is_match,
169
+ })
170
+
171
+ total = len(results)
172
+ exact_match_rate = matches / total if total else 0
173
+
174
+ output = {
175
+ "summary": {
176
+ "label": label,
177
+ "model_path": model_path,
178
+ "num_samples": total,
179
+ "exact_matches": matches,
180
+ "exact_match_rate": exact_match_rate,
181
+ "timestamp": datetime.now().isoformat(),
182
+ },
183
+ "results": results,
184
+ }
185
+
186
+ path = pathlib.Path(output_path)
187
+ path.parent.mkdir(parents=True, exist_ok=True)
188
+ with open(path, "w") as f:
189
+ json.dump(output, f, indent=2)
190
+ gazet_vol.commit()
191
+
192
+ print(f"\n{'='*60}")
193
+ print(f"[{label}] {matches}/{total} exact matches ({100*exact_match_rate:.1f}%)")
194
+ print(f"Results saved to {output_path}")
195
+ print(f"{'='*60}")
196
+
197
+
198
+ @app.function(
199
+ image=infer_image,
200
+ volumes=VOLUMES,
201
+ )
202
+ def read_test_data(test_jsonl: str) -> list[dict]:
203
+ """Read test JSONL from the volume."""
204
+ lines = []
205
+ with open(test_jsonl) as f:
206
+ for line in f:
207
+ lines.append(json.loads(line))
208
+ return lines
209
+
210
+
211
+ @app.local_entrypoint()
212
+ def main(
213
+ model_path: str = DEFAULT_MODEL_PATH,
214
+ label: str = "finetuned",
215
+ test_jsonl: str = "/mnt/gazet/data/output/test.jsonl",
216
+ max_samples: Optional[int] = None,
217
+ max_new_tokens: int = 512,
218
+ batch_size: int = 16,
219
+ use_chat_template: bool = False,
220
+ output_dir: str = "/mnt/gazet/eval_results",
221
+ ):
222
+ print(f"Model: {model_path}")
223
+ print(f"Label: {label}")
224
+ print(f"Chat template: {use_chat_template}")
225
+
226
+ print("Loading test data...")
227
+ samples = read_test_data.remote(test_jsonl)
228
+ if max_samples:
229
+ samples = samples[:max_samples]
230
+ print(f"Eval samples: {len(samples)}")
231
+
232
+ output_file = f"{output_dir}/eval-{label}.json"
233
+ run_eval.remote(
234
+ model_path, label, samples, output_file,
235
+ max_new_tokens, batch_size, use_chat_template,
236
+ )
finetune/prompts.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt templates and message formatting for natural language geocoding."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Sequence
6
+
7
+ import pandas as pd
8
+
9
+ SYSTEM_PROMPT = (
10
+ "You are a text to SQL query translator that helps in natural language geocoding."
11
+ )
12
+
13
+ USER_PROMPT_TEMPLATE = """GIVEN the <SCHEMA_DETAILS>, <CANDIDATES> and <USER_QUERY>, generate the corresponding SQL command to retrieve the desired geometry.
14
+
15
+ <SCHEMA_DETAILS>
16
+ {schema_details}
17
+ </SCHEMA_DETAILS>
18
+
19
+ <CANDIDATES>
20
+ {candidates_csv}
21
+ </CANDIDATES>
22
+
23
+ <USER_QUERY>
24
+ {question}
25
+ </USER_QUERY>
26
+ """
27
+
28
+ DEFAULT_SCHEMA_DETAILS = """1. divisions_area -- Overture polygon/multipolygon admin boundaries
29
+ path: '/data/overture/division_area/*.parquet'
30
+ columns:
31
+ id VARCHAR
32
+ names STRUCT("primary" VARCHAR, ...)
33
+ country VARCHAR
34
+ subtype VARCHAR
35
+ class VARCHAR
36
+ region VARCHAR
37
+ admin_level INTEGER
38
+ division_id VARCHAR
39
+ is_land BOOLEAN
40
+ is_territorial BOOLEAN
41
+ geometry GEOMETRY
42
+
43
+ 2. natural_earth -- Natural Earth geography polygons
44
+ path: '/data/natural_earth_geoparquet/ne_geography.parquet'
45
+ columns:
46
+ id VARCHAR
47
+ name VARCHAR
48
+ featurecla VARCHAR
49
+ scalerank INTEGER
50
+ min_zoom DOUBLE
51
+ geometry GEOMETRY"""
52
+
53
+
54
+ def candidates_to_csv(candidates: Sequence[Dict[str, Any]]) -> str:
55
+ df = pd.DataFrame(list(candidates))
56
+ if "candidate_id" in df.columns:
57
+ df = df.drop(columns=["candidate_id"])
58
+ return df.to_csv(index=False)
59
+
60
+
61
+ def build_user_prompt(
62
+ question: str,
63
+ candidates: Sequence[Dict[str, Any]],
64
+ schema_details: str,
65
+ ) -> str:
66
+ return USER_PROMPT_TEMPLATE.format(
67
+ schema_details=schema_details.strip(),
68
+ candidates_csv=candidates_to_csv(candidates).strip(),
69
+ question=question.strip(),
70
+ )
71
+
72
+
73
+ def make_prompt_completion(
74
+ sample: Dict[str, Any],
75
+ schema_details: str,
76
+ ) -> Dict[str, str]:
77
+ prompt = SYSTEM_PROMPT + "\n\n" + build_user_prompt(
78
+ question=sample["question"],
79
+ candidates=sample["candidates"],
80
+ schema_details=schema_details,
81
+ )
82
+ completion = sample.get("target", {}).get("sql", "")
83
+ return {"prompt": prompt, "completion": completion}
finetune/train_modal.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal training script for text-to-SQL LoRA finetuning.
2
+
3
+ Usage
4
+ -----
5
+ modal run finetune/train_modal.py \
6
+ --train-jsonl /data/train.jsonl \
7
+ --val-jsonl /data/val.jsonl \
8
+ --base-model google/gemma-3-1b-it
9
+
10
+ All CLI arguments map to TrainingConfig fields. Run with --help for details.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import pathlib
16
+ from typing import Optional
17
+
18
+ import modal
19
+
20
+ app = modal.App("gazet-nlg-finetune")
21
+
22
+ GPU_TYPE = "A100-80GB" # "L40S"
23
+ TIMEOUT_HOURS = 6
24
+ MAX_RETRIES = 1
25
+
26
+ train_image = (
27
+ modal.Image.debian_slim(python_version="3.11")
28
+ .pip_install(
29
+ "accelerate>=1.0",
30
+ "datasets>=3.0",
31
+ "hf-transfer>=0.1",
32
+ "huggingface_hub>=0.25",
33
+ "jinja2>=3.0",
34
+ "pandas>=2.2",
35
+ "peft>=0.13",
36
+ "torch>=2.4",
37
+ "trackio[gpu]",
38
+ "transformers>=4.46",
39
+ "trl>=0.12",
40
+ )
41
+ .add_local_python_source("finetune", copy=True)
42
+ .env({"HF_HOME": "/mnt/gazet/model_cache", "HF_HUB_ENABLE_HF_TRANSFER": "1"})
43
+ )
44
+
45
+ with train_image.imports():
46
+ import torch
47
+ from datasets import DatasetDict
48
+ from peft import LoraConfig
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
50
+ from trl import SFTConfig, SFTTrainer
51
+
52
+ gazet_vol = modal.Volume.from_name("gazet", create_if_missing=True)
53
+
54
+ VOLUMES = {
55
+ "/mnt/gazet": gazet_vol,
56
+ }
57
+
58
+
59
+ def _load_tokenizer(model_name: str):
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ if tokenizer.pad_token is None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+ return tokenizer
64
+
65
+
66
+ def _load_model(model_name: str):
67
+ return AutoModelForCausalLM.from_pretrained(
68
+ model_name,
69
+ torch_dtype=torch.bfloat16,
70
+ attn_implementation="sdpa",
71
+ device_map="auto",
72
+ )
73
+
74
+
75
+ def _build_lora_config(config) -> LoraConfig:
76
+ return LoraConfig(
77
+ r=config.lora_r,
78
+ lora_alpha=config.lora_alpha,
79
+ lora_dropout=config.lora_dropout,
80
+ bias="none",
81
+ task_type="CAUSAL_LM",
82
+ target_modules=config.target_modules,
83
+ )
84
+
85
+
86
+ def _load_and_format_dataset(config) -> DatasetDict:
87
+ """Load JSONL splits and apply prompt-completion formatting."""
88
+ from finetune.data import (
89
+ format_dataset_for_sft,
90
+ load_jsonl_splits,
91
+ read_text,
92
+ )
93
+ from finetune.prompts import DEFAULT_SCHEMA_DETAILS
94
+
95
+ schema_details = read_text(config.schema_file, DEFAULT_SCHEMA_DETAILS)
96
+ raw_ds = load_jsonl_splits(config.train_jsonl, config.val_jsonl, config.test_jsonl)
97
+ ds = format_dataset_for_sft(raw_ds, schema_details)
98
+
99
+ if config.max_train_samples is not None:
100
+ ds["train"] = ds["train"].select(
101
+ range(min(config.max_train_samples, len(ds["train"])))
102
+ )
103
+ if config.max_eval_samples is not None and "val" in ds:
104
+ ds["val"] = ds["val"].select(
105
+ range(min(config.max_eval_samples, len(ds["val"])))
106
+ )
107
+ return ds
108
+
109
+
110
+ def _find_latest_checkpoint(checkpoint_dir: pathlib.Path) -> str | None:
111
+ if not checkpoint_dir.exists():
112
+ return None
113
+ checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
114
+ if not checkpoints:
115
+ return None
116
+ latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
117
+ print(f"Found existing checkpoint: {latest}")
118
+ return str(latest)
119
+
120
+
121
+ @app.function(
122
+ image=train_image,
123
+ gpu=GPU_TYPE,
124
+ volumes=VOLUMES,
125
+ secrets=[modal.Secret.from_name("huggingface-secret")],
126
+ timeout=TIMEOUT_HOURS * 60 * 60,
127
+ retries=modal.Retries(initial_delay=0.0, max_retries=MAX_RETRIES),
128
+ )
129
+ def finetune(config_dict: dict):
130
+ """Run LoRA SFT training inside a Modal container."""
131
+ from finetune.config import TrainingConfig
132
+
133
+ config = TrainingConfig(**config_dict)
134
+ set_seed(config.seed)
135
+
136
+ experiment_dir = pathlib.Path("/mnt/gazet/checkpoints") / config.experiment_name
137
+ experiment_dir.mkdir(parents=True, exist_ok=True)
138
+
139
+ print(f"Experiment: {config.experiment_name}")
140
+ print(f"Model: {config.base_model}")
141
+
142
+ # Model and tokenizer
143
+ tokenizer = _load_tokenizer(config.base_model)
144
+ model = _load_model(config.base_model)
145
+
146
+ # Dataset
147
+ ds = _load_and_format_dataset(config)
148
+ print(f"Train samples: {len(ds['train']):,}")
149
+ if "val" in ds:
150
+ print(f"Val samples: {len(ds['val']):,}")
151
+
152
+ # LoRA
153
+ peft_config = _build_lora_config(config)
154
+
155
+ # SFT config
156
+ sft_args = SFTConfig(
157
+ output_dir=str(experiment_dir),
158
+ max_length=config.max_length,
159
+ packing=config.packing,
160
+ num_train_epochs=config.num_train_epochs,
161
+ per_device_train_batch_size=config.per_device_train_batch_size,
162
+ per_device_eval_batch_size=config.per_device_eval_batch_size,
163
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
164
+ gradient_checkpointing=config.gradient_checkpointing,
165
+ optim=config.optim,
166
+ logging_steps=config.logging_steps,
167
+ save_strategy=config.save_strategy,
168
+ save_steps=config.save_steps,
169
+ eval_strategy=config.eval_strategy,
170
+ eval_steps=config.eval_steps,
171
+ learning_rate=config.learning_rate,
172
+ bf16=True,
173
+ max_grad_norm=config.max_grad_norm,
174
+ warmup_steps=config.warmup_steps,
175
+ lr_scheduler_type=config.lr_scheduler_type,
176
+ weight_decay=config.weight_decay,
177
+ report_to=config.report_to,
178
+ trackio_space_id=config.trackio_space_id,
179
+ project=config.project,
180
+ completion_only_loss=config.completion_only_loss,
181
+ dataset_num_proc=config.dataset_num_proc,
182
+ seed=config.seed,
183
+ )
184
+
185
+ trainer = SFTTrainer(
186
+ model=model,
187
+ args=sft_args,
188
+ train_dataset=ds["train"],
189
+ eval_dataset=ds.get("val"),
190
+ peft_config=peft_config,
191
+ processing_class=tokenizer,
192
+ )
193
+
194
+ total_params = sum(p.numel() for p in model.parameters())
195
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
196
+ print(f"Total parameters: {total_params:,}")
197
+ print(f"Trainable parameters: {trainable_params:,}")
198
+
199
+ # Resume from checkpoint if available (handles preemption)
200
+ resume_from = _find_latest_checkpoint(experiment_dir)
201
+ if resume_from:
202
+ print(f"Resuming from {resume_from}")
203
+
204
+ trainer.train(resume_from_checkpoint=resume_from)
205
+
206
+ # Save final adapter + tokenizer
207
+ print(f"Saving adapter to {experiment_dir}")
208
+ trainer.save_model(str(experiment_dir))
209
+ tokenizer.save_pretrained(str(experiment_dir))
210
+ gazet_vol.commit()
211
+
212
+ # Optionally merge adapter into base model
213
+ if config.merge_after_training:
214
+ _merge_and_save(config, experiment_dir)
215
+
216
+ print(f"Training complete: {config.experiment_name}")
217
+ return config.experiment_name
218
+
219
+
220
+ def _merge_and_save(config, experiment_dir: pathlib.Path):
221
+ from peft import PeftModel
222
+
223
+ merged_dir = experiment_dir / "merged"
224
+ merged_dir.mkdir(parents=True, exist_ok=True)
225
+
226
+ base = AutoModelForCausalLM.from_pretrained(
227
+ config.base_model,
228
+ device_map="cpu",
229
+ )
230
+ peft_model = PeftModel.from_pretrained(base, str(experiment_dir))
231
+ merged = peft_model.merge_and_unload()
232
+ merged.save_pretrained(str(merged_dir), safe_serialization=True, max_shard_size="2GB")
233
+
234
+ tokenizer = _load_tokenizer(config.base_model)
235
+ tokenizer.save_pretrained(str(merged_dir))
236
+ gazet_vol.commit()
237
+ print(f"Merged model saved to {merged_dir}")
238
+
239
+
240
+ # ---------------------------------------------------------------------------
241
+ # Local entrypoint
242
+ # ---------------------------------------------------------------------------
243
+
244
+ @app.local_entrypoint()
245
+ def main(
246
+ base_model: Optional[str] = None,
247
+ experiment_name: Optional[str] = None,
248
+ per_device_train_batch_size: Optional[int] = None,
249
+ max_train_samples: Optional[int] = None,
250
+ max_eval_samples: Optional[int] = None,
251
+ num_train_epochs: Optional[int] = None,
252
+ lora_r: Optional[int] = None,
253
+ max_length: Optional[int] = None,
254
+ ):
255
+ from finetune.config import TrainingConfig
256
+
257
+ overrides = {
258
+ k: v for k, v in dict(
259
+ base_model=base_model,
260
+ experiment_name=experiment_name,
261
+ per_device_train_batch_size=per_device_train_batch_size,
262
+ max_train_samples=max_train_samples,
263
+ max_eval_samples=max_eval_samples,
264
+ num_train_epochs=num_train_epochs,
265
+ lora_r=lora_r,
266
+ max_length=max_length,
267
+ ).items() if v is not None
268
+ }
269
+
270
+ config = TrainingConfig(**overrides)
271
+
272
+ print(f"Starting experiment: {config.experiment_name}")
273
+ print(f"Model: {config.base_model}")
274
+ print(f"LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
275
+ effective_batch = config.per_device_train_batch_size * config.gradient_accumulation_steps
276
+ print(f"Effective batch size: {effective_batch}")
277
+
278
+ result = finetune.remote(config.__dict__)
279
+ print(f"Training complete: {result}")