junaid0600 commited on
Commit
bb2cfec
Β·
verified Β·
1 Parent(s): 188eb9c

Update training/train_agent.py

Browse files
Files changed (1) hide show
  1. training/train_agent.py +266 -347
training/train_agent.py CHANGED
@@ -1,184 +1,228 @@
1
  """
2
  training/train_agent.py β€” SQL Database Engineer Agent
3
- Unsloth + GRPO training script.
4
- Run on venue GPU (April 25-26) with compute credits.
5
-
6
- FREE T4 (Colab): MODEL_NAME=unsloth/Qwen2.5-1.5B-Instruct (default)
7
- VENUE A100: set ENV_VAR MODEL_NAME=unsloth/Qwen2.5-7B-Instruct
8
  """
9
 
10
- import os
11
- import json
12
- import requests
13
- import sys
14
- import re
15
  from pathlib import Path
16
 
17
- # ── Try importing Unsloth (GPU only) ─────────────────────────
 
18
  try:
 
 
 
 
19
  from unsloth import FastLanguageModel
20
  from trl import GRPOTrainer, GRPOConfig
21
- import torch
22
  UNSLOTH_AVAILABLE = True
23
- print("Unsloth + TRL loaded successfully")
24
- except ImportError:
25
- UNSLOTH_AVAILABLE = False
26
- print("Unsloth not available. Run: pip install unsloth trl")
 
27
 
28
- # ─────────────────────────────────────────────
29
- # CONFIG β€” change MODEL_NAME via env var at venue
30
- # ─────────────────────────────────────────────
31
 
 
32
  ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
33
  HF_TOKEN = os.getenv("HF_TOKEN", "")
34
- MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-1.5B-Instruct") # 1.5B for free T4
35
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained")
36
- MAX_STEPS = int(os.getenv("MAX_STEPS", "100")) # increase to 300+ at venue
 
 
 
 
37
 
38
- print(f"[CONFIG] Model: {MODEL_NAME}")
39
- print(f"[CONFIG] Output: {OUTPUT_DIR}")
40
- print(f"[CONFIG] Max steps: {MAX_STEPS}")
41
- print(f"[CONFIG] ENV URL: {ENV_URL}")
42
 
43
- # ─────────────────────────────────────────────
44
- # SYSTEM PROMPT
45
- # ─────────────────────────────────────────────
46
 
47
- SYSTEM_PROMPT = """You are a senior database engineer.
48
- Given the current database state with slow queries, choose the BEST action to improve performance.
49
- Think step by step:
50
- 1. If you have not inspected queries yet -> use inspect_query
51
- 2. If you have not analyzed indexes -> use analyze_indexes
52
- 3. If you know which index is missing -> use create_index
53
- 4. If query can be rewritten better -> use rewrite_query
54
- 5. If table is huge (1M+ rows) -> use partition_table
55
- 6. When performance target is reached -> use submit_report
56
 
57
- Respond with JSON only β€” no explanation, no markdown:
58
- {"action_type": "...", "payload": {...}}"""
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # ─────────────────────────────────────────────
62
- # REWARD FUNCTION (calls live HF Space)
63
- # ─────────────────────────────────────────────
64
 
65
- def parse_action(text: str) -> dict | None:
66
- """Parse LLM output into action dict. Returns None on failure."""
 
67
  try:
68
  text = text.strip()
69
- if "```" in text:
70
- text = text.split("```")[1]
71
- if text.startswith("json"):
72
- text = text[4:]
73
  text = text.strip()
74
- # Try direct JSON first
75
  data = json.loads(text)
76
- if "action_type" in data:
77
  return data
78
  except Exception:
79
- # Try extracting first JSON object from mixed text output
80
- match = re.search(r"\{[\s\S]*\}", text)
81
- if match:
82
- try:
83
- data = json.loads(match.group(0))
84
- if "action_type" in data:
85
- return data
86
- except Exception:
87
- pass
88
- return None
89
-
90
-
91
- def _extract_task_id_from_prompt(prompt_text: str) -> str | None:
92
- """Fallback extractor when GRPO doesn't pass task_id column."""
93
- match = re.search(r"-\s*Scenario:\s*([a-z]+_[a-z]?\d+)", prompt_text, flags=re.IGNORECASE)
94
- if match:
95
- return match.group(1)
96
- return None
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def reward_fn(prompts, completions, **kwargs):
100
  """
101
- GRPO reward function β€” calls /grader on live environment.
102
- Returns list of float rewards, one per completion.
103
- Score always between 0.001 and 0.999.
 
 
 
 
 
 
104
  """
105
  rewards = []
106
- task_ids = kwargs.get("task_ids")
107
- if not task_ids:
108
- # GRPO can pass dataset columns directly as kwargs, not always via batch.
109
- task_ids = kwargs.get("task_id")
110
- if not task_ids:
111
- task_ids = ["easy_s001"] * len(prompts)
112
- if isinstance(task_ids, str):
113
- task_ids = [task_ids] * len(prompts)
114
 
115
  for i, (prompt, completion) in enumerate(zip(prompts, completions)):
116
  try:
117
- # Get completion text
118
  if isinstance(completion, list):
119
  text = completion[0].get("content", "") if completion else ""
120
  else:
121
  text = str(completion)
122
 
123
- # Parse into action
 
 
 
124
  action = parse_action(text)
125
- task_id = task_ids[i] if i < len(task_ids) else "easy_s001"
126
- if not task_id:
127
- task_id = _extract_task_id_from_prompt(str(prompt)) or "easy_s001"
128
- task_id = str(task_id)
129
 
130
  if action is None:
 
131
  rewards.append(0.001)
132
- print(f" [REWARD] task={task_id} | action=parse_failed | score=0.001")
 
133
  continue
134
 
135
- # Use environment step reward so dense + milestone logic is used.
136
- # This also guarantees the sampled task_id actually drives reward.
137
- difficulty = "easy"
138
- if str(task_id).startswith("medium_"):
139
- difficulty = "medium"
140
- elif str(task_id).startswith("hard_"):
141
- difficulty = "hard"
142
-
143
- reset_resp = requests.post(
144
- f"{ENV_URL}/reset",
145
- json={"difficulty": difficulty, "task_id": task_id},
146
- timeout=15,
147
- headers={"Content-Type": "application/json"},
148
- )
149
- if reset_resp.status_code != 200:
150
- raise RuntimeError(f"/reset failed for {task_id}: {reset_resp.status_code}")
151
-
152
- step_resp = requests.post(
153
- f"{ENV_URL}/step",
154
- json=action,
155
- timeout=15,
156
- headers={"Content-Type": "application/json"},
157
- )
158
- if step_resp.status_code == 200:
159
- score = step_resp.json().get("reward", {}).get("score", 0.001)
160
- score = max(0.001, min(0.999, float(score)))
161
- else:
162
- # Fallback to grader for robustness.
163
- grader_resp = requests.post(
164
- f"{ENV_URL}/grader",
165
- json={"task_id": task_id, "action": action},
166
- timeout=15,
167
- headers={"Content-Type": "application/json"},
168
- )
169
- if grader_resp.status_code == 200:
170
- score = grader_resp.json().get("score", 0.001)
171
- score = max(0.001, min(0.999, float(score)))
172
- else:
173
- score = 0.001
174
-
175
- action_name = str(action.get("action_type", "unknown"))
176
-
177
  rewards.append(score)
178
- print(f" [REWARD] task={task_id} | action={action_name} | score={score:.3f}")
179
 
180
- except json.JSONDecodeError:
181
- rewards.append(0.001)
 
 
 
 
182
  except Exception as e:
183
  print(f" [REWARD] Error: {e}")
184
  rewards.append(0.001)
@@ -186,210 +230,113 @@ def reward_fn(prompts, completions, **kwargs):
186
  return rewards
187
 
188
 
189
- # ─────────────────────────────────────────────
190
- # BUILD TRAINING DATASET
191
- # ─────────────────────────────────────────────
192
-
193
- def build_dataset():
194
- """Build training examples from all 15 Round 2 scenarios."""
195
- scenarios = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- for fname in [
198
- "dataset/easy_scenarios.json",
199
- "dataset/medium_scenarios.json",
200
- "dataset/hard_scenarios.json"
201
- ]:
202
- try:
203
- with open(fname) as f:
204
- data = json.load(f)
205
- scenarios.extend(data)
206
- print(f" Loaded {len(data)} scenarios from {fname}")
207
- except FileNotFoundError:
208
- print(f"{fname} not found, skipping")
209
 
210
- if not scenarios:
211
- print("No local scenarios found. Fetching from live environment...")
212
- try:
213
- resp = requests.get(f"{ENV_URL}/tasks", timeout=15)
214
- tasks = resp.json().get("tasks", [])
215
- scenarios = [{"id": t["id"], "description": t["description"]} for t in tasks]
216
- print(f" Fetched {len(scenarios)} tasks from HF Space")
217
- except Exception as e:
218
- print(f"Could not fetch tasks: {e}")
219
- sys.exit(1)
220
 
221
- examples = []
222
- for s in scenarios:
223
- prompt = f"""{SYSTEM_PROMPT}
 
 
224
 
225
- Current Database State:
226
- - Scenario: {s.get('id', 'unknown')}
227
- - Description: {s.get('description', '')}
228
- - Tables: {json.dumps(s.get('tables', []))}
229
- - Slow Queries: {json.dumps(s.get('slow_queries', []))}
230
- - Performance Score: {s.get('performance_score_baseline', 0)} / 100
231
- - Target Score: {s.get('target_score', 85)}
232
 
233
- What is your next action?"""
 
234
 
235
- examples.append({
236
- "prompt": prompt,
237
- "task_id": s.get("id", "easy_s001"),
238
- })
239
 
240
- diff_counts = {"easy": 0, "medium": 0, "hard": 0}
241
- for ex in examples:
242
- tid = ex["task_id"]
243
- if str(tid).startswith("medium_"):
244
- diff_counts["medium"] += 1
245
- elif str(tid).startswith("hard_"):
246
- diff_counts["hard"] += 1
247
- else:
248
- diff_counts["easy"] += 1
249
- print(f" Built {len(examples)} training examples total")
250
- print(f" Difficulty mix: easy={diff_counts['easy']} medium={diff_counts['medium']} hard={diff_counts['hard']}")
251
- from datasets import Dataset
252
- return Dataset.from_list(examples)
253
 
 
 
 
 
 
 
 
 
 
254
 
255
- # ─────────────────────────────────────────────
256
- # INFERENCE TEST β€” run immediately after save
257
- # ─────────────────────────────────────────────
 
258
 
259
- def test_inference(model, tokenizer):
260
- """
261
- REQUIRED: Test inference immediately after saving.
262
- If this fails, the model was not saved correctly.
263
- """
264
- print("\n[INFERENCE TEST] Testing saved model...")
265
- try:
266
- FastLanguageModel.for_inference(model)
267
-
268
- test_prompt = f"""{SYSTEM_PROMPT}
269
-
270
- Current Database State:
271
- - Scenario: easy_s001
272
- - Description: User lookup query taking 2s on 10K users table
273
- - Tables: [{{"name": "users", "rows": 10000, "indexes": ["PRIMARY"]}}]
274
- - Slow Queries: [{{"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000}}]
275
- - Performance Score: 8.0 / 100
276
- - Target Score: 80.0
277
-
278
- What is your next action?"""
279
-
280
- inputs = tokenizer(
281
- test_prompt,
282
- return_tensors="pt",
283
- truncation=True,
284
- max_length=1024
285
- ).to(model.device)
286
-
287
- with torch.no_grad():
288
- outputs = model.generate(
289
- **inputs,
290
- max_new_tokens = 100,
291
- temperature = 0.3,
292
- do_sample = True,
293
- pad_token_id = tokenizer.eos_token_id,
294
- )
295
-
296
- response = tokenizer.decode(
297
- outputs[0][inputs["input_ids"].shape[1]:],
298
- skip_special_tokens=True
299
- ).strip()
300
-
301
- print(f"[INFERENCE TEST] Model output:\n {response}")
302
-
303
- # Validate output
304
- action = parse_action(response)
305
- print(f"[INFERENCE TEST] Parsed action: {action}")
306
- print("[INFERENCE TEST] PASSED β€” model saved correctly!")
307
- return True
308
-
309
- except Exception as e:
310
- print(f"[INFERENCE TEST] FAILED: {e}")
311
- print("[INFERENCE TEST] Check model save path. Do NOT proceed without fixing this.")
312
- return False
313
-
314
-
315
- # ─────────────────────────────────────────────
316
- # MAIN TRAINING
317
- # ─────────────────────────────────────────────
318
 
 
319
  def train():
320
- if not UNSLOTH_AVAILABLE:
321
- print(" Cannot train β€” Unsloth not installed or no GPU found")
322
- print("Run: pip install unsloth trl transformers datasets accelerate")
323
- return
324
-
325
- print(f"\n Loading model: {MODEL_NAME}")
326
- print(f" Environment: {ENV_URL}\n")
327
-
328
- # Verify environment is reachable
329
- try:
330
- r = requests.get(f"{ENV_URL}/health", timeout=10)
331
- version = r.json().get("version", "?")
332
- print(f" Environment reachable β€” version {version}")
333
- except Exception as e:
334
- print(f" Cannot reach environment at {ENV_URL}: {e}")
335
- print("Check ENV_URL and make sure HF Space is running.")
336
  sys.exit(1)
337
 
338
- # ── Load model ───────────────────────────────────────────
339
  model, tokenizer = FastLanguageModel.from_pretrained(
340
  model_name = MODEL_NAME,
341
  max_seq_length = 2048,
342
- load_in_4bit = True, # QLoRA β€” required for T4
343
- dtype = None, # Auto detect
344
  token = HF_TOKEN or None,
345
  )
346
- print(" Model loaded")
347
-
348
- # ── Apply LoRA adapters ──────────────────────────────────
349
  model = FastLanguageModel.get_peft_model(
350
  model,
351
- r = 16,
352
- lora_alpha = 16,
353
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
354
- "gate_proj", "up_proj", "down_proj"],
355
- lora_dropout = 0,
356
- bias = "none",
357
- use_gradient_checkpointing = "unsloth",
358
- random_state = 42,
359
  )
360
- print(" LoRA adapters applied")
361
 
362
- # ── Build dataset ────────────────────────────────────────
363
- print("\n[DATASET] Building training dataset...")
364
  dataset = build_dataset()
365
- print(f" Dataset ready: {len(dataset)} examples")
366
-
367
- # ── Reward wrapper ───────────────────────────────────────
368
- def reward_wrapper(prompts, completions, **kwargs):
369
- batch = kwargs.get("batch", [])
370
- if batch and hasattr(batch[0], "get"):
371
- task_ids = [b.get("task_id", "easy_s001") for b in batch]
372
- elif "task_id" in kwargs and kwargs["task_id"]:
373
- task_ids = kwargs["task_id"]
374
- else:
375
- task_ids = ["easy_s001"] * len(prompts)
376
- return reward_fn(prompts, completions, task_ids=task_ids)
377
-
378
- # ── GRPO config ──────────────────────────────────────────
379
- # NOTE: batch_size=1, num_generations=2 for free T4
380
- # At venue A100: increase to batch_size=2, num_generations=4
381
  config = GRPOConfig(
382
  output_dir = OUTPUT_DIR,
383
  max_steps = MAX_STEPS,
384
- per_device_train_batch_size = 1, # 1 for T4, 2 for A100
385
- gradient_accumulation_steps = 8,
386
  learning_rate = 5e-6,
387
- max_completion_length = 256,
388
- num_generations = 2, # 2 for T4, 4 for A100
389
- temperature = 0.8,
390
- logging_steps = 5,
391
- save_steps = 50,
392
- save_total_limit = 2,
393
  warmup_ratio = 0.1,
394
  report_to = "none",
395
  remove_unused_columns = False,
@@ -398,56 +345,28 @@ def train():
398
  trainer = GRPOTrainer(
399
  model = model,
400
  tokenizer = tokenizer,
401
- reward_funcs = reward_wrapper,
402
  args = config,
403
  train_dataset = dataset,
404
  )
405
 
406
- # ── Train ────────────────────────────────────────────────
407
- print(f"\nπŸ‹οΈ Starting GRPO training β€” {MAX_STEPS} steps...")
408
- print("Watch the 'reward' column β€” it should increase over time.\n")
409
  trainer.train()
410
- print("\n Training complete!")
411
 
412
- # ── Save β€” ADAPTER ONLY (correct way for QLoRA) ──────────
413
- # DO NOT call merge_and_unload() on 4-bit model
414
- # DO NOT upcast to 16-bit and merge naively
415
- # CORRECT: save adapter weights only, load with from_pretrained later
416
- print(f"\n[SAVE] Saving adapter to {OUTPUT_DIR}/final ...")
417
  Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
418
  model.save_pretrained(f"{OUTPUT_DIR}/final")
419
  tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
 
 
 
420
 
421
- # Save config for reference
422
- with open(f"{OUTPUT_DIR}/final/training_config.json", "w") as f:
423
- json.dump({
424
- "model_name": MODEL_NAME,
425
- "max_steps": MAX_STEPS,
426
- "save_method": "adapter_only_qlora",
427
- "lora_r": 16,
428
- "lora_alpha": 16,
429
- }, f, indent=2)
430
- print(f" Adapter saved to {OUTPUT_DIR}/final")
431
-
432
- # ── IMMEDIATE inference test (required) ──────────────────
433
- passed = test_inference(model, tokenizer)
434
-
435
- # ── Summary ──────────────────────────────────────────────
436
- print("\n" + "="*60)
437
- print("TRAINING COMPLETE")
438
- print("="*60)
439
- print(f" Model: {MODEL_NAME}")
440
- print(f" Steps: {MAX_STEPS}")
441
- print(f" Saved to: {OUTPUT_DIR}/final")
442
- print(f" Save method: Adapter only (QLoRA safe)")
443
- print(f" Inference test: {' PASSED' if passed else ' FAILED'}")
444
- print("="*60)
445
- print("\nNext steps:")
446
- print(" 1. python training/evaluate_agent.py")
447
- print(" 2. Open reward_curve.png β€” show to judges")
448
- print(" 3. git add reward_curve.png && git commit && git push")
449
- print("="*60)
450
 
451
 
452
  if __name__ == "__main__":
453
- train()
 
1
  """
2
  training/train_agent.py β€” SQL Database Engineer Agent
3
+ FIXED: Uses local DatabaseSimulator for rewards (no HF Space calls)
4
+ - No shared singleton state
5
+ - Real delta rewards (0.0 for wrong actions, 40-75pts for correct)
6
+ - Clear reward difference teaches model to prefer create_index over inspect_query
 
7
  """
8
 
9
+ import os, json, sys, time
 
 
 
 
10
  from pathlib import Path
11
 
12
+ # ── GPU check ─────────────────────────────────────────────────
13
+ UNSLOTH_AVAILABLE = False
14
  try:
15
+ import torch
16
+ if not torch.cuda.is_available():
17
+ print("❌ No GPU. Unsloth requires CUDA GPU.")
18
+ sys.exit(1)
19
  from unsloth import FastLanguageModel
20
  from trl import GRPOTrainer, GRPOConfig
21
+ from datasets import Dataset
22
  UNSLOTH_AVAILABLE = True
23
+ print(f"βœ… GPU: {torch.cuda.get_device_name(0)}")
24
+ print(f"βœ… VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
25
+ except ImportError as e:
26
+ print(f"❌ {e}\nRun: pip install unsloth trl transformers datasets accelerate")
27
+ sys.exit(1)
28
 
29
+ # Add project root so we can import DatabaseSimulator
30
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
+ from env.db_simulator import DatabaseSimulator
32
 
33
+ # ── Config ────────────────────────────────────────────────────
34
  ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
35
  HF_TOKEN = os.getenv("HF_TOKEN", "")
36
+ MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-1.5B-Instruct")
37
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained")
38
+ MAX_STEPS = int(os.getenv("MAX_STEPS", "100"))
39
+
40
+ print(f"\n[CONFIG] Model: {MODEL_NAME}")
41
+ print(f"[CONFIG] Max steps: {MAX_STEPS}")
42
+ print(f"[CONFIG] Output: {OUTPUT_DIR}\n")
43
 
44
+ # ── System prompt ─────────────────────────────────────────────
45
+ SYSTEM_PROMPT = """You are a senior database engineer fixing slow database queries.
46
+ You will see slow queries and table structures. Choose the BEST action.
 
47
 
48
+ Key insight: create_index on the RIGHT columns fixes slow queries.
49
+ Wrong columns = no improvement. Right columns = massive improvement.
 
50
 
51
+ Respond with ONLY valid JSON:
52
+ {"action_type": "create_index", "payload": {"table": "TABLE_NAME", "columns": ["COL1", "COL2"]}}
 
 
 
 
 
 
 
53
 
54
+ Available actions: inspect_query, analyze_indexes, create_index, rewrite_query, analyze_statistics, submit_report"""
 
55
 
56
+ # ── Load all 15 scenarios ─────────────────────────────────────
57
+ def load_all_scenarios() -> list:
58
+ scenarios = []
59
+ base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
60
+ for fname in ["easy_scenarios.json", "medium_scenarios.json", "hard_scenarios.json"]:
61
+ path = os.path.join(base, "dataset", fname)
62
+ try:
63
+ with open(path) as f:
64
+ data = json.load(f)
65
+ scenarios.extend(data)
66
+ print(f" βœ… Loaded {len(data)} from {fname}")
67
+ except FileNotFoundError:
68
+ print(f" ⚠️ {fname} not found")
69
+ print(f" Total: {len(scenarios)} scenarios\n")
70
+ return scenarios
71
 
72
+ ALL_SCENARIOS = load_all_scenarios()
 
 
73
 
74
+ # ── Parse LLM output ─────────────────────────────────────────
75
+ def parse_action(text: str) -> dict:
76
+ """Parse LLM output into action dict."""
77
  try:
78
  text = text.strip()
79
+ for marker in ["```json", "```"]:
80
+ if marker in text:
81
+ parts = text.split(marker)
82
+ text = parts[1] if len(parts) > 1 else parts[0]
83
  text = text.strip()
 
84
  data = json.loads(text)
85
+ if "action_type" in data and "payload" in data:
86
  return data
87
  except Exception:
88
+ pass
89
+ return None # None = invalid JSON = penalized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
+ # ── LOCAL reward function using DatabaseSimulator ─────────────
93
+ def compute_local_reward(action: dict, scenario: dict) -> tuple:
94
+ """
95
+ Compute reward LOCALLY using DatabaseSimulator.
96
+ No HF Space calls. No shared state. Clean every time.
97
+
98
+ Returns (reward_score, db_delta, milestone_bonus)
99
+ """
100
+ sim = DatabaseSimulator(scenario)
101
+ baseline = sim.get_performance_score()
102
+ hints = scenario.get("missing_index_hints", [])
103
+
104
+ action_type = action.get("action_type", "")
105
+ payload = action.get("payload", {})
106
+
107
+ # Apply action to simulator
108
+ if action_type == "create_index":
109
+ result = sim.apply_action("create_index", payload)
110
+ delta = result.get("delta", 0.0)
111
+
112
+ elif action_type == "inspect_query":
113
+ # Investigation β€” small reward, no DB change
114
+ delta = 0.0
115
+
116
+ elif action_type == "analyze_indexes":
117
+ delta = 0.0
118
+
119
+ elif action_type == "rewrite_query":
120
+ result = sim.apply_action("rewrite_query", payload)
121
+ delta = result.get("delta", 0.0)
122
+
123
+ elif action_type == "analyze_statistics":
124
+ result = sim.apply_action("analyze_statistics", payload)
125
+ delta = result.get("delta", 0.0)
126
+
127
+ elif action_type == "partition_table":
128
+ result = sim.apply_action("partition_table", payload)
129
+ delta = result.get("delta", 0.0)
130
+
131
+ elif action_type == "submit_report":
132
+ # Terminal: score based on how much DB improved so far
133
+ final = sim.get_performance_score()
134
+ improvement = max(0, final - baseline)
135
+ delta = improvement
136
+
137
+ else:
138
+ delta = -5.0 # Unknown action = penalty
139
+
140
+ final_score = sim.get_performance_score()
141
+ improvement = max(0.0, final_score - baseline)
142
+ max_possible = max(1.0, 100.0 - baseline)
143
+
144
+ # ── Reward components ─────────────────────────────────────
145
+ # 1. Step reward β€” different per action type
146
+ step_rewards = {
147
+ "inspect_query": 0.10,
148
+ "analyze_indexes": 0.10,
149
+ "create_index": 0.15,
150
+ "rewrite_query": 0.20,
151
+ "analyze_statistics":0.08,
152
+ "partition_table": 0.15,
153
+ "submit_report": 0.05,
154
+ }
155
+ step_r = step_rewards.get(action_type, 0.001)
156
+
157
+ # 2. Delta reward β€” proportional to actual improvement
158
+ delta_r = min(0.70, (improvement / max_possible) * 0.70)
159
+
160
+ # 3. Milestone bonus β€” one-time for big improvements
161
+ milestone_r = 0.0
162
+ if improvement / max_possible >= 0.75:
163
+ milestone_r = 0.40
164
+ elif improvement / max_possible >= 0.50:
165
+ milestone_r = 0.25
166
+ elif improvement / max_possible >= 0.25:
167
+ milestone_r = 0.15
168
+
169
+ # 4. Penalty for wrong index (delta=0 on create_index)
170
+ wrong_index_pen = 0.0
171
+ if action_type == "create_index" and delta <= 0.0:
172
+ wrong_index_pen = -0.15 # created useless index
173
+
174
+ total = step_r + delta_r + milestone_r + wrong_index_pen
175
+ total = max(0.001, min(0.999, total))
176
+
177
+ return total, improvement, milestone_r
178
+
179
+
180
+ # ── GRPO reward function ──────────────────────────────────────
181
  def reward_fn(prompts, completions, **kwargs):
182
  """
183
+ LOCAL reward β€” no HTTP calls, no shared state.
184
+ Each completion gets its own fresh DatabaseSimulator.
185
+
186
+ Reward differences:
187
+ inspect_query (always): 0.10 + 0.0 = 0.10
188
+ create_index (wrong col): 0.15 - 0.15 = 0.001
189
+ create_index (right col): 0.15 + 0.60 = 0.75+
190
+
191
+ GRPO will learn: right create_index >> inspect_query >> wrong create_index
192
  """
193
  rewards = []
 
 
 
 
 
 
 
 
194
 
195
  for i, (prompt, completion) in enumerate(zip(prompts, completions)):
196
  try:
197
+ # Get text
198
  if isinstance(completion, list):
199
  text = completion[0].get("content", "") if completion else ""
200
  else:
201
  text = str(completion)
202
 
203
+ # Pick scenario (rotate through all)
204
+ scenario = ALL_SCENARIOS[i % len(ALL_SCENARIOS)]
205
+
206
+ # Parse action
207
  action = parse_action(text)
 
 
 
 
208
 
209
  if action is None:
210
+ # Invalid JSON output β€” penalize
211
  rewards.append(0.001)
212
+ print(f" [REWARD] scenario={scenario['id']} | "
213
+ f"INVALID JSON | score=0.001")
214
  continue
215
 
216
+ # Compute reward locally
217
+ score, improvement, milestone = compute_local_reward(action, scenario)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  rewards.append(score)
 
219
 
220
+ print(f" [REWARD] scenario={scenario['id']} | "
221
+ f"action={action.get('action_type')} | "
222
+ f"improvement=+{improvement:.1f}pts | "
223
+ f"milestone=+{milestone:.2f} | "
224
+ f"score={score:.3f}")
225
+
226
  except Exception as e:
227
  print(f" [REWARD] Error: {e}")
228
  rewards.append(0.001)
 
230
  return rewards
231
 
232
 
233
+ # ── Build dataset ─────────────────────────────────────────────
234
+ def build_dataset() -> Dataset:
235
+ examples = []
236
+ for i, s in enumerate(ALL_SCENARIOS):
237
+ tables_str = json.dumps(s.get("tables", []))
238
+ queries_str = json.dumps(s.get("slow_queries", []))
239
+ hints_str = json.dumps(s.get("missing_index_hints", []))
240
+
241
+ prompt = (
242
+ f"{SYSTEM_PROMPT}\n\n"
243
+ f"=== DATABASE STATE ===\n"
244
+ f"Scenario: {s['id']}\n"
245
+ f"Description: {s.get('description','')}\n"
246
+ f"Tables: {tables_str}\n"
247
+ f"Slow Queries: {queries_str}\n"
248
+ f"Missing Index Hints: {hints_str}\n"
249
+ f"Performance: {s.get('performance_score_baseline',0)}/100 "
250
+ f"β†’ Target: {s.get('target_score',85)}/100\n\n"
251
+ f"What action should you take? Output JSON only:"
252
+ )
253
+ examples.append({
254
+ "prompt": prompt,
255
+ "scenario_id": s["id"],
256
+ })
257
 
258
+ print(f" βœ… Dataset: {len(examples)} examples")
259
+ return Dataset.from_list(examples)
 
 
 
 
 
 
 
 
 
 
260
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ # ── Generate plots ────────────────────────────────────────────
263
+ def generate_plots(trainer):
264
+ import matplotlib
265
+ matplotlib.use("Agg")
266
+ import matplotlib.pyplot as plt
267
 
268
+ logs = [l for l in trainer.state.log_history if "loss" in l]
269
+ if not logs:
270
+ print("⚠️ No logs for plotting")
271
+ return
 
 
 
272
 
273
+ steps = [l.get("step", i) for i,l in enumerate(logs)]
274
+ losses = [l.get("loss", 0) for l in logs]
275
 
276
+ fig, ax = plt.subplots(1, 1, figsize=(8, 4))
277
+ fig.suptitle("GRPO Training β€” SQL Database Engineer Agent",
278
+ fontsize=13, fontweight="bold")
 
279
 
280
+ ax.plot(steps, losses, "b-o", lw=2, ms=4)
281
+ ax.set_xlabel("Training Step")
282
+ ax.set_ylabel("Loss")
283
+ ax.set_title("Training Loss (↓ = model learning DBA pattern)")
284
+ ax.grid(True, alpha=0.3)
 
 
 
 
 
 
 
 
285
 
286
+ if losses:
287
+ ax.annotate(f"Start: {losses[0]:.4f}",
288
+ xy=(steps[0], losses[0]),
289
+ xytext=(steps[0]+1, losses[0]*1.1),
290
+ fontsize=9, color="red")
291
+ ax.annotate(f"End: {losses[-1]:.4f}",
292
+ xy=(steps[-1], losses[-1]),
293
+ xytext=(steps[-1]-5, losses[-1]*1.1),
294
+ fontsize=9, color="green")
295
 
296
+ plt.tight_layout()
297
+ plt.savefig("loss_curve.png", dpi=150, bbox_inches="tight")
298
+ print("βœ… loss_curve.png saved")
299
+ print(f" Loss: {losses[0]:.4f} β†’ {losses[-1]:.4f}")
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # ── Main ──────────────────────────────────────────────────────
303
  def train():
304
+ if not ALL_SCENARIOS:
305
+ print("❌ No scenarios found. Check dataset/ folder.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  sys.exit(1)
307
 
308
+ print(f"⏳ Loading {MODEL_NAME}...")
309
  model, tokenizer = FastLanguageModel.from_pretrained(
310
  model_name = MODEL_NAME,
311
  max_seq_length = 2048,
312
+ load_in_4bit = True,
 
313
  token = HF_TOKEN or None,
314
  )
 
 
 
315
  model = FastLanguageModel.get_peft_model(
316
  model,
317
+ r=16, lora_alpha=16,
318
+ target_modules=["q_proj","k_proj","v_proj","o_proj",
319
+ "gate_proj","up_proj","down_proj"],
320
+ lora_dropout=0, bias="none",
321
+ use_gradient_checkpointing="unsloth",
322
+ random_state=42,
 
 
323
  )
324
+ print("βœ… Model ready\n")
325
 
 
 
326
  dataset = build_dataset()
327
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  config = GRPOConfig(
329
  output_dir = OUTPUT_DIR,
330
  max_steps = MAX_STEPS,
331
+ per_device_train_batch_size = 1,
332
+ gradient_accumulation_steps = 4,
333
  learning_rate = 5e-6,
334
+ max_completion_length = 150,
335
+ num_generations = 4, # compare 4 actions per step
336
+ temperature = 1.0,
337
+ logging_steps = 1,
338
+ save_steps = 25,
339
+ save_total_limit = 3,
340
  warmup_ratio = 0.1,
341
  report_to = "none",
342
  remove_unused_columns = False,
 
345
  trainer = GRPOTrainer(
346
  model = model,
347
  tokenizer = tokenizer,
348
+ reward_funcs = reward_fn,
349
  args = config,
350
  train_dataset = dataset,
351
  )
352
 
353
+ print(f"πŸ‹οΈ GRPO training β€” {MAX_STEPS} steps")
354
+ print("Watch for: improvement > 0 and score > 0.5 on create_index\n")
 
355
  trainer.train()
356
+ print("\nβœ… Training complete!")
357
 
 
 
 
 
 
358
  Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
359
  model.save_pretrained(f"{OUTPUT_DIR}/final")
360
  tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
361
+ print(f"βœ… Saved to {OUTPUT_DIR}/final")
362
+
363
+ generate_plots(trainer)
364
 
365
+ print("\n" + "="*50)
366
+ print("NEXT: python training/evaluate_agent.py")
367
+ print("THEN: git add loss_curve.png reward_curve.png")
368
+ print("="*50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
 
371
  if __name__ == "__main__":
372
+ train()