Spaces:
Running
Running
File size: 3,882 Bytes
bc20ef9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | # π SQL Debug Env: FINAL REAL-WORLD BRIDGE
# (This script automatically installs its own dependencies)
# 1. AUTO-INSTALL LIBRARIES
import os
print("π¦ Checking libraries...")
os.system("pip install trl accelerate wandb -U")
import httpx
import torch
import random
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- 2. BRIDGE CONFIGURATION ---
# Put your Localtunnel URL here
BRIDGE_URL = "https://metal-bushes-lie.loca.lt"
MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
# Headers to bypass the Localtunnel landing page
BYPASS_HEADERS = {"Bypass-Tunnel-Reminder": "true"}
# --- 3. REAL DATASET GENERATION ---
def make_real_dataset():
print(f"π Connecting to your Mac at {BRIDGE_URL}...")
tasks = ["easy_syntax_fix", "medium_logic_fix", "hard_multi_bug"]
rows = []
with httpx.Client(base_url=BRIDGE_URL, headers=BYPASS_HEADERS, timeout=30.0) as client:
for t_id in tasks:
try:
resp = client.post("/reset", json={"task_id": t_id})
obs = resp.json()["observation"]
prompt = (
"Fix the following SQL query and provide only the fixed SQL.\n"
f"Task: {obs['task_description']}\n"
f"Broken Query: {obs['original_query']}\n"
"Fixed SQL:"
)
for _ in range(10):
rows.append({"prompt": prompt, "task_id": t_id})
except Exception as e:
print(f"β οΈ Error fetching task {t_id}: {e}")
if not rows:
raise RuntimeError("Dataset is empty. Is your local server and tunnel running?")
return Dataset.from_list(rows)
# --- 4. REAL REWARD FUNCTION ---
def sql_reward_func(completions, task_id, **kwargs):
rewards = []
with httpx.Client(base_url=BRIDGE_URL, headers=BYPASS_HEADERS, timeout=30.0) as client:
for query, t_id in zip(completions, task_id):
try:
client.post("/reset", json={"task_id": t_id})
sql_part = query.split("Fixed SQL:")[-1].strip() if "Fixed SQL:" in query else query.strip()
resp = client.post("/step", json={"action": {"action_type": "submit_query", "query": sql_part}})
reward = resp.json()["reward"]
except Exception as e:
print(f"β Connection Error for {t_id}: {e}")
reward = 0.0
reward += random.uniform(-1e-6, 1e-6)
rewards.append(reward)
return rewards
# --- 5. TRAINING LOOP ---
def run_real_world_train():
print(f"π Starting Real-World GRPO on Cloud GPU...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="auto"
)
training_args = GRPOConfig(
output_dir="./real_results",
learning_rate=1e-5,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=4,
max_completion_length=64,
num_train_epochs=1,
max_steps=20,
logging_steps=1,
fp16=False,
report_to="wandb",
push_to_hub=True, # <--- NEW: Pushes logs and model to HF
hub_model_id="sql-debug-agent-7b", # <--- NEW: Your HF Model Repo Name
hub_strategy="every_save"
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[sql_reward_func],
args=training_args,
train_dataset=make_real_dataset(),
processing_class=tokenizer,
)
print("π§ Cloud Brain connected. Starting Real-World training...")
trainer.train()
if __name__ == "__main__":
run_real_world_train()
|