Spaces:
Running
Running
File size: 5,910 Bytes
bc20ef9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import os
import json
import httpx
import torch
import random
from typing import List, Dict, Any
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
# We use a tiny model for local testing. In the hackathon, upgrade this to 1.5B or 7B.
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-Coder-0.5B-Instruct")
OUTPUT_DIR = "./grpo_sql_debug_results"
# ββ 1. Dataset Generation ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_dataset():
"""
Creates a training dataset by pulling observations from the live environment.
"""
print(f"[GRPO] Connecting to {ENV_URL} to build dataset...")
tasks = ["easy_syntax_fix", "medium_logic_fix", "hard_multi_bug"]
rows = []
with httpx.Client(base_url=ENV_URL, timeout=10.0) as client:
for task_id in tasks:
try:
resp = client.post("/reset", json={"task_id": task_id})
resp.raise_for_status()
obs = resp.json()["observation"]
prompt = (
"Fix the following SQL query and provide only the fixed SQL.\n"
f"Task: {obs['task_description']}\n"
f"Broken Query: {obs['original_query']}\n"
"Fixed SQL:"
)
# Each task is repeated to create a batch for the trainer
for _ in range(20):
rows.append({
"prompt": prompt,
"task_id": task_id
})
except Exception as e:
print(f"[GRPO] Failed to pull task {task_id}: {e}")
if not rows:
raise RuntimeError("Could not build dataset. Is the environment server running?")
return Dataset.from_list(rows)
# ββ 2. Reward Function βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def sql_reward_func(completions: List[str], task_id: List[str], **kwargs) -> List[float]:
"""
The heart of the Self-Improving Agent.
It submits the model's generated query to the environment and returns the reward.
"""
rewards = []
with httpx.Client(base_url=ENV_URL, timeout=5.0) as client:
# completions and task_id are lists of the same length
for query, t_id in zip(completions, task_id):
try:
# Use a unique session ID for each generation in the GRPO group
session_id = f"grpo-eval-{os.urandom(4).hex()}"
# 1. Reset to the specific task
client.post("/reset", json={"task_id": t_id}, headers={"x-session-id": session_id})
# 2. Submit the generated query
sql_part = query.split("Fixed SQL:")[-1].strip() if "Fixed SQL:" in query else query.strip()
resp = client.post(
"/step",
json={"action": {"action_type": "submit_query", "query": sql_part}},
headers={"x-session-id": session_id}
)
if resp.status_code == 200:
reward = float(resp.json().get("reward", 0.0))
else:
reward = 0.0
except Exception:
reward = 0.0
# ADD MICROSCOPIC NOISE: Prevents Zero-Variance crash
reward += random.uniform(-1e-6, 1e-6)
print(f" [REWARD] Task: {t_id:18} | Score: {reward:.4f} | Query: {query[:50].strip()}...", flush=True)
rewards.append(reward)
return rewards
# ββ 3. Training Loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def train():
print(f"[GRPO] Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
# Load model
device = "cpu" # Forcing CPU for 100% stability on Mac
print(f"[GRPO] Using device: {device} (Safe Mode)")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
).to(device)
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
learning_rate=1e-6,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=4,
max_completion_length=32, # Short and sweet
num_train_epochs=1,
max_steps=5,
logging_steps=1,
max_grad_norm=0.1, # Tightest possible clip
beta=0.01, # Low KL pressure
report_to="wandb",
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[sql_reward_func],
args=training_args,
train_dataset=make_dataset(),
processing_class=tokenizer,
)
print("[GRPO] Starting training...")
trainer.train()
print(f"[GRPO] Training complete. Saving to {OUTPUT_DIR}/final")
trainer.save_model(f"{OUTPUT_DIR}/final")
if __name__ == "__main__":
# Check if server is running
try:
httpx.get(f"{ENV_URL}/health")
train()
except Exception as e:
print(f"ERROR during training execution.")
print(f"Details: {e}")
|