File size: 4,647 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import asyncio
import os
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# --- Configuration ---
BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50"
SPACE_URL = "http://localhost:8000" # Local server URL
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
MAX_STEPS = 5

print("Loading Base Model and LoRA weights...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
model = PeftModel.from_pretrained(base_model, LORA_DIR)

# --- System Prompt & LLM Call ---
SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database.

Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown."""

def call_local_llm(user_prompt: str) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True)
    
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    
    # Strip markdown code fences if model wraps in ```sql ... ```
    if response.startswith("```"):
        lines = response.split("\n")
        response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
    return response if response else "SELECT 1"

def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns):
    parts = [f"QUESTION: {question}", ""]
    if step > 1:
        parts.append(f"Your previous SQL (step {step - 1}):")
        parts.append(f"  {' '.join(last_query.split())}")
        parts.append("")
        if last_error:
            parts.append(f"ERROR: {last_error}")
        elif last_result:
            preview = str(last_result[:3]).replace("\n", " ")
            parts.append(f"RESULT PREVIEW (first 3 rows): {preview}")
            parts.append(f"COLUMNS: {result_columns}")
        parts.append("")
        parts.append("Please correct or refine your query.")
    else:
        parts.append("Write a SQL query to answer the question.")
    return "\n".join(parts)

async def main():
    from client import NL2SQLEnv, NL2SQLAction
    
    all_results = []
    
    for task_name in TASKS:
        print(f"\n--- Starting Task: {task_name} ---")
        os.environ["NL2SQL_DEFAULT_TASK"] = task_name
        
        try:
            async with NL2SQLEnv(base_url=SPACE_URL) as env:
                result = await env.reset()
                obs = result.observation
                
                rewards = []
                success = False
                
                for step in range(1, MAX_STEPS + 1):
                    if obs.done:
                        break
                        
                    user_prompt = build_user_prompt(
                        obs.question, obs.schema_context, step, 
                        obs.last_query, obs.last_error, obs.last_result, obs.result_columns
                    )
                    
                    sql = call_local_llm(user_prompt)
                    
                    print(f"Step {step} Agent Output: {sql}")
                    
                    step_result = await env.step(NL2SQLAction(query=sql))
                    obs = step_result.observation
                    
                    reward = obs.reward or 0.0
                    rewards.append(reward)
                    print(f"Step {step} Reward: {reward}")
                    
                    if obs.done:
                        break
                
                score = sum(rewards) / max(len(rewards), 1)
                success = score >= 0.7
                print(f"Final Score for {task_name}: {score:.3f}")
                all_results.append({"task": task_name, "score": score, "success": success})
                
        except Exception as e:
            print(f"Error testing task {task_name}: {e}")
            
    print("\n=== Final Results ===")
    for r in all_results:
        print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}")

if __name__ == "__main__":
    asyncio.run(main())