junaid0600 commited on
Commit
4edd88e
Β·
verified Β·
1 Parent(s): 028dbb9

Update training/train_agent.py

Browse files
Files changed (1) hide show
  1. training/train_agent.py +286 -285
training/train_agent.py CHANGED
@@ -1,372 +1,373 @@
1
  """
2
  training/train_agent.py β€” SQL Database Engineer Agent
3
- FIXED: Uses local DatabaseSimulator for rewards (no HF Space calls)
4
- - No shared singleton state
5
- - Real delta rewards (0.0 for wrong actions, 40-75pts for correct)
6
- - Clear reward difference teaches model to prefer create_index over inspect_query
 
 
 
 
7
  """
8
 
9
- import os, json, sys, time
10
- from pathlib import Path
 
 
 
11
 
12
- # ── GPU check ─────────────────────────────────────────────────
13
- UNSLOTH_AVAILABLE = False
14
  try:
15
- import torch
16
- if not torch.cuda.is_available():
17
- print("❌ No GPU. Unsloth requires CUDA GPU.")
18
- sys.exit(1)
19
  from unsloth import FastLanguageModel
20
  from trl import GRPOTrainer, GRPOConfig
21
- from datasets import Dataset
22
  UNSLOTH_AVAILABLE = True
23
- print(f"βœ… GPU: {torch.cuda.get_device_name(0)}")
24
- print(f"βœ… VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
25
- except ImportError as e:
26
- print(f"❌ {e}\nRun: pip install unsloth trl transformers datasets accelerate")
27
- sys.exit(1)
28
 
29
- # Add project root so we can import DatabaseSimulator
30
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
- from env.db_simulator import DatabaseSimulator
32
 
33
- # ── Config ────────────────────────────────────────────────────
34
  ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
35
  HF_TOKEN = os.getenv("HF_TOKEN", "")
36
- MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-1.5B-Instruct")
37
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained")
38
- MAX_STEPS = int(os.getenv("MAX_STEPS", "100"))
39
 
40
- print(f"\n[CONFIG] Model: {MODEL_NAME}")
41
- print(f"[CONFIG] Max steps: {MAX_STEPS}")
42
- print(f"[CONFIG] Output: {OUTPUT_DIR}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # ── System prompt ─────────────────────────────────────────────
45
- SYSTEM_PROMPT = """You are a senior database engineer fixing slow database queries.
46
- You will see slow queries and table structures. Choose the BEST action.
47
 
48
- Key insight: create_index on the RIGHT columns fixes slow queries.
49
- Wrong columns = no improvement. Right columns = massive improvement.
 
 
 
 
 
50
 
51
- Respond with ONLY valid JSON:
52
- {"action_type": "create_index", "payload": {"table": "TABLE_NAME", "columns": ["COL1", "COL2"]}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- Available actions: inspect_query, analyze_indexes, create_index, rewrite_query, analyze_statistics, submit_report"""
55
 
56
- # ── Load all 15 scenarios ─────────────────────────────────────
57
- def load_all_scenarios() -> list:
58
- scenarios = []
59
- base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
60
- for fname in ["easy_scenarios.json", "medium_scenarios.json", "hard_scenarios.json"]:
61
- path = os.path.join(base, "dataset", fname)
62
- try:
63
- with open(path) as f:
64
- data = json.load(f)
65
- scenarios.extend(data)
66
- print(f" βœ… Loaded {len(data)} from {fname}")
67
- except FileNotFoundError:
68
- print(f" ⚠️ {fname} not found")
69
- print(f" Total: {len(scenarios)} scenarios\n")
70
- return scenarios
71
 
72
- ALL_SCENARIOS = load_all_scenarios()
 
 
 
 
 
 
 
 
 
 
73
 
74
- # ── Parse LLM output ─────────────────────────────────────────
75
- def parse_action(text: str) -> dict:
76
- """Parse LLM output into action dict."""
77
- try:
78
- text = text.strip()
79
- for marker in ["```json", "```"]:
80
- if marker in text:
81
- parts = text.split(marker)
82
- text = parts[1] if len(parts) > 1 else parts[0]
83
- text = text.strip()
84
- data = json.loads(text)
85
- if "action_type" in data and "payload" in data:
86
- return data
87
- except Exception:
88
- pass
89
- return None # None = invalid JSON = penalized
90
 
 
 
 
91
 
92
- # ── LOCAL reward function using DatabaseSimulator ─────────────
93
- def compute_local_reward(action: dict, scenario: dict) -> tuple:
94
- """
95
- Compute reward LOCALLY using DatabaseSimulator.
96
- No HF Space calls. No shared state. Clean every time.
97
-
98
- Returns (reward_score, db_delta, milestone_bonus)
99
- """
100
- sim = DatabaseSimulator(scenario)
101
- baseline = sim.get_performance_score()
102
- hints = scenario.get("missing_index_hints", [])
103
-
104
- action_type = action.get("action_type", "")
105
- payload = action.get("payload", {})
106
-
107
- # Apply action to simulator
108
- if action_type == "create_index":
109
- result = sim.apply_action("create_index", payload)
110
- delta = result.get("delta", 0.0)
111
-
112
- elif action_type == "inspect_query":
113
- # Investigation β€” small reward, no DB change
114
- delta = 0.0
115
-
116
- elif action_type == "analyze_indexes":
117
- delta = 0.0
118
-
119
- elif action_type == "rewrite_query":
120
- result = sim.apply_action("rewrite_query", payload)
121
- delta = result.get("delta", 0.0)
122
-
123
- elif action_type == "analyze_statistics":
124
- result = sim.apply_action("analyze_statistics", payload)
125
- delta = result.get("delta", 0.0)
126
-
127
- elif action_type == "partition_table":
128
- result = sim.apply_action("partition_table", payload)
129
- delta = result.get("delta", 0.0)
130
-
131
- elif action_type == "submit_report":
132
- # Terminal: score based on how much DB improved so far
133
- final = sim.get_performance_score()
134
- improvement = max(0, final - baseline)
135
- delta = improvement
136
-
137
- else:
138
- delta = -5.0 # Unknown action = penalty
139
-
140
- final_score = sim.get_performance_score()
141
- improvement = max(0.0, final_score - baseline)
142
- max_possible = max(1.0, 100.0 - baseline)
143
-
144
- # ── Reward components ─────────────────────────────────────
145
- # 1. Step reward β€” different per action type
146
- step_rewards = {
147
- "inspect_query": 0.10,
148
- "analyze_indexes": 0.10,
149
- "create_index": 0.15,
150
- "rewrite_query": 0.20,
151
- "analyze_statistics":0.08,
152
- "partition_table": 0.15,
153
- "submit_report": 0.05,
154
- }
155
- step_r = step_rewards.get(action_type, 0.001)
156
-
157
- # 2. Delta reward β€” proportional to actual improvement
158
- delta_r = min(0.70, (improvement / max_possible) * 0.70)
159
-
160
- # 3. Milestone bonus β€” one-time for big improvements
161
- milestone_r = 0.0
162
- if improvement / max_possible >= 0.75:
163
- milestone_r = 0.40
164
- elif improvement / max_possible >= 0.50:
165
- milestone_r = 0.25
166
- elif improvement / max_possible >= 0.25:
167
- milestone_r = 0.15
168
-
169
- # 4. Penalty for wrong index (delta=0 on create_index)
170
- wrong_index_pen = 0.0
171
- if action_type == "create_index" and delta <= 0.0:
172
- wrong_index_pen = -0.15 # created useless index
173
-
174
- total = step_r + delta_r + milestone_r + wrong_index_pen
175
- total = max(0.001, min(0.999, total))
176
-
177
- return total, improvement, milestone_r
178
-
179
-
180
- # ── GRPO reward function ──────────────────────────────────────
181
  def reward_fn(prompts, completions, **kwargs):
182
  """
183
- LOCAL reward β€” no HTTP calls, no shared state.
184
- Each completion gets its own fresh DatabaseSimulator.
185
-
186
- Reward differences:
187
- inspect_query (always): 0.10 + 0.0 = 0.10
188
- create_index (wrong col): 0.15 - 0.15 = 0.001
189
- create_index (right col): 0.15 + 0.60 = 0.75+
190
-
191
- GRPO will learn: right create_index >> inspect_query >> wrong create_index
192
  """
193
- rewards = []
 
 
 
 
 
 
 
194
 
195
  for i, (prompt, completion) in enumerate(zip(prompts, completions)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  try:
197
- # Get text
198
- if isinstance(completion, list):
199
- text = completion[0].get("content", "") if completion else ""
200
- else:
201
- text = str(completion)
202
-
203
- # Pick scenario (rotate through all)
204
- scenario = ALL_SCENARIOS[i % len(ALL_SCENARIOS)]
205
-
206
- # Parse action
207
- action = parse_action(text)
208
-
209
- if action is None:
210
- # Invalid JSON output β€” penalize
211
- rewards.append(0.001)
212
- print(f" [REWARD] scenario={scenario['id']} | "
213
- f"INVALID JSON | score=0.001")
214
- continue
215
-
216
- # Compute reward locally
217
- score, improvement, milestone = compute_local_reward(action, scenario)
218
  rewards.append(score)
219
 
220
- print(f" [REWARD] scenario={scenario['id']} | "
221
- f"action={action.get('action_type')} | "
222
- f"improvement=+{improvement:.1f}pts | "
223
- f"milestone=+{milestone:.2f} | "
224
- f"score={score:.3f}")
225
-
226
  except Exception as e:
227
- print(f" [REWARD] Error: {e}")
228
  rewards.append(0.001)
229
 
230
  return rewards
231
 
232
 
233
- # ── Build dataset ─────────────────────────────────────────────
234
- def build_dataset() -> Dataset:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  examples = []
236
- for i, s in enumerate(ALL_SCENARIOS):
237
- tables_str = json.dumps(s.get("tables", []))
238
- queries_str = json.dumps(s.get("slow_queries", []))
239
- hints_str = json.dumps(s.get("missing_index_hints", []))
 
 
240
 
241
  prompt = (
242
  f"{SYSTEM_PROMPT}\n\n"
243
- f"=== DATABASE STATE ===\n"
244
- f"Scenario: {s['id']}\n"
245
- f"Description: {s.get('description','')}\n"
246
- f"Tables: {tables_str}\n"
247
- f"Slow Queries: {queries_str}\n"
248
- f"Missing Index Hints: {hints_str}\n"
249
- f"Performance: {s.get('performance_score_baseline',0)}/100 "
250
- f"β†’ Target: {s.get('target_score',85)}/100\n\n"
251
- f"What action should you take? Output JSON only:"
252
  )
 
253
  examples.append({
254
- "prompt": prompt,
255
- "scenario_id": s["id"],
256
  })
257
 
258
- print(f" βœ… Dataset: {len(examples)} examples")
259
  return Dataset.from_list(examples)
260
 
261
 
262
- # ── Generate plots ────────────────────────────────────────────
263
- def generate_plots(trainer):
264
- import matplotlib
265
- matplotlib.use("Agg")
266
- import matplotlib.pyplot as plt
267
-
268
- logs = [l for l in trainer.state.log_history if "loss" in l]
269
- if not logs:
270
- print("⚠️ No logs for plotting")
271
- return
272
-
273
- steps = [l.get("step", i) for i,l in enumerate(logs)]
274
- losses = [l.get("loss", 0) for l in logs]
275
 
276
- fig, ax = plt.subplots(1, 1, figsize=(8, 4))
277
- fig.suptitle("GRPO Training β€” SQL Database Engineer Agent",
278
- fontsize=13, fontweight="bold")
 
 
 
 
279
 
280
- ax.plot(steps, losses, "b-o", lw=2, ms=4)
281
- ax.set_xlabel("Training Step")
282
- ax.set_ylabel("Loss")
283
- ax.set_title("Training Loss (↓ = model learning DBA pattern)")
284
- ax.grid(True, alpha=0.3)
285
 
286
- if losses:
287
- ax.annotate(f"Start: {losses[0]:.4f}",
288
- xy=(steps[0], losses[0]),
289
- xytext=(steps[0]+1, losses[0]*1.1),
290
- fontsize=9, color="red")
291
- ax.annotate(f"End: {losses[-1]:.4f}",
292
- xy=(steps[-1], losses[-1]),
293
- xytext=(steps[-1]-5, losses[-1]*1.1),
294
- fontsize=9, color="green")
295
 
296
- plt.tight_layout()
297
- plt.savefig("loss_curve.png", dpi=150, bbox_inches="tight")
298
- print("βœ… loss_curve.png saved")
299
- print(f" Loss: {losses[0]:.4f} β†’ {losses[-1]:.4f}")
 
300
 
 
 
301
 
302
- # ── Main ──────────────────────────────────────────────────────
303
- def train():
304
- if not ALL_SCENARIOS:
305
- print("❌ No scenarios found. Check dataset/ folder.")
306
- sys.exit(1)
 
 
307
 
308
- print(f"⏳ Loading {MODEL_NAME}...")
309
  model, tokenizer = FastLanguageModel.from_pretrained(
310
  model_name = MODEL_NAME,
311
- max_seq_length = 2048,
312
  load_in_4bit = True,
313
  token = HF_TOKEN or None,
314
  )
 
 
315
  model = FastLanguageModel.get_peft_model(
316
  model,
317
- r=16, lora_alpha=16,
318
- target_modules=["q_proj","k_proj","v_proj","o_proj",
319
- "gate_proj","up_proj","down_proj"],
320
- lora_dropout=0, bias="none",
321
- use_gradient_checkpointing="unsloth",
322
- random_state=42,
 
323
  )
324
- print("βœ… Model ready\n")
325
 
 
326
  dataset = build_dataset()
327
 
 
328
  config = GRPOConfig(
329
  output_dir = OUTPUT_DIR,
330
- max_steps = MAX_STEPS,
331
- per_device_train_batch_size = 1,
332
- gradient_accumulation_steps = 4,
333
- learning_rate = 5e-6,
334
- max_completion_length = 150,
335
- num_generations = 4, # compare 4 actions per step
336
- temperature = 1.0,
337
- logging_steps = 1,
338
- save_steps = 25,
339
- save_total_limit = 3,
340
  warmup_ratio = 0.1,
341
  report_to = "none",
342
- remove_unused_columns = False,
343
  )
344
 
345
  trainer = GRPOTrainer(
346
  model = model,
347
  tokenizer = tokenizer,
348
- reward_funcs = reward_fn,
349
  args = config,
350
  train_dataset = dataset,
351
  )
352
 
353
- print(f"πŸ‹οΈ GRPO training β€” {MAX_STEPS} steps")
354
- print("Watch for: improvement > 0 and score > 0.5 on create_index\n")
 
 
 
 
 
 
355
  trainer.train()
356
- print("\nβœ… Training complete!")
357
 
358
- Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
359
  model.save_pretrained(f"{OUTPUT_DIR}/final")
360
  tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
361
- print(f"βœ… Saved to {OUTPUT_DIR}/final")
362
-
363
- generate_plots(trainer)
364
-
365
- print("\n" + "="*50)
366
- print("NEXT: python training/evaluate_agent.py")
367
- print("THEN: git add loss_curve.png reward_curve.png")
368
- print("="*50)
369
 
370
 
371
  if __name__ == "__main__":
372
- train()
 
1
  """
2
  training/train_agent.py β€” SQL Database Engineer Agent
3
+ Unsloth + GRPO training script.
4
+ Run on venue GPU (April 25-26) with compute credits.
5
+
6
+ FIXES applied:
7
+ 1. Robust JSON extraction via regex (kills PARSE FALLBACK)
8
+ 2. task_id from kwargs directly β€” not from kwargs["batch"] (kills only-easy_s001)
9
+ 3. Reward calls /grader (stateless) instead of /reset+/step (kills race condition + flat 0.500)
10
+ 4. Format bonus so valid JSON gets non-zero reward even before agent learns DBA actions
11
  """
12
 
13
+ import os
14
+ import re
15
+ import json
16
+ import requests
17
+ from datasets import Dataset
18
 
19
+ # ── Try importing Unsloth (GPU only) ─────────────────────────
 
20
  try:
 
 
 
 
21
  from unsloth import FastLanguageModel
22
  from trl import GRPOTrainer, GRPOConfig
 
23
  UNSLOTH_AVAILABLE = True
24
+ except ImportError:
25
+ UNSLOTH_AVAILABLE = False
26
+ print("⚠️ Unsloth not available. Run: pip install unsloth trl")
 
 
27
 
28
+ # ─────────────────────────────────────────────
29
+ # CONFIG
30
+ # ─────────────────────────────────────────────
31
 
 
32
  ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space")
33
  HF_TOKEN = os.getenv("HF_TOKEN", "")
34
+ MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct")
35
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained")
 
36
 
37
+ # Valid Round 2 action types β€” model must use one of these
38
+ VALID_ACTION_TYPES = {
39
+ "inspect_query", "analyze_indexes", "create_index",
40
+ "rewrite_query", "add_column", "drop_index",
41
+ "partition_table", "analyze_statistics",
42
+ "request_hint", "submit_report",
43
+ }
44
+
45
+ SYSTEM_PROMPT = """You are a senior database engineer.
46
+ Given a database scenario with slow queries, choose the BEST single action to improve performance.
47
+
48
+ Investigation pattern (follow this order):
49
+ 1. Use inspect_query to understand WHY a query is slow (scan type, rows examined)
50
+ 2. Use analyze_indexes to see what indexes exist and what is missing
51
+ 3. Use create_index to add the missing index on WHERE/JOIN columns
52
+ 4. Use rewrite_query if the SQL itself is inefficient
53
+ 5. Use partition_table for tables with 1M+ rows and range queries
54
+ 6. Use submit_report when performance target is reached
55
+
56
+ RESPOND WITH VALID JSON ONLY. No explanation. No markdown. No preamble.
57
+ Examples:
58
+ {"action_type": "inspect_query", "payload": {"query_id": "q1"}}
59
+ {"action_type": "analyze_indexes", "payload": {"table": "users"}}
60
+ {"action_type": "create_index", "payload": {"table": "users", "columns": ["email"]}}
61
+ {"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id", "status"]}}
62
+ {"action_type": "submit_report", "payload": {"summary": "Added composite index on orders(user_id, status). Performance improved from 5.0 to 85.0."}}"""
63
+
64
+
65
+ # ─────────────────────────────────────────────
66
+ # JSON EXTRACTION (FIX 1 β€” kills PARSE FALLBACK)
67
+ # ─────────────────────────────────────────────
68
+
69
+ def _extract_json(text: str) -> dict | None:
70
+ """
71
+ Robustly extract a JSON object from model output.
72
+ Handles: pure JSON, markdown blocks, JSON buried in text, partial JSON.
73
+ Returns parsed dict or None if nothing parseable found.
74
+ """
75
+ if not text:
76
+ return None
77
 
78
+ # Strip common markdown wrappers
79
+ text = text.strip()
80
+ text = re.sub(r"```(?:json)?", "", text).replace("```", "").strip()
81
 
82
+ # Try 1: entire text is valid JSON
83
+ try:
84
+ obj = json.loads(text)
85
+ if isinstance(obj, dict) and "action_type" in obj:
86
+ return obj
87
+ except json.JSONDecodeError:
88
+ pass
89
 
90
+ # Try 2: find outermost {...} block using regex (handles extra text around JSON)
91
+ matches = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)?\}', text, re.DOTALL)
92
+ for m in matches:
93
+ try:
94
+ obj = json.loads(m)
95
+ if isinstance(obj, dict) and "action_type" in obj:
96
+ return obj
97
+ except json.JSONDecodeError:
98
+ continue
99
+
100
+ # Try 3: greedy β€” find first { to last }
101
+ start = text.find("{")
102
+ end = text.rfind("}")
103
+ if start != -1 and end != -1 and end > start:
104
+ try:
105
+ obj = json.loads(text[start:end + 1])
106
+ if isinstance(obj, dict) and "action_type" in obj:
107
+ return obj
108
+ except json.JSONDecodeError:
109
+ pass
110
 
111
+ return None
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ def _is_valid_action(action: dict) -> bool:
115
+ """Check action has correct structure before sending to /grader."""
116
+ if not isinstance(action, dict):
117
+ return False
118
+ if "action_type" not in action:
119
+ return False
120
+ if action["action_type"] not in VALID_ACTION_TYPES:
121
+ return False
122
+ if "payload" not in action or not isinstance(action.get("payload"), dict):
123
+ return False
124
+ return True
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # ─────────────────────────────────────────────
128
+ # REWARD FUNCTION (FIX 2 + FIX 3)
129
+ # ─────────────────────────────────────────────
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def reward_fn(prompts, completions, **kwargs):
132
  """
133
+ GRPO reward function β€” calls /grader (STATELESS).
134
+
135
+ FIX 2: task_ids from kwargs["task_id"] directly (TRL passes dataset
136
+ columns as direct kwargs, NOT inside a "batch" key).
137
+
138
+ FIX 3: calls /grader instead of /reset + /step.
139
+ /grader is stateless β€” no race condition, no global env mutation,
140
+ no flat reward from concurrent resets overwriting each other.
 
141
  """
142
+ rewards = []
143
+
144
+ # ── FIX 2: correct task_id extraction ────────────────────────
145
+ # TRL GRPO passes dataset columns directly as kwargs.
146
+ # With num_generations=4, each task_id is repeated 4x in the list.
147
+ raw_task_ids = kwargs.get("task_id", [])
148
+ if isinstance(raw_task_ids, str):
149
+ raw_task_ids = [raw_task_ids]
150
 
151
  for i, (prompt, completion) in enumerate(zip(prompts, completions)):
152
+ task_id = (
153
+ raw_task_ids[i]
154
+ if i < len(raw_task_ids)
155
+ else "easy_s001"
156
+ )
157
+
158
+ # ── Extract text from completion ──────────────────────────
159
+ if isinstance(completion, list):
160
+ # Standard TRL format: [{"role": "assistant", "content": "..."}]
161
+ text = completion[0].get("content", "") if completion else ""
162
+ elif isinstance(completion, dict):
163
+ text = completion.get("content", "")
164
+ else:
165
+ text = str(completion)
166
+
167
+ # ── FIX 1: robust JSON parse ──────────────────────────────
168
+ action = _extract_json(text)
169
+
170
+ if action is None:
171
+ # Complete parse failure β€” 0.001 (not 0.0, avoids GRPO div-by-zero)
172
+ rewards.append(0.001)
173
+ continue
174
+
175
+ # Format bonus: valid JSON with correct structure = small positive signal
176
+ # This gives the model SOMETHING to learn from even before it learns
177
+ # the right actions, avoiding the all-zero gradient problem.
178
+ if not _is_valid_action(action):
179
+ # JSON parsed but action_type is wrong/missing
180
+ rewards.append(0.05)
181
+ continue
182
+
183
+ # ── FIX 3: stateless /grader call ────────────────────────
184
  try:
185
+ resp = requests.post(
186
+ f"{ENV_URL}/grader",
187
+ json={"task_id": task_id, "action": action},
188
+ timeout=20,
189
+ )
190
+ resp.raise_for_status()
191
+ score = float(resp.json().get("score", 0.001))
192
+ score = max(0.001, min(0.999, score))
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  rewards.append(score)
194
 
195
+ except requests.exceptions.Timeout:
196
+ rewards.append(0.05) # grader timed out β€” give format credit
 
 
 
 
197
  except Exception as e:
198
+ print(f"[reward_fn] grader call failed for {task_id}: {e}")
199
  rewards.append(0.001)
200
 
201
  return rewards
202
 
203
 
204
+ # ─────────────────────────────────────────────
205
+ # BUILD TRAINING DATASET
206
+ # ─────────────────────────────────────────────
207
+
208
+ def build_dataset():
209
+ """
210
+ Build training examples from all Round 2 scenario JSON files.
211
+ Each example: {"prompt": "...", "task_id": "easy_s001"}.
212
+ task_id is passed through to reward_fn via kwargs (TRL behaviour).
213
+ """
214
+ scenarios = []
215
+
216
+ for fname in [
217
+ "dataset/easy_scenarios.json",
218
+ "dataset/medium_scenarios.json",
219
+ "dataset/hard_scenarios.json",
220
+ ]:
221
+ try:
222
+ with open(fname) as f:
223
+ loaded = json.load(f)
224
+ scenarios.extend(loaded)
225
+ print(f" Loaded {len(loaded)} scenarios from {fname}")
226
+ except FileNotFoundError:
227
+ print(f" {fname} not found, skipping")
228
+
229
+ if not scenarios:
230
+ print(" Falling back to /tasks endpoint...")
231
+ try:
232
+ resp = requests.get(f"{ENV_URL}/tasks", timeout=15)
233
+ tasks = resp.json().get("tasks", [])
234
+ scenarios = [{"id": t["id"], "description": t.get("description", "")}
235
+ for t in tasks if "_s" in t["id"]]
236
+ except Exception as e:
237
+ print(f" /tasks fallback failed: {e}")
238
+ # Minimal fallback so training doesn't crash
239
+ scenarios = [{"id": "easy_s001",
240
+ "description": "User lookup query taking 2s. Add index.",
241
+ "tables": [{"name": "users", "rows": 10000, "indexes": ["PRIMARY"]}],
242
+ "slow_queries": [{"id": "q1", "sql": "SELECT * FROM users WHERE email=?", "avg_ms": 2000}],
243
+ "performance_score_baseline": 8.0,
244
+ "target_score": 80.0}]
245
+
246
  examples = []
247
+ for s in scenarios:
248
+ tables_txt = json.dumps(s.get("tables", []), separators=(",", ":"))
249
+ queries_txt = json.dumps(s.get("slow_queries", []), separators=(",", ":"))
250
+ baseline = s.get("performance_score_baseline", s.get("performance_score", 0))
251
+ target = s.get("target_score", 85)
252
+ max_steps = s.get("max_steps", 50)
253
 
254
  prompt = (
255
  f"{SYSTEM_PROMPT}\n\n"
256
+ f"=== DATABASE SCENARIO ===\n"
257
+ f"Scenario ID: {s.get('id', 'unknown')}\n"
258
+ f"Description: {s.get('description', '')}\n"
259
+ f"Tables: {tables_txt}\n"
260
+ f"Slow Queries: {queries_txt}\n"
261
+ f"Current Performance Score: {baseline} / 100\n"
262
+ f"Target Performance Score: {target} / 100\n"
263
+ f"Step Budget: {max_steps}\n\n"
264
+ f"What is your FIRST action?"
265
  )
266
+
267
  examples.append({
268
+ "prompt": prompt,
269
+ "task_id": s.get("id", "easy_s001"),
270
  })
271
 
272
+ print(f"Built {len(examples)} training examples from {len(scenarios)} scenarios")
273
  return Dataset.from_list(examples)
274
 
275
 
276
+ # ─────────────────────────────────────────────
277
+ # REWARD WRAPPER (FIX 2 continued)
278
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
279
 
280
+ def reward_wrapper(prompts, completions, **kwargs):
281
+ """
282
+ Thin wrapper β€” passes kwargs straight through.
283
+ TRL GRPO sends dataset columns (including task_id) as direct kwargs.
284
+ DO NOT use kwargs.get("batch") β€” that key does not exist in TRL GRPO.
285
+ """
286
+ return reward_fn(prompts, completions, **kwargs)
287
 
 
 
 
 
 
288
 
289
+ # ─────────────────────────────────────────────
290
+ # MAIN TRAINING
291
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
292
 
293
+ def train():
294
+ if not UNSLOTH_AVAILABLE:
295
+ print("Cannot train β€” Unsloth not installed")
296
+ print("Run: pip install 'unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git' trl transformers datasets accelerate")
297
+ return
298
 
299
+ print(f"πŸš€ Loading model: {MODEL_NAME}")
300
+ print(f"🌐 Environment: {ENV_URL}")
301
 
302
+ # Sanity check β€” make sure environment is reachable
303
+ try:
304
+ r = requests.get(f"{ENV_URL}/health", timeout=10)
305
+ print(f"βœ… Environment health: {r.json()}")
306
+ except Exception as e:
307
+ print(f"⚠️ Cannot reach environment at {ENV_URL}: {e}")
308
+ print(" Training will likely fail β€” check ENV_URL")
309
 
310
+ # Load model with Unsloth 4-bit quantization
311
  model, tokenizer = FastLanguageModel.from_pretrained(
312
  model_name = MODEL_NAME,
313
+ max_seq_length = 4096,
314
  load_in_4bit = True,
315
  token = HF_TOKEN or None,
316
  )
317
+
318
+ # Add LoRA adapters
319
  model = FastLanguageModel.get_peft_model(
320
  model,
321
+ r = 16,
322
+ lora_alpha = 16,
323
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
324
+ "gate_proj", "up_proj", "down_proj"],
325
+ lora_dropout = 0,
326
+ bias = "none",
327
+ use_gradient_checkpointing = "unsloth",
328
  )
 
329
 
330
+ # Build dataset
331
  dataset = build_dataset()
332
 
333
+ # GRPO config
334
  config = GRPOConfig(
335
  output_dir = OUTPUT_DIR,
336
+ num_train_epochs = 3,
337
+ per_device_train_batch_size = 2,
338
+ gradient_accumulation_steps = 8,
339
+ learning_rate = 5e-5,
340
+ max_completion_length = 256,
341
+ num_generations = 4,
342
+ logging_steps = 5,
343
+ save_steps = 50,
 
 
344
  warmup_ratio = 0.1,
345
  report_to = "none",
 
346
  )
347
 
348
  trainer = GRPOTrainer(
349
  model = model,
350
  tokenizer = tokenizer,
351
+ reward_funcs = reward_wrapper,
352
  args = config,
353
  train_dataset = dataset,
354
  )
355
 
356
+ print("πŸ‹οΈ Starting GRPO training...")
357
+ print(" Expected reward progression:")
358
+ print(" Steps 10: ~0.05-0.15 (model still outputting free text)")
359
+ print(" Steps 50: ~0.20-0.35 (learning JSON format)")
360
+ print(" Steps 100: ~0.35-0.50 (learning correct action types)")
361
+ print(" Steps 200: ~0.55-0.70 (learning DBA investigation pattern)")
362
+ print(" Steps 300: ~0.70-0.82 (strategic multi-action planning)")
363
+
364
  trainer.train()
 
365
 
366
+ # Save
367
  model.save_pretrained(f"{OUTPUT_DIR}/final")
368
  tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
369
+ print(f"βœ… Training complete. Model saved to {OUTPUT_DIR}/final")
 
 
 
 
 
 
 
370
 
371
 
372
  if __name__ == "__main__":
373
+ train()