Spaces:
Sleeping
Sleeping
File size: 3,692 Bytes
22654ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import gradio as gr
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL = "jinesh90/qwen2.5-coder-sql-generator"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype = torch.float16,
device_map = "auto",
low_cpu_mem_usage = True,
)
model.eval()
print("Ready!")
def clean_sql(text):
text = text.strip()
clean = re.sub(r'[^\x00-\x7F].*', '', text).strip()
for stop in ["###", "assistant", "\n\n"]:
if stop in clean:
clean = clean.split(stop)[0].strip()
return clean
def build_prompt(question, schema):
return f"""You are a SQL expert. Generate the simplest and most direct SQL query.
Use JOINs only when multiple tables are needed.
### Schema:
{schema}
### Question:
{question}
### SQL:"""
def generate(question, schema):
if not question or not schema:
return "Please provide both a question and schema!"
messages = [{"role": "user", "content": build_prompt(question, schema)}]
text = tokenizer.apply_chat_template(
messages,
tokenize = False,
add_generation_prompt = True
)
inputs = tokenizer(
text,
return_tensors = "pt",
truncation = True,
max_length = 1024
).to(model.device)
stop_tokens = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|im_end|>"),
]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens = 200,
do_sample = False,
temperature = 0,
repetition_penalty = 1.3,
eos_token_id = stop_tokens,
pad_token_id = tokenizer.eos_token_id,
)
input_len = inputs["input_ids"].shape[1]
raw = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True)
return clean_sql(raw)
# Example schemas for demo
example_schema = """CREATE TABLE employees (
id INTEGER,
name VARCHAR,
salary REAL,
department VARCHAR,
age INTEGER
);"""
with gr.Blocks(title="SQL Query Generator") as demo:
gr.Markdown("# 🗄️ SQL Query Generator")
gr.Markdown("Fine-tuned Qwen2.5-Coder 7B on Spider dataset | 42% execution accuracy")
with gr.Row():
with gr.Column():
schema = gr.Textbox(
label = "Database Schema (CREATE TABLE statements)",
value = example_schema,
lines = 10
)
question = gr.Textbox(
label = "Question",
placeholder = "How many employees have salary > 50000?",
lines = 2
)
btn = gr.Button("🚀 Generate SQL", variant="primary")
with gr.Column():
output = gr.Code(
label = "Generated SQL",
language = "sql"
)
gr.Markdown("""
### 📊 Model Stats
- **Base model**: Qwen2.5-Coder-7B
- **Training data**: Spider dataset (7.9k samples)
- **Simple queries**: 64.2% accuracy
- **Complex queries**: 17.0% accuracy
- **Overall**: 42% execution accuracy
""")
btn.click(fn=generate, inputs=[question, schema], outputs=output)
gr.Examples(
examples=[
["How many employees are there?", example_schema],
["Find all employees with salary greater than 50000", example_schema],
["What is the average salary by department?", example_schema],
],
inputs=[question, schema]
)
demo.launch()
|