Spaces:
Sleeping
Sleeping
| """ | |
| training/train_agent.py β SQL Database Engineer Agent | |
| Unsloth + GRPO β Fixed Version | |
| - 100 steps (not 700) | |
| - Local DatabaseSimulator for variance (not just /grader) | |
| - Real delta rewards, milestones shown in logs | |
| - Generates loss_curve.png automatically | |
| """ | |
| import os, re, json, sys, warnings | |
| from pathlib import Path | |
| # ββ Suppress known warnings βββββββββββββββββββββββββββββββββββ | |
| warnings.filterwarnings("ignore", message=r".*max_new_tokens.*max_length.*") | |
| warnings.filterwarnings("ignore", message=r".*AttentionMaskConverter.*", category=FutureWarning) | |
| # ββ GPU check + Unsloth βββββββββββββββββββββββββββββββββββββββ | |
| UNSLOTH_AVAILABLE = False | |
| try: | |
| import torch | |
| if not torch.cuda.is_available(): | |
| print("β No GPU. Unsloth requires CUDA.") | |
| sys.exit(1) | |
| from unsloth import FastLanguageModel | |
| from trl import GRPOTrainer, GRPOConfig | |
| from datasets import Dataset | |
| UNSLOTH_AVAILABLE = True | |
| print(f"β GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"β VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") | |
| except ImportError as e: | |
| print(f"β {e}") | |
| sys.exit(1) | |
| # ββ Add project root to path ββββββββββββββββββββββββββββββββββ | |
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.insert(0, ROOT) | |
| from env.db_simulator import DatabaseSimulator | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ENV_URL = os.getenv("ENV_URL", "https://junaid0600-sql-db-engineer-agent.hf.space") | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "unsloth/Qwen2.5-7B-Instruct") | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./sdea-trained") | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "100")) | |
| print(f"\n[CONFIG] Model: {MODEL_NAME}") | |
| print(f"[CONFIG] Max steps: {MAX_STEPS}") | |
| print(f"[CONFIG] Output: {OUTPUT_DIR}\n") | |
| VALID_ACTIONS = { | |
| "inspect_query", "analyze_indexes", "create_index", | |
| "rewrite_query", "add_column", "drop_index", | |
| "partition_table", "analyze_statistics", | |
| "request_hint", "submit_report", | |
| } | |
| SYSTEM_PROMPT = """You are a senior database engineer fixing slow database queries. | |
| Investigation pattern: | |
| 1. inspect_query β understand WHY query is slow | |
| 2. analyze_indexes β see what indexes are missing | |
| 3. create_index β add composite index on WHERE/JOIN columns | |
| 4. submit_report β when performance target is reached | |
| RESPOND WITH VALID JSON ONLY. No markdown. No explanation. | |
| Examples: | |
| {"action_type": "inspect_query", "payload": {"query_id": "q1"}} | |
| {"action_type": "create_index", "payload": {"table": "users", "columns": ["email"]}} | |
| {"action_type": "create_index", "payload": {"table": "orders", "columns": ["user_id", "status"]}}""" | |
| # ββ Load all 15 scenarios βββββββββββββββββββββββββββββββββββββ | |
| def load_scenarios() -> list: | |
| scenarios = [] | |
| for fname in ["easy_scenarios.json", "medium_scenarios.json", "hard_scenarios.json"]: | |
| path = os.path.join(ROOT, "dataset", fname) | |
| try: | |
| with open(path) as f: | |
| data = json.load(f) | |
| scenarios.extend(data) | |
| print(f" β {len(data)} from {fname}") | |
| except FileNotFoundError: | |
| print(f" β οΈ {fname} not found") | |
| print(f" Total: {len(scenarios)} scenarios\n") | |
| return scenarios | |
| ALL_SCENARIOS = load_scenarios() | |
| # ββ Parse LLM output β action dict βββββββββββββββββββββββββββ | |
| def parse_action(text: str) -> dict | None: | |
| if not text: | |
| return None | |
| text = text.strip() | |
| text = re.sub(r"```(?:json)?", "", text).replace("```", "").strip() | |
| # Try full parse | |
| try: | |
| obj = json.loads(text) | |
| if isinstance(obj, dict) and "action_type" in obj: | |
| return obj | |
| except json.JSONDecodeError: | |
| pass | |
| # Try extract from partial text | |
| for start in [i for i,c in enumerate(text) if c == "{"]: | |
| depth, in_str, escape = 0, False, False | |
| for i in range(start, len(text)): | |
| ch = text[i] | |
| if in_str: | |
| if escape: escape = False | |
| elif ch == "\\": escape = True | |
| elif ch == '"': in_str = False | |
| continue | |
| if ch == '"': in_str = True; continue | |
| if ch == "{": depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| try: | |
| obj = json.loads(text[start:i+1]) | |
| if isinstance(obj, dict) and "action_type" in obj: | |
| return obj | |
| except: pass | |
| break | |
| return None | |
| def compute_reward(action: dict, scenario: dict) -> tuple: | |
| """ | |
| Compute reward LOCALLY β no HTTP, no shared state, deterministic. | |
| Returns (score, improvement_pts, milestone_bonus, description) | |
| """ | |
| sim = DatabaseSimulator(scenario) | |
| baseline = sim.get_performance_score() | |
| action_type = action.get("action_type", "") | |
| payload = action.get("payload", {}) | |
| # Apply action | |
| delta = 0.0 | |
| if action_type == "create_index": | |
| result = sim.apply_action("create_index", payload) | |
| delta = result.get("delta", 0.0) | |
| # ββ Hint-match fallback ββββββββββββββββββββββββββββββββββββββββββββ | |
| # When the simulator returns delta=0 (wrong/partial columns), check | |
| # the scenario's missing_index_hints and simulate a proportional delta. | |
| # This gives GRPO a gradient signal early in training before the model | |
| # has learned the exact right column names. | |
| if delta <= 0.0: | |
| hints = scenario.get("missing_index_hints", []) | |
| p_table = payload.get("table", "") | |
| p_cols = set(str(c).lower() for c in payload.get("columns", [])) | |
| best_match = 0.0 | |
| for hint in hints: | |
| h_table = hint.get("table", "") | |
| h_cols = set(str(c).lower() for c in hint.get("columns", [])) | |
| if not h_cols: | |
| continue | |
| table_ok = 1.0 if h_table == p_table else 0.3 | |
| col_overlap = len(p_cols & h_cols) / max(len(h_cols), 1) | |
| best_match = max(best_match, table_ok * 0.4 + col_overlap * 0.6) | |
| if best_match > 0: | |
| max_gap = max(1.0, 100.0 - baseline) | |
| delta = best_match * max_gap * 0.65 # up to 65% of gap on perfect match β reaches 25% milestone | |
| elif action_type == "rewrite_query": | |
| result = sim.apply_action("rewrite_query", payload) | |
| delta = result.get("delta", 0.0) | |
| elif action_type == "partition_table": | |
| result = sim.apply_action("partition_table", payload) | |
| delta = result.get("delta", 0.0) | |
| elif action_type == "analyze_statistics": | |
| result = sim.apply_action("analyze_statistics", payload) | |
| delta = result.get("delta", 0.0) | |
| elif action_type in ("inspect_query", "analyze_indexes"): | |
| delta = 0.0 # investigation β no DB change | |
| elif action_type == "submit_report": | |
| delta = max(0, sim.get_performance_score() - baseline) | |
| final = sim.get_performance_score() | |
| sim_improve = max(0.0, final - baseline) | |
| # Use hint-simulated delta if simulator returned 0 (wrong columns early in training) | |
| improvement = max(sim_improve, delta) | |
| max_possible = max(1.0, 100.0 - baseline) | |
| ratio = improvement / max_possible | |
| # Step reward | |
| step_r = { | |
| "inspect_query": 0.10, | |
| "analyze_indexes": 0.10, | |
| "create_index": 0.15, | |
| "rewrite_query": 0.20, | |
| "analyze_statistics":0.08, | |
| "partition_table": 0.15, | |
| "submit_report": 0.05, | |
| }.get(action_type, 0.001) | |
| # Delta reward β key signal | |
| delta_r = min(0.65, ratio * 0.65) | |
| # Milestone bonus β cumulative, all thresholds crossed shown + rewarded | |
| milestone = 0.0 | |
| milestone_parts = [] | |
| if ratio >= 0.25: | |
| milestone += 0.15 | |
| milestone_parts.append("25%") | |
| if ratio >= 0.50: | |
| milestone += 0.25 | |
| milestone_parts.append("50%") | |
| if ratio >= 0.75: | |
| milestone += 0.40 | |
| milestone_parts.append("75%") | |
| milestone_str = ("π― " + "+".join(milestone_parts) + " milestone!") if milestone_parts else "" | |
| # Wrong index penalty reduced: -0.05 instead of -0.15 | |
| # Old value exactly cancelled step_r (0.15 - 0.15 = 0 β clamped to 0.001) | |
| # Now wrong create_index still scores ~0.10 instead of 0.001 | |
| wrong_pen = -0.05 if (action_type == "create_index" and delta <= 0.0) else 0.0 | |
| total = max(0.001, min(0.999, step_r + delta_r + milestone + wrong_pen)) | |
| src = "sim" if sim_improve > 0 else ("hint" if delta > 0 else "none") | |
| desc = f"+{improvement:.1f}pts delta={delta:.1f}[{src}] {milestone_str}" | |
| return total, improvement, milestone, desc | |
| # ββ GRPO reward function ββββββββββββββββββββββββββββββββββββββ | |
| def reward_fn(prompts, completions, **kwargs): | |
| """ | |
| LOCAL reward β DatabaseSimulator directly. | |
| Rewards vary 0.001 to 0.999 giving GRPO real gradient signal. | |
| """ | |
| rewards = [] | |
| for i, (prompt, completion) in enumerate(zip(prompts, completions)): | |
| try: | |
| # Get text | |
| if isinstance(completion, list): | |
| text = completion[0].get("content","") if completion else "" | |
| else: | |
| text = str(completion) | |
| # Pick scenario | |
| scenario = ALL_SCENARIOS[i % len(ALL_SCENARIOS)] | |
| sid = scenario["id"] | |
| # Parse | |
| action = parse_action(text) | |
| if action is None: | |
| print(f" [REWARD] {sid} | INVALID JSON | score=0.001") | |
| rewards.append(0.001) | |
| continue | |
| if action.get("action_type") not in VALID_ACTIONS: | |
| print(f" [REWARD] {sid} | UNKNOWN ACTION | score=0.05") | |
| rewards.append(0.05) | |
| continue | |
| # Compute locally | |
| score, improvement, milestone, desc = compute_reward(action, scenario) | |
| rewards.append(score) | |
| print(f" [REWARD] {sid} | " | |
| f"action={action['action_type']} | " | |
| f"{desc} | score={score:.3f}") | |
| except Exception as e: | |
| print(f" [REWARD] Error: {e}") | |
| rewards.append(0.001) | |
| return rewards | |
| # ββ Build dataset βββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_dataset() -> Dataset: | |
| examples = [] | |
| for s in ALL_SCENARIOS: | |
| tables_str = json.dumps(s.get("tables", [])) | |
| queries_str = json.dumps(s.get("slow_queries", [])) | |
| hints_str = json.dumps(s.get("missing_index_hints", [])) | |
| prompt = ( | |
| f"{SYSTEM_PROMPT}\n\n" | |
| f"=== DATABASE STATE ===\n" | |
| f"Scenario: {s['id']} β {s.get('description','')}\n" | |
| f"Tables: {tables_str}\n" | |
| f"Slow Queries: {queries_str}\n" | |
| f"Missing Index Hints: {hints_str}\n" | |
| f"Performance: {s.get('performance_score_baseline',0)}/100 " | |
| f"β Target: {s.get('target_score',85)}/100\n\n" | |
| f"What is your next action? JSON only:" | |
| ) | |
| examples.append({"prompt": prompt, "scenario_id": s["id"]}) | |
| print(f" β Dataset: {len(examples)} examples") | |
| return Dataset.from_list(examples) | |
| # ββ Generate loss + reward plots ββββββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # REPLACE ONLY THIS FUNCTION in your existing train_agent.py | |
| # Find: def generate_plots(trainer): | |
| # Replace the entire function with this: | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_plots(trainer): | |
| import json, numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| logs = [l for l in trainer.state.log_history if "loss" in l] | |
| if not logs: | |
| print("β οΈ No logs to plot") | |
| return | |
| steps = [l.get("step", i) for i, l in enumerate(logs)] | |
| losses = [l.get("loss", 0) for l in logs] | |
| rewards = [l.get("reward", 0) for l in logs] | |
| # Save logs for generate_plots.py | |
| import os, json | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| with open(f"{OUTPUT_DIR}/training_logs.json", "w") as f: | |
| json.dump(trainer.state.log_history, f) | |
| print(f"β Logs saved to {OUTPUT_DIR}/training_logs.json") | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5)) | |
| fig.suptitle("GRPO Training β SQL Database Engineer Agent\n" | |
| "Qwen2.5-7B fine-tuned with Unsloth + TRL", | |
| fontsize=13, fontweight="bold") | |
| # ββ LEFT: Loss with smoothing ββββββββββββββββββββββββββββββ | |
| ax1.plot(steps, losses, "b-", lw=1.0, alpha=0.35, label="Raw loss") | |
| if len(losses) >= 10: | |
| smooth = np.convolve(losses, np.ones(10)/10, mode="valid") | |
| ax1.plot(steps[9:], smooth, "b-", lw=2.5, label="10-step avg") | |
| ax1.set_xlabel("Training Step") | |
| ax1.set_ylabel("Loss") | |
| ax1.set_title("Training Loss β = model learning") | |
| ax1.ticklabel_format(style="sci", axis="y", scilimits=(0,0)) | |
| ax1.grid(True, alpha=0.3) | |
| ax1.legend(fontsize=9) | |
| # FIX 1: scientific notation annotations | |
| if losses: | |
| ax1.annotate(f"Start: {losses[0]:.2e}", | |
| xy=(steps[0], losses[0]), | |
| xytext=(steps[0]+3, max(losses)*0.85), | |
| fontsize=8, color="red", | |
| arrowprops=dict(arrowstyle="->", color="red", lw=1)) | |
| ax1.annotate(f"End: {losses[-1]:.2e}", | |
| xy=(steps[-1], losses[-1]), | |
| xytext=(steps[-1]-12, max(losses)*0.65), | |
| fontsize=8, color="green", | |
| arrowprops=dict(arrowstyle="->", color="green", lw=1)) | |
| # ββ RIGHT: Reward with smoothing βββββββββββββββββββββββββββ | |
| ax2.plot(steps, rewards, "g-", lw=1.0, alpha=0.35, label="Raw reward") | |
| if len(rewards) >= 10: | |
| smooth_r = np.convolve(rewards, np.ones(10)/10, mode="valid") | |
| ax2.plot(steps[9:], smooth_r, "g-", lw=2.5, label="10-step avg") | |
| ax2.set_xlabel("Training Step") | |
| ax2.set_ylabel("Avg Reward") | |
| ax2.set_title("Reward During Training β = improving") | |
| ax2.grid(True, alpha=0.3) | |
| ax2.legend(fontsize=9) | |
| # Bottom summary | |
| if losses and rewards: | |
| start_r = rewards[0] | |
| end_r = rewards[-1] | |
| pct = ((end_r-start_r)/max(abs(start_r),1e-9))*100 | |
| fig.text(0.5, 0.01, | |
| f"Loss: {losses[0]:.2e} β {losses[-1]:.2e} | " | |
| f"Reward: {start_r:.3f} β {end_r:.3f} ({'+'if pct>=0 else ''}{pct:.0f}%)", | |
| ha="center", fontsize=10, | |
| bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8)) | |
| plt.tight_layout(rect=[0, 0.07, 1, 1]) | |
| plt.savefig("loss_curve.png", dpi=150, bbox_inches="tight") | |
| print("β loss_curve.png saved") | |
| if losses: | |
| print(f" Loss: {losses[0]:.2e} β {losses[-1]:.2e}") | |
| if rewards: | |
| valid = [r for r in rewards if r > 0] | |
| if valid: | |
| print(f" Reward: {valid[0]:.3f} β {valid[-1]:.3f}") | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(): | |
| if not ALL_SCENARIOS: | |
| print("β No scenarios loaded. Check dataset/ folder.") | |
| sys.exit(1) | |
| print(f"β³ Loading {MODEL_NAME}...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = MODEL_NAME, | |
| max_seq_length = 2048, | |
| load_in_4bit = True, | |
| token = HF_TOKEN or None, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, lora_alpha=16, | |
| target_modules=["q_proj","k_proj","v_proj","o_proj", | |
| "gate_proj","up_proj","down_proj"], | |
| lora_dropout=0, bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=42, | |
| ) | |
| print("β Model + LoRA ready\n") | |
| dataset = build_dataset() | |
| config = GRPOConfig( | |
| output_dir = OUTPUT_DIR, | |
| max_steps = MAX_STEPS, | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = 2, | |
| learning_rate = 2e-5, | |
| max_completion_length = 150, | |
| num_generations = 4, | |
| temperature = 1.0, | |
| logging_steps = 1, | |
| save_steps = 25, | |
| save_total_limit = 3, | |
| warmup_steps = 10, | |
| report_to = "none", | |
| remove_unused_columns = False, | |
| ) | |
| trainer = GRPOTrainer( | |
| model = model, | |
| tokenizer = tokenizer, | |
| reward_funcs = reward_fn, | |
| args = config, | |
| train_dataset = dataset, | |
| ) | |
| print(f"ποΈ GRPO training β {MAX_STEPS} steps") | |
| print("Expected rewards:") | |
| print(" inspect_query / analyze_indexes: ~0.10") | |
| print(" create_index (no table/col match): ~0.10 (was 0.001)") | |
| print(" create_index (partial hint match): ~0.20-0.45") | |
| print(" create_index (perfect hint match): ~0.55-0.80") | |
| print(" create_index (simulator confirms): ~0.75-0.99") | |
| print(" Milestones: 25%=+0.15 50%=+0.25 75%=+0.40 (cumulative)\n") | |
| trainer.train() | |
| print("\nβ Training complete!") | |
| Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(f"{OUTPUT_DIR}/final") | |
| tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") | |
| print(f"β Saved to {OUTPUT_DIR}/final") | |
| generate_plots(trainer) | |
| print("\n" + "="*50) | |
| print("NEXT:") | |
| print(" python training/evaluate_agent.py") | |
| print(" git add loss_curve.png reward_curve.png") | |
| print(" git commit -m 'Real GRPO training evidence'") | |
| print(" git push origin main") | |
| print("="*50) | |
| if __name__ == "__main__": | |
| train() |