Spaces:
Running
Running
| import asyncio | |
| import os | |
| import sys | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| # --- Configuration --- | |
| BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" | |
| LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50" | |
| SPACE_URL = "http://localhost:8000" # Local server URL | |
| TASKS = ["simple-filter", "join-aggregation", "analytics-window"] | |
| MAX_STEPS = 5 | |
| print("Loading Base Model and LoRA weights...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_DIR) | |
| # --- System Prompt & LLM Call --- | |
| SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database. | |
| Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown.""" | |
| def call_local_llm(user_prompt: str) -> str: | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True) | |
| response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| # Strip markdown code fences if model wraps in ```sql ... ``` | |
| if response.startswith("```"): | |
| lines = response.split("\n") | |
| response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() | |
| return response if response else "SELECT 1" | |
| def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns): | |
| parts = [f"QUESTION: {question}", ""] | |
| if step > 1: | |
| parts.append(f"Your previous SQL (step {step - 1}):") | |
| parts.append(f" {' '.join(last_query.split())}") | |
| parts.append("") | |
| if last_error: | |
| parts.append(f"ERROR: {last_error}") | |
| elif last_result: | |
| preview = str(last_result[:3]).replace("\n", " ") | |
| parts.append(f"RESULT PREVIEW (first 3 rows): {preview}") | |
| parts.append(f"COLUMNS: {result_columns}") | |
| parts.append("") | |
| parts.append("Please correct or refine your query.") | |
| else: | |
| parts.append("Write a SQL query to answer the question.") | |
| return "\n".join(parts) | |
| async def main(): | |
| from client import NL2SQLEnv, NL2SQLAction | |
| all_results = [] | |
| for task_name in TASKS: | |
| print(f"\n--- Starting Task: {task_name} ---") | |
| os.environ["NL2SQL_DEFAULT_TASK"] = task_name | |
| try: | |
| async with NL2SQLEnv(base_url=SPACE_URL) as env: | |
| result = await env.reset() | |
| obs = result.observation | |
| rewards = [] | |
| success = False | |
| for step in range(1, MAX_STEPS + 1): | |
| if obs.done: | |
| break | |
| user_prompt = build_user_prompt( | |
| obs.question, obs.schema_context, step, | |
| obs.last_query, obs.last_error, obs.last_result, obs.result_columns | |
| ) | |
| sql = call_local_llm(user_prompt) | |
| print(f"Step {step} Agent Output: {sql}") | |
| step_result = await env.step(NL2SQLAction(query=sql)) | |
| obs = step_result.observation | |
| reward = obs.reward or 0.0 | |
| rewards.append(reward) | |
| print(f"Step {step} Reward: {reward}") | |
| if obs.done: | |
| break | |
| score = sum(rewards) / max(len(rewards), 1) | |
| success = score >= 0.7 | |
| print(f"Final Score for {task_name}: {score:.3f}") | |
| all_results.append({"task": task_name, "score": score, "success": success}) | |
| except Exception as e: | |
| print(f"Error testing task {task_name}: {e}") | |
| print("\n=== Final Results ===") | |
| for r in all_results: | |
| print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |