Spaces:
Running
Running
| """GRPO Training for Autonomic DBRE — Schema-aware prompts.""" | |
| import os, json, time, torch | |
| os.environ["DB_USER"] = os.getenv("DB_USER", "dbre_admin") | |
| os.environ["DB_PASSWORD"] = os.getenv("DB_PASSWORD", "dbre_pass") | |
| from dbre.environment import DBREEnvironment, DBREAction | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, get_peft_model | |
| from trl import GRPOConfig, GRPOTrainer | |
| from datasets import Dataset | |
| SCHEMA_PROMPT = """You are a SQL optimization expert. Given a slow query, rewrite it to be faster. | |
| Database schema: | |
| - customers(customer_id, name, email, city, created_at) | |
| - products(product_id, name, category, price, stock) | |
| - orders(order_id, customer_id, order_date, status) | |
| - order_items(item_id, order_id, product_id, quantity, unit_price) | |
| - reviews(review_id, customer_id, product_id, rating, review_text, created_at) | |
| Rules: | |
| - Use only the tables and columns above | |
| - Add JOINs with proper ON conditions | |
| - No SELECT * | |
| - Use specific column names | |
| - Add WHERE clauses to filter rows | |
| - Use LIMIT for large result sets | |
| - Use indexes: customers(email), orders(customer_id), order_items(order_id), reviews(customer_id) | |
| Rewrite this slow query to be more efficient. Output ONLY the SQL, no explanation.""" | |
| # Load model | |
| print("Loading Qwen2.5-Coder-1.5B...") | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) | |
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct", quantization_config=bnb, device_map="auto", trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct", trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = get_peft_model(model, LoraConfig(r=16, lora_alpha=16, lora_dropout=0, target_modules=["q_proj","k_proj","v_proj","o_proj"], bias="none", task_type="CAUSAL_LM")) | |
| print("Model loaded.") | |
| # Reward | |
| def dbre_reward(completions, **kwargs): | |
| env = DBREEnvironment({'max_steps': 5}) | |
| rewards = [] | |
| for c in completions: | |
| obs = env.reset() | |
| try: | |
| raw = str(c).strip() | |
| # Extract SQL: find SELECT and take everything after it | |
| if "SELECT" in raw.upper(): | |
| idx = raw.upper().find("SELECT") | |
| sql = raw[idx:].split(";")[0] + ";" | |
| sql = sql[:500] | |
| else: | |
| sql = raw[:500] | |
| a = DBREAction(action_type="rewrite_query", new_sql=sql) | |
| _, r, _, info = env.step(a) | |
| rb = info.get('reward_breakdown', {}) | |
| score = rb.get('efficiency', 0) * 0.5 + rb.get('correctness', 0) * 0.3 + rb.get('style', 0) * 0.2 | |
| scaled = min(1.0, max(0.0, score * 2.0 - 0.2)) | |
| rewards.append(float(scaled)) | |
| except Exception as e: | |
| rewards.append(0.0) | |
| return rewards | |
| # Dataset with schema-aware prompts | |
| dummy = Dataset.from_dict({"prompt": [SCHEMA_PROMPT] * 100}) | |
| # Train | |
| trainer = GRPOTrainer( | |
| model=model, processing_class=tokenizer, | |
| args=GRPOConfig(output_dir="./grpo_dbre", num_train_epochs=1, per_device_train_batch_size=2, | |
| gradient_accumulation_steps=8, learning_rate=5e-5, logging_steps=5, | |
| save_steps=50, max_steps=500, bf16=True, report_to="none"), | |
| train_dataset=dummy, reward_funcs=[dbre_reward], | |
| ) | |
| print("Training with schema-aware prompts...") | |
| trainer.train() | |
| model.save_pretrained("./dbre_trained") | |
| tokenizer.save_pretrained("./dbre_trained") | |
| print("Done. Model saved to ./dbre_trained") | |