junaid0600 commited on
Commit
44e9354
Β·
verified Β·
1 Parent(s): 86abfc1

Update training/train_agent.py

Browse files
Files changed (1) hide show
  1. training/train_agent.py +330 -225
training/train_agent.py CHANGED
@@ -1,98 +1,116 @@
1
  """
2
  training/train_agent.py β€” SQL Database Engineer Agent
3
- Fixed version: reward_fn runs FULL EPISODES via /reset + /step
4
- This gives real delta rewards, milestones, and meaningful learning signal.
5
 
6
- FREE T4 (Colab/Kaggle): MODEL_NAME=unsloth/Qwen2.5-1.5B-Instruct
7
- VENUE A100: MODEL_NAME=unsloth/Qwen2.5-7B-Instruct
8
  """
9
 
10
- import os, json, requests, sys, time
 
 
 
 
11
  from pathlib import Path
12
 
13
- # ── GPU check + imports ───────────────────────────────────────
14
- UNSLOTH_AVAILABLE = False
15
  try:
16
- import torch
17
- if not torch.cuda.is_available():
18
- print("❌ No GPU found. Unsloth requires GPU.")
19
- sys.exit(1)
20
  from unsloth import FastLanguageModel
21
  from trl import GRPOTrainer, GRPOConfig
22
- from datasets import Dataset
23
  UNSLOTH_AVAILABLE = True
24
- print(f"βœ… GPU: {torch.cuda.get_device_name(0)}")
25
- print(f"βœ… VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
26
- except ImportError as e:
27
- print(f"❌ Import error: {e}")
28
- print("Run: pip install unsloth trl transformers datasets accelerate")
29
- sys.exit(1)
30
-
31
- # ── Config ────────────────────────────────────────────────────
 
32
  ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
33
  HF_TOKEN = os.getenv("HF_TOKEN", "")
34
- MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-1.5B-Instruct")
35
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained")
36
- MAX_STEPS = int(os.getenv("MAX_STEPS", "100"))
37
 
38
- print(f"\n[CONFIG] Model: {MODEL_NAME}")
39
- print(f"[CONFIG] ENV URL: {ENV_URL}")
40
- print(f"[CONFIG] Max steps: {MAX_STEPS}")
41
- print(f"[CONFIG] Output: {OUTPUT_DIR}\n")
 
 
 
 
42
 
43
- # ── System prompt ─────────────────────────────────────────────
44
  SYSTEM_PROMPT = """You are a senior database engineer.
45
- Given the current database state, choose the BEST next action.
46
-
47
- Rules:
48
- 1. First action MUST be inspect_query to see what's slow
49
- 2. Then analyze_indexes to see what's missing
50
- 3. Then create_index with correct table and columns
51
- 4. Then analyze_statistics to update planner
52
- 5. Finally submit_report when performance target is reached
53
-
54
- Respond with ONLY valid JSON β€” no markdown, no explanation:
55
- {"action_type": "create_index", "payload": {"table": "users", "columns": ["email"]}}"""
56
-
57
- # ── All 15 Round 2 scenario IDs ───────────────────────────────
58
- ALL_SCENARIOS = [
59
- "easy_s001", "easy_s002", "easy_s003", "easy_s004", "easy_s005",
60
- "medium_s001", "medium_s002", "medium_s003", "medium_s004", "medium_s005",
61
- "hard_s001", "hard_s002", "hard_s003", "hard_s004", "hard_s005",
62
- ]
63
-
64
- # ── Parse LLM output β†’ action dict ───────────────────────────
65
- def parse_action(text: str) -> dict:
66
  try:
67
  text = text.strip()
68
- # Strip markdown
69
- for marker in ["```json", "```"]:
70
- if marker in text:
71
- parts = text.split(marker)
72
- text = parts[1] if len(parts) > 1 else parts[0]
73
  text = text.strip()
 
74
  data = json.loads(text)
75
- if "action_type" in data and "payload" in data:
76
  return data
77
  except Exception:
78
- pass
79
- # Safe fallback
80
- return {"action_type": "inspect_query", "payload": {"query_id": "q1"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
- # ── REWARD FUNCTION β€” runs FULL EPISODE ───────────────────────
84
  def reward_fn(prompts, completions, **kwargs):
85
  """
86
- KEY FIX: Runs a full episode per completion.
87
- 1. /reset with scenario
88
- 2. Parse LLM output as action
89
- 3. /step with that action
90
- 4. Get REAL reward including delta + milestones
91
- 5. /step with submit_report to get terminal score
92
- Returns real rewards β€” not constant 0.5
93
  """
94
  rewards = []
95
- batch = kwargs.get("batch", [])
 
 
 
 
 
 
 
96
 
97
  for i, (prompt, completion) in enumerate(zip(prompts, completions)):
98
  try:
@@ -102,62 +120,65 @@ def reward_fn(prompts, completions, **kwargs):
102
  else:
103
  text = str(completion)
104
 
105
- # Pick scenario β€” rotate through all 15
106
- scenario_id = ALL_SCENARIOS[i % len(ALL_SCENARIOS)]
107
-
108
- # Parse LLM output
109
  action = parse_action(text)
 
 
 
 
110
 
111
- # Step 1: Reset environment for this scenario
112
- r = requests.post(f"{ENV_URL}/reset",
113
- json={"task_id": scenario_id}, timeout=15)
114
- if r.status_code != 200:
115
- rewards.append(0.001)
116
- continue
117
-
118
- obs = r.json()
119
- baseline = obs.get("current_context", {}).get("performance_score", 0)
120
-
121
- # Step 2: Submit the LLM's action
122
- r2 = requests.post(f"{ENV_URL}/step",
123
- json=action, timeout=15)
124
- if r2.status_code != 200:
125
  rewards.append(0.001)
 
126
  continue
127
 
128
- data = r2.json()
129
- step_score = data.get("reward", {}).get("score", 0.001)
130
- db_delta = data.get("info", {}).get("db_delta", 0)
131
- perf_score = data.get("info", {}).get("performance_score", baseline)
132
- milestones = data.get("info", {}).get("milestones", [])
133
- done = data.get("done", False)
134
-
135
- # Step 3: If not done, submit report to get terminal score
136
- if not done:
137
- r3 = requests.post(f"{ENV_URL}/step",
138
- json={"action_type": "submit_report",
139
- "payload": {"summary": "Training episode complete."}},
140
- timeout=15)
141
- if r3.status_code == 200:
142
- final_data = r3.json()
143
- terminal = final_data.get("reward", {}).get("score", step_score)
144
- # Combine step reward + terminal
145
- final_score = (step_score * 0.4) + (terminal * 0.6)
146
- else:
147
- final_score = step_score
 
 
 
 
 
 
148
  else:
149
- final_score = step_score
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # Clamp
152
- final_score = max(0.001, min(0.999, final_score))
153
- rewards.append(final_score)
154
 
155
- print(f" [REWARD] scenario={scenario_id} | "
156
- f"action={action.get('action_type')} | "
157
- f"db_delta=+{db_delta:.1f} | "
158
- f"milestones={milestones} | "
159
- f"score={final_score:.3f}")
160
 
 
 
161
  except Exception as e:
162
  print(f" [REWARD] Error: {e}")
163
  rewards.append(0.001)
@@ -165,149 +186,210 @@ def reward_fn(prompts, completions, **kwargs):
165
  return rewards
166
 
167
 
168
- # ── Build dataset with all 15 scenarios ───────────────────────
 
 
 
169
  def build_dataset():
 
170
  scenarios = []
171
- for fname in ["dataset/easy_scenarios.json",
172
- "dataset/medium_scenarios.json",
173
- "dataset/hard_scenarios.json"]:
 
 
 
174
  try:
175
  with open(fname) as f:
176
  data = json.load(f)
177
  scenarios.extend(data)
178
- print(f" Loaded {len(data)} from {fname}")
179
  except FileNotFoundError:
180
- print(f" ⚠️ {fname} not found")
181
 
182
  if not scenarios:
183
- print(" Fetching from live environment...")
184
- resp = requests.get(f"{ENV_URL}/tasks", timeout=15)
185
- tasks = resp.json().get("tasks", [])
186
- scenarios = [t for t in tasks if t["id"].startswith(("easy_s","medium_s","hard_s"))]
 
 
 
 
 
187
 
188
  examples = []
189
- for i, s in enumerate(scenarios):
190
- # Build rich prompt with full DB state
191
- tables_str = json.dumps(s.get("tables", []))
192
- queries_str = json.dumps(s.get("slow_queries", []))
193
- prompt = (
194
- f"{SYSTEM_PROMPT}\n\n"
195
- f"Scenario: {s.get('id')}\n"
196
- f"Description: {s.get('description','')}\n"
197
- f"Tables: {tables_str}\n"
198
- f"Slow Queries: {queries_str}\n"
199
- f"Performance: {s.get('performance_score_baseline',0)}/100 "
200
- f"(target: {s.get('target_score',85)})\n\n"
201
- f"What is your FIRST action?"
202
- )
203
  examples.append({
204
- "prompt": prompt,
205
- "task_id": s.get("id", ALL_SCENARIOS[i % len(ALL_SCENARIOS)]),
206
- "scenario_id": s.get("id", ALL_SCENARIOS[i % len(ALL_SCENARIOS)]),
207
  })
208
 
209
- print(f" βœ… Dataset: {len(examples)} examples")
 
 
 
 
 
 
 
 
 
 
 
210
  return Dataset.from_list(examples)
211
 
212
 
213
- # ── Generate loss + reward curve from training logs ───────────
214
- def generate_training_plots(trainer):
215
- import matplotlib
216
- matplotlib.use("Agg")
217
- import matplotlib.pyplot as plt
218
 
219
- logs = [l for l in trainer.state.log_history if "loss" in l]
220
- if not logs:
221
- print("⚠️ No training logs found for plotting")
222
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- steps = [l.get("step", i) for i, l in enumerate(logs)]
225
- losses = [l.get("loss", 0) for l in logs]
226
- rewards = [l.get("reward", l.get("train_loss", 0)) for l in logs]
227
-
228
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
229
- fig.suptitle("GRPO Training β€” SQL Database Engineer Agent",
230
- fontsize=13, fontweight="bold")
231
-
232
- ax1.plot(steps, losses, "b-o", lw=2, ms=4, label="Loss")
233
- ax1.set_xlabel("Training Step")
234
- ax1.set_ylabel("Loss")
235
- ax1.set_title("Training Loss (↓ = learning)")
236
- ax1.grid(True, alpha=0.3)
237
- ax1.legend()
238
-
239
- ax2.plot(steps, rewards, "g-o", lw=2, ms=4, label="Reward")
240
- ax2.set_xlabel("Training Step")
241
- ax2.set_ylabel("Reward")
242
- ax2.set_title("Reward During Training (↑ = improving)")
243
- ax2.grid(True, alpha=0.3)
244
- ax2.legend()
245
-
246
- plt.tight_layout()
247
- plt.savefig("loss_curve.png", dpi=150, bbox_inches="tight")
248
- print("βœ… loss_curve.png saved")
249
-
250
- # Print summary
251
- if losses:
252
- print(f" Loss: {losses[0]:.4f} β†’ {losses[-1]:.4f}")
253
- if rewards:
254
- valid_r = [r for r in rewards if r > 0]
255
- if valid_r:
256
- print(f" Reward: {valid_r[0]:.4f} β†’ {valid_r[-1]:.4f}")
257
-
258
-
259
- # ── Main training ─────────────────────────────────────────────
260
  def train():
261
- # Verify environment
 
 
 
 
 
 
 
 
262
  try:
263
  r = requests.get(f"{ENV_URL}/health", timeout=10)
264
- print(f"βœ… Environment live: v{r.json().get('version','?')}\n")
 
265
  except Exception as e:
266
- print(f"❌ Cannot reach {ENV_URL}: {e}")
 
267
  sys.exit(1)
268
 
269
- # Load model
270
- print(f"⏳ Loading {MODEL_NAME}...")
271
  model, tokenizer = FastLanguageModel.from_pretrained(
272
  model_name = MODEL_NAME,
273
  max_seq_length = 2048,
274
- load_in_4bit = True,
 
275
  token = HF_TOKEN or None,
276
  )
 
 
 
277
  model = FastLanguageModel.get_peft_model(
278
  model,
279
  r = 16,
280
  lora_alpha = 16,
281
- target_modules = ["q_proj","k_proj","v_proj","o_proj",
282
- "gate_proj","up_proj","down_proj"],
283
  lora_dropout = 0,
284
  bias = "none",
285
  use_gradient_checkpointing = "unsloth",
286
  random_state = 42,
287
  )
288
- print("βœ… Model + LoRA ready\n")
289
 
290
- # Dataset
291
- print("⏳ Building dataset...")
292
  dataset = build_dataset()
 
293
 
294
- # Reward wrapper
295
  def reward_wrapper(prompts, completions, **kwargs):
296
- return reward_fn(prompts, completions, **kwargs)
297
-
298
- # GRPO config
 
 
 
 
 
 
 
 
 
299
  config = GRPOConfig(
300
  output_dir = OUTPUT_DIR,
301
  max_steps = MAX_STEPS,
302
- per_device_train_batch_size = 1,
303
- gradient_accumulation_steps = 4,
304
  learning_rate = 5e-6,
305
- max_completion_length = 200,
306
- num_generations = 2,
307
- temperature = 0.9,
308
- logging_steps = 1,
309
- save_steps = 25,
310
- save_total_limit = 3,
311
  warmup_ratio = 0.1,
312
  report_to = "none",
313
  remove_unused_columns = False,
@@ -321,27 +403,50 @@ def train():
321
  train_dataset = dataset,
322
  )
323
 
324
- print(f"πŸ‹οΈ Starting GRPO β€” {MAX_STEPS} steps")
325
- print("Watch for: db_delta > 0 and milestones in logs\n")
 
326
  trainer.train()
327
- print("\nβœ… Training complete!")
328
 
329
- # Save
 
 
 
 
330
  Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
331
  model.save_pretrained(f"{OUTPUT_DIR}/final")
332
  tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
333
- print(f"βœ… Saved to {OUTPUT_DIR}/final")
334
-
335
- # Generate plots
336
- generate_training_plots(trainer)
337
-
338
- print("\n" + "="*55)
339
- print("NEXT STEPS:")
340
- print(" python training/evaluate_agent.py")
341
- print(" git add loss_curve.png reward_curve.png")
342
- print(" git commit -m 'Real training evidence'")
343
- print(" git push origin main")
344
- print("="*55)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
 
347
  if __name__ == "__main__":
 
1
  """
2
  training/train_agent.py β€” SQL Database Engineer Agent
3
+ Unsloth + GRPO training script.
4
+ Run on venue GPU (April 25-26) with compute credits.
5
 
6
+ FREE T4 (Colab): MODEL_NAME=unsloth/Qwen2.5-1.5B-Instruct (default)
7
+ VENUE A100: set ENV_VAR MODEL_NAME=unsloth/Qwen2.5-7B-Instruct
8
  """
9
 
10
+ import os
11
+ import json
12
+ import requests
13
+ import sys
14
+ import re
15
  from pathlib import Path
16
 
17
+ # ── Try importing Unsloth (GPU only) ─────────────────────────
 
18
  try:
 
 
 
 
19
  from unsloth import FastLanguageModel
20
  from trl import GRPOTrainer, GRPOConfig
21
+ import torch
22
  UNSLOTH_AVAILABLE = True
23
+ print("Unsloth + TRL loaded successfully")
24
+ except ImportError:
25
+ UNSLOTH_AVAILABLE = False
26
+ print("Unsloth not available. Run: pip install unsloth trl")
27
+
28
+ # ─────────────────────────────────────────────
29
+ # CONFIG β€” change MODEL_NAME via env var at venue
30
+ # ─────────────────────────────────────────────
31
+
32
  ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
33
  HF_TOKEN = os.getenv("HF_TOKEN", "")
34
+ MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-1.5B-Instruct") # 1.5B for free T4
35
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained")
36
+ MAX_STEPS = int(os.getenv("MAX_STEPS", "100")) # increase to 300+ at venue
37
 
38
+ print(f"[CONFIG] Model: {MODEL_NAME}")
39
+ print(f"[CONFIG] Output: {OUTPUT_DIR}")
40
+ print(f"[CONFIG] Max steps: {MAX_STEPS}")
41
+ print(f"[CONFIG] ENV URL: {ENV_URL}")
42
+
43
+ # ─────────────────────────────────────────────
44
+ # SYSTEM PROMPT
45
+ # ─────────────────────────────────────────────
46
 
 
47
  SYSTEM_PROMPT = """You are a senior database engineer.
48
+ Given the current database state with slow queries, choose the BEST action to improve performance.
49
+ Think step by step:
50
+ 1. If you have not inspected queries yet -> use inspect_query
51
+ 2. If you have not analyzed indexes -> use analyze_indexes
52
+ 3. If you know which index is missing -> use create_index
53
+ 4. If query can be rewritten better -> use rewrite_query
54
+ 5. If table is huge (1M+ rows) -> use partition_table
55
+ 6. When performance target is reached -> use submit_report
56
+
57
+ Respond with JSON only β€” no explanation, no markdown:
58
+ {"action_type": "...", "payload": {...}}"""
59
+
60
+
61
+ # ─────────────────────────────────────────────
62
+ # REWARD FUNCTION (calls live HF Space)
63
+ # ─────────────────────────────────────────────
64
+
65
+ def parse_action(text: str) -> dict | None:
66
+ """Parse LLM output into action dict. Returns None on failure."""
 
 
67
  try:
68
  text = text.strip()
69
+ if "```" in text:
70
+ text = text.split("```")[1]
71
+ if text.startswith("json"):
72
+ text = text[4:]
 
73
  text = text.strip()
74
+ # Try direct JSON first
75
  data = json.loads(text)
76
+ if "action_type" in data:
77
  return data
78
  except Exception:
79
+ # Try extracting first JSON object from mixed text output
80
+ match = re.search(r"\{[\s\S]*\}", text)
81
+ if match:
82
+ try:
83
+ data = json.loads(match.group(0))
84
+ if "action_type" in data:
85
+ return data
86
+ except Exception:
87
+ pass
88
+ return None
89
+
90
+
91
+ def _extract_task_id_from_prompt(prompt_text: str) -> str | None:
92
+ """Fallback extractor when GRPO doesn't pass task_id column."""
93
+ match = re.search(r"-\s*Scenario:\s*([a-z]+_[a-z]?\d+)", prompt_text, flags=re.IGNORECASE)
94
+ if match:
95
+ return match.group(1)
96
+ return None
97
 
98
 
 
99
  def reward_fn(prompts, completions, **kwargs):
100
  """
101
+ GRPO reward function β€” calls /grader on live environment.
102
+ Returns list of float rewards, one per completion.
103
+ Score always between 0.001 and 0.999.
 
 
 
 
104
  """
105
  rewards = []
106
+ task_ids = kwargs.get("task_ids")
107
+ if not task_ids:
108
+ # GRPO can pass dataset columns directly as kwargs, not always via batch.
109
+ task_ids = kwargs.get("task_id")
110
+ if not task_ids:
111
+ task_ids = ["easy_s001"] * len(prompts)
112
+ if isinstance(task_ids, str):
113
+ task_ids = [task_ids] * len(prompts)
114
 
115
  for i, (prompt, completion) in enumerate(zip(prompts, completions)):
116
  try:
 
120
  else:
121
  text = str(completion)
122
 
123
+ # Parse into action
 
 
 
124
  action = parse_action(text)
125
+ task_id = task_ids[i] if i < len(task_ids) else "easy_s001"
126
+ if not task_id:
127
+ task_id = _extract_task_id_from_prompt(str(prompt)) or "easy_s001"
128
+ task_id = str(task_id)
129
 
130
+ if action is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  rewards.append(0.001)
132
+ print(f" [REWARD] task={task_id} | action=parse_failed | score=0.001")
133
  continue
134
 
135
+ # Use environment step reward so dense + milestone logic is used.
136
+ # This also guarantees the sampled task_id actually drives reward.
137
+ difficulty = "easy"
138
+ if str(task_id).startswith("medium_"):
139
+ difficulty = "medium"
140
+ elif str(task_id).startswith("hard_"):
141
+ difficulty = "hard"
142
+
143
+ reset_resp = requests.post(
144
+ f"{ENV_URL}/reset",
145
+ json={"difficulty": difficulty, "task_id": task_id},
146
+ timeout=15,
147
+ headers={"Content-Type": "application/json"},
148
+ )
149
+ if reset_resp.status_code != 200:
150
+ raise RuntimeError(f"/reset failed for {task_id}: {reset_resp.status_code}")
151
+
152
+ step_resp = requests.post(
153
+ f"{ENV_URL}/step",
154
+ json=action,
155
+ timeout=15,
156
+ headers={"Content-Type": "application/json"},
157
+ )
158
+ if step_resp.status_code == 200:
159
+ score = step_resp.json().get("reward", {}).get("score", 0.001)
160
+ score = max(0.001, min(0.999, float(score)))
161
  else:
162
+ # Fallback to grader for robustness.
163
+ grader_resp = requests.post(
164
+ f"{ENV_URL}/grader",
165
+ json={"task_id": task_id, "action": action},
166
+ timeout=15,
167
+ headers={"Content-Type": "application/json"},
168
+ )
169
+ if grader_resp.status_code == 200:
170
+ score = grader_resp.json().get("score", 0.001)
171
+ score = max(0.001, min(0.999, float(score)))
172
+ else:
173
+ score = 0.001
174
 
175
+ action_name = str(action.get("action_type", "unknown"))
 
 
176
 
177
+ rewards.append(score)
178
+ print(f" [REWARD] task={task_id} | action={action_name} | score={score:.3f}")
 
 
 
179
 
180
+ except json.JSONDecodeError:
181
+ rewards.append(0.001)
182
  except Exception as e:
183
  print(f" [REWARD] Error: {e}")
184
  rewards.append(0.001)
 
186
  return rewards
187
 
188
 
189
+ # ─────────────────────────────────────────────
190
+ # BUILD TRAINING DATASET
191
+ # ─────────────────────────────────────────────
192
+
193
  def build_dataset():
194
+ """Build training examples from all 15 Round 2 scenarios."""
195
  scenarios = []
196
+
197
+ for fname in [
198
+ "dataset/easy_scenarios.json",
199
+ "dataset/medium_scenarios.json",
200
+ "dataset/hard_scenarios.json"
201
+ ]:
202
  try:
203
  with open(fname) as f:
204
  data = json.load(f)
205
  scenarios.extend(data)
206
+ print(f" Loaded {len(data)} scenarios from {fname}")
207
  except FileNotFoundError:
208
+ print(f"{fname} not found, skipping")
209
 
210
  if not scenarios:
211
+ print("No local scenarios found. Fetching from live environment...")
212
+ try:
213
+ resp = requests.get(f"{ENV_URL}/tasks", timeout=15)
214
+ tasks = resp.json().get("tasks", [])
215
+ scenarios = [{"id": t["id"], "description": t["description"]} for t in tasks]
216
+ print(f" Fetched {len(scenarios)} tasks from HF Space")
217
+ except Exception as e:
218
+ print(f"Could not fetch tasks: {e}")
219
+ sys.exit(1)
220
 
221
  examples = []
222
+ for s in scenarios:
223
+ prompt = f"""{SYSTEM_PROMPT}
224
+
225
+ Current Database State:
226
+ - Scenario: {s.get('id', 'unknown')}
227
+ - Description: {s.get('description', '')}
228
+ - Tables: {json.dumps(s.get('tables', []))}
229
+ - Slow Queries: {json.dumps(s.get('slow_queries', []))}
230
+ - Performance Score: {s.get('performance_score_baseline', 0)} / 100
231
+ - Target Score: {s.get('target_score', 85)}
232
+
233
+ What is your next action?"""
234
+
 
235
  examples.append({
236
+ "prompt": prompt,
237
+ "task_id": s.get("id", "easy_s001"),
 
238
  })
239
 
240
+ diff_counts = {"easy": 0, "medium": 0, "hard": 0}
241
+ for ex in examples:
242
+ tid = ex["task_id"]
243
+ if str(tid).startswith("medium_"):
244
+ diff_counts["medium"] += 1
245
+ elif str(tid).startswith("hard_"):
246
+ diff_counts["hard"] += 1
247
+ else:
248
+ diff_counts["easy"] += 1
249
+ print(f" Built {len(examples)} training examples total")
250
+ print(f" Difficulty mix: easy={diff_counts['easy']} medium={diff_counts['medium']} hard={diff_counts['hard']}")
251
+ from datasets import Dataset
252
  return Dataset.from_list(examples)
253
 
254
 
255
+ # ─────────────────────────────────────────────
256
+ # INFERENCE TEST β€” run immediately after save
257
+ # ─────────────────────────────────────────────
 
 
258
 
259
+ def test_inference(model, tokenizer):
260
+ """
261
+ REQUIRED: Test inference immediately after saving.
262
+ If this fails, the model was not saved correctly.
263
+ """
264
+ print("\n[INFERENCE TEST] Testing saved model...")
265
+ try:
266
+ FastLanguageModel.for_inference(model)
267
+
268
+ test_prompt = f"""{SYSTEM_PROMPT}
269
+
270
+ Current Database State:
271
+ - Scenario: easy_s001
272
+ - Description: User lookup query taking 2s on 10K users table
273
+ - Tables: [{{"name": "users", "rows": 10000, "indexes": ["PRIMARY"]}}]
274
+ - Slow Queries: [{{"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000}}]
275
+ - Performance Score: 8.0 / 100
276
+ - Target Score: 80.0
277
+
278
+ What is your next action?"""
279
+
280
+ inputs = tokenizer(
281
+ test_prompt,
282
+ return_tensors="pt",
283
+ truncation=True,
284
+ max_length=1024
285
+ ).to(model.device)
286
+
287
+ with torch.no_grad():
288
+ outputs = model.generate(
289
+ **inputs,
290
+ max_new_tokens = 100,
291
+ temperature = 0.3,
292
+ do_sample = True,
293
+ pad_token_id = tokenizer.eos_token_id,
294
+ )
295
+
296
+ response = tokenizer.decode(
297
+ outputs[0][inputs["input_ids"].shape[1]:],
298
+ skip_special_tokens=True
299
+ ).strip()
300
+
301
+ print(f"[INFERENCE TEST] Model output:\n {response}")
302
+
303
+ # Validate output
304
+ action = parse_action(response)
305
+ print(f"[INFERENCE TEST] Parsed action: {action}")
306
+ print("[INFERENCE TEST] PASSED β€” model saved correctly!")
307
+ return True
308
+
309
+ except Exception as e:
310
+ print(f"[INFERENCE TEST] FAILED: {e}")
311
+ print("[INFERENCE TEST] Check model save path. Do NOT proceed without fixing this.")
312
+ return False
313
+
314
+
315
+ # ─────────────────────────────────────────────
316
+ # MAIN TRAINING
317
+ # ─────────────────────────────────────────────
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  def train():
320
+ if not UNSLOTH_AVAILABLE:
321
+ print(" Cannot train β€” Unsloth not installed or no GPU found")
322
+ print("Run: pip install unsloth trl transformers datasets accelerate")
323
+ return
324
+
325
+ print(f"\n Loading model: {MODEL_NAME}")
326
+ print(f" Environment: {ENV_URL}\n")
327
+
328
+ # Verify environment is reachable
329
  try:
330
  r = requests.get(f"{ENV_URL}/health", timeout=10)
331
+ version = r.json().get("version", "?")
332
+ print(f" Environment reachable β€” version {version}")
333
  except Exception as e:
334
+ print(f" Cannot reach environment at {ENV_URL}: {e}")
335
+ print("Check ENV_URL and make sure HF Space is running.")
336
  sys.exit(1)
337
 
338
+ # ── Load model ───────────────────────────────────────────
 
339
  model, tokenizer = FastLanguageModel.from_pretrained(
340
  model_name = MODEL_NAME,
341
  max_seq_length = 2048,
342
+ load_in_4bit = True, # QLoRA β€” required for T4
343
+ dtype = None, # Auto detect
344
  token = HF_TOKEN or None,
345
  )
346
+ print(" Model loaded")
347
+
348
+ # ── Apply LoRA adapters ──────────────────────────────────
349
  model = FastLanguageModel.get_peft_model(
350
  model,
351
  r = 16,
352
  lora_alpha = 16,
353
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
354
+ "gate_proj", "up_proj", "down_proj"],
355
  lora_dropout = 0,
356
  bias = "none",
357
  use_gradient_checkpointing = "unsloth",
358
  random_state = 42,
359
  )
360
+ print(" LoRA adapters applied")
361
 
362
+ # ── Build dataset ────────────────────────────────────────
363
+ print("\n[DATASET] Building training dataset...")
364
  dataset = build_dataset()
365
+ print(f" Dataset ready: {len(dataset)} examples")
366
 
367
+ # ── Reward wrapper ───────────────────────────────────────
368
  def reward_wrapper(prompts, completions, **kwargs):
369
+ batch = kwargs.get("batch", [])
370
+ if batch and hasattr(batch[0], "get"):
371
+ task_ids = [b.get("task_id", "easy_s001") for b in batch]
372
+ elif "task_id" in kwargs and kwargs["task_id"]:
373
+ task_ids = kwargs["task_id"]
374
+ else:
375
+ task_ids = ["easy_s001"] * len(prompts)
376
+ return reward_fn(prompts, completions, task_ids=task_ids)
377
+
378
+ # ── GRPO config ──────────────────────────────────────────
379
+ # NOTE: batch_size=1, num_generations=2 for free T4
380
+ # At venue A100: increase to batch_size=2, num_generations=4
381
  config = GRPOConfig(
382
  output_dir = OUTPUT_DIR,
383
  max_steps = MAX_STEPS,
384
+ per_device_train_batch_size = 1, # 1 for T4, 2 for A100
385
+ gradient_accumulation_steps = 8,
386
  learning_rate = 5e-6,
387
+ max_completion_length = 256,
388
+ num_generations = 2, # 2 for T4, 4 for A100
389
+ temperature = 0.8,
390
+ logging_steps = 5,
391
+ save_steps = 50,
392
+ save_total_limit = 2,
393
  warmup_ratio = 0.1,
394
  report_to = "none",
395
  remove_unused_columns = False,
 
403
  train_dataset = dataset,
404
  )
405
 
406
+ # ── Train ────────────────────────────────────────────────
407
+ print(f"\nπŸ‹οΈ Starting GRPO training β€” {MAX_STEPS} steps...")
408
+ print("Watch the 'reward' column β€” it should increase over time.\n")
409
  trainer.train()
410
+ print("\n Training complete!")
411
 
412
+ # ── Save β€” ADAPTER ONLY (correct way for QLoRA) ──────────
413
+ # DO NOT call merge_and_unload() on 4-bit model
414
+ # DO NOT upcast to 16-bit and merge naively
415
+ # CORRECT: save adapter weights only, load with from_pretrained later
416
+ print(f"\n[SAVE] Saving adapter to {OUTPUT_DIR}/final ...")
417
  Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
418
  model.save_pretrained(f"{OUTPUT_DIR}/final")
419
  tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
420
+
421
+ # Save config for reference
422
+ with open(f"{OUTPUT_DIR}/final/training_config.json", "w") as f:
423
+ json.dump({
424
+ "model_name": MODEL_NAME,
425
+ "max_steps": MAX_STEPS,
426
+ "save_method": "adapter_only_qlora",
427
+ "lora_r": 16,
428
+ "lora_alpha": 16,
429
+ }, f, indent=2)
430
+ print(f" Adapter saved to {OUTPUT_DIR}/final")
431
+
432
+ # ── IMMEDIATE inference test (required) ──────────────────
433
+ passed = test_inference(model, tokenizer)
434
+
435
+ # ── Summary ──────────────────────────────────────────────
436
+ print("\n" + "="*60)
437
+ print("TRAINING COMPLETE")
438
+ print("="*60)
439
+ print(f" Model: {MODEL_NAME}")
440
+ print(f" Steps: {MAX_STEPS}")
441
+ print(f" Saved to: {OUTPUT_DIR}/final")
442
+ print(f" Save method: Adapter only (QLoRA safe)")
443
+ print(f" Inference test: {' PASSED' if passed else ' FAILED'}")
444
+ print("="*60)
445
+ print("\nNext steps:")
446
+ print(" 1. python training/evaluate_agent.py")
447
+ print(" 2. Open reward_curve.png β€” show to judges")
448
+ print(" 3. git add reward_curve.png && git commit && git push")
449
+ print("="*60)
450
 
451
 
452
  if __name__ == "__main__":