nl2sql-bench / local_test.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
import asyncio
import os
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# --- Configuration ---
BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50"
SPACE_URL = "http://localhost:8000" # Local server URL
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
MAX_STEPS = 5
print("Loading Base Model and LoRA weights...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, LORA_DIR)
# --- System Prompt & LLM Call ---
SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database.
Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown."""
def call_local_llm(user_prompt: str) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True)
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Strip markdown code fences if model wraps in ```sql ... ```
if response.startswith("```"):
lines = response.split("\n")
response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
return response if response else "SELECT 1"
def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns):
parts = [f"QUESTION: {question}", ""]
if step > 1:
parts.append(f"Your previous SQL (step {step - 1}):")
parts.append(f" {' '.join(last_query.split())}")
parts.append("")
if last_error:
parts.append(f"ERROR: {last_error}")
elif last_result:
preview = str(last_result[:3]).replace("\n", " ")
parts.append(f"RESULT PREVIEW (first 3 rows): {preview}")
parts.append(f"COLUMNS: {result_columns}")
parts.append("")
parts.append("Please correct or refine your query.")
else:
parts.append("Write a SQL query to answer the question.")
return "\n".join(parts)
async def main():
from client import NL2SQLEnv, NL2SQLAction
all_results = []
for task_name in TASKS:
print(f"\n--- Starting Task: {task_name} ---")
os.environ["NL2SQL_DEFAULT_TASK"] = task_name
try:
async with NL2SQLEnv(base_url=SPACE_URL) as env:
result = await env.reset()
obs = result.observation
rewards = []
success = False
for step in range(1, MAX_STEPS + 1):
if obs.done:
break
user_prompt = build_user_prompt(
obs.question, obs.schema_context, step,
obs.last_query, obs.last_error, obs.last_result, obs.result_columns
)
sql = call_local_llm(user_prompt)
print(f"Step {step} Agent Output: {sql}")
step_result = await env.step(NL2SQLAction(query=sql))
obs = step_result.observation
reward = obs.reward or 0.0
rewards.append(reward)
print(f"Step {step} Reward: {reward}")
if obs.done:
break
score = sum(rewards) / max(len(rewards), 1)
success = score >= 0.7
print(f"Final Score for {task_name}: {score:.3f}")
all_results.append({"task": task_name, "score": score, "success": success})
except Exception as e:
print(f"Error testing task {task_name}: {e}")
print("\n=== Final Results ===")
for r in all_results:
print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}")
if __name__ == "__main__":
asyncio.run(main())