File size: 3,692 Bytes
22654ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL = "jinesh90/qwen2.5-coder-sql-generator"

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype = torch.float16,
    device_map = "auto",
    low_cpu_mem_usage = True,
)
model.eval()
print("Ready!")

def clean_sql(text):
    text = text.strip()
    clean = re.sub(r'[^\x00-\x7F].*', '', text).strip()
    for stop in ["###", "assistant", "\n\n"]:
        if stop in clean:
            clean = clean.split(stop)[0].strip()
    return clean

def build_prompt(question, schema):
    return f"""You are a SQL expert. Generate the simplest and most direct SQL query.
Use JOINs only when multiple tables are needed.

### Schema:
{schema}

### Question:
{question}

### SQL:"""

def generate(question, schema):
    if not question or not schema:
        return "Please provide both a question and schema!"
    
    messages = [{"role": "user", "content": build_prompt(question, schema)}]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize = False,
        add_generation_prompt = True
    )
    inputs = tokenizer(
        text,
        return_tensors = "pt",
        truncation = True,
        max_length = 1024
    ).to(model.device)
    
    stop_tokens = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|im_end|>"),
    ]
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens = 200,
            do_sample = False,
            temperature = 0,
            repetition_penalty = 1.3,
            eos_token_id = stop_tokens,
            pad_token_id = tokenizer.eos_token_id,
        )
    
    input_len = inputs["input_ids"].shape[1]
    raw = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True)
    return clean_sql(raw)

# Example schemas for demo
example_schema = """CREATE TABLE employees (
    id INTEGER,
    name VARCHAR,
    salary REAL,
    department VARCHAR,
    age INTEGER
);"""

with gr.Blocks(title="SQL Query Generator") as demo:
    gr.Markdown("# 🗄️ SQL Query Generator")
    gr.Markdown("Fine-tuned Qwen2.5-Coder 7B on Spider dataset | 42% execution accuracy")
    
    with gr.Row():
        with gr.Column():
            schema = gr.Textbox(
                label = "Database Schema (CREATE TABLE statements)",
                value = example_schema,
                lines = 10
            )
            question = gr.Textbox(
                label = "Question",
                placeholder = "How many employees have salary > 50000?",
                lines = 2
            )
            btn = gr.Button("🚀 Generate SQL", variant="primary")
        
        with gr.Column():
            output = gr.Code(
                label = "Generated SQL",
                language = "sql"
            )
            gr.Markdown("""
            ### 📊 Model Stats
            - **Base model**: Qwen2.5-Coder-7B
            - **Training data**: Spider dataset (7.9k samples)
            - **Simple queries**: 64.2% accuracy
            - **Complex queries**: 17.0% accuracy
            - **Overall**: 42% execution accuracy
            """)
    
    btn.click(fn=generate, inputs=[question, schema], outputs=output)
    
    gr.Examples(
        examples=[
            ["How many employees are there?", example_schema],
            ["Find all employees with salary greater than 50000", example_schema],
            ["What is the average salary by department?", example_schema],
        ],
        inputs=[question, schema]
    )

demo.launch()