Spaces:
Running
Running
| import os | |
| # CRITICAL: Ye line sabse upar honi chahiye kisi bhi PyTorch import se pehle! | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,7" | |
| import sys | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| sys.path.insert(0, "./server") | |
| from environment import NL2SQLEnvironment | |
| from models import NL2SQLAction | |
| from tasks import all_task_names, get_task | |
| MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo" | |
| SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite. | |
| Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries. | |
| STRICT RULES: | |
| 1. Output EXACTLY ONE valid SQLite query. | |
| 2. DO NOT wrap the query in markdown formatting (no ```sql or ```). | |
| 3. DO NOT output any explanations, conversational text, or preambles (e.g., never say "Here is the query"). | |
| 4. ONLY use standard SQLite functions. Avoid SQL Server, MySQL, or PostgreSQL specific syntax. | |
| 5. If the question implies ordering, use the correct ORDER BY clause. | |
| Your output must be executable directly against the database as-is.""" | |
| def build_dataset(): | |
| data = [] | |
| for t_name in all_task_names(): | |
| task = get_task(t_name) | |
| schema = task.schema_context() | |
| for ex in task.examples: | |
| user_content = f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}" | |
| data.append({ | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content} | |
| ], | |
| "task_name": t_name | |
| }) | |
| return Dataset.from_list(data) | |
| def sql_reward_func(prompts, completions, task_name, **kwargs): | |
| rewards = [] | |
| env = NL2SQLEnvironment() | |
| for idx, completion in enumerate(completions): | |
| generated_text = completion[0]['content'] if isinstance(completion, list) else completion | |
| if generated_text.startswith("```"): | |
| lines = generated_text.split("\n") | |
| generated_text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() | |
| current_task = task_name[idx] if isinstance(task_name, list) else task_name | |
| env.reset(task_name=current_task) | |
| try: | |
| action = NL2SQLAction(query=generated_text) | |
| obs = env.step(action) | |
| rewards.append(float(obs.reward)) | |
| except Exception: | |
| rewards.append(0.0) | |
| return rewards | |
| def main(): | |
| dataset = build_dataset() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa" # Defaulting to sdpa to avoid any flash_attn setup issues | |
| ) | |
| peft_config = LoraConfig( | |
| r=128, | |
| lora_alpha=256, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| training_args = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| max_completion_length=256, | |
| num_generations=8, | |
| temperature=0.5, | |
| bf16=True, | |
| logging_steps=5, | |
| num_train_epochs=10, | |
| report_to="none", | |
| remove_unused_columns=False, | |
| ddp_find_unused_parameters=False | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=sql_reward_func, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=peft_config, | |
| processing_class=tokenizer | |
| ) | |
| trainer.train() | |
| if trainer.accelerator.is_main_process: | |
| trainer.model.save_pretrained(f"{OUTPUT_DIR}/final") | |
| tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") | |
| if __name__ == "__main__": | |
| main() |