import gradio as gr import torch import re from transformers import AutoTokenizer, AutoModelForCausalLM MODEL = "jinesh90/qwen2.5-coder-sql-generator" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL) model = AutoModelForCausalLM.from_pretrained( MODEL, torch_dtype = torch.float16, device_map = "auto", low_cpu_mem_usage = True, ) model.eval() print("Ready!") def clean_sql(text): text = text.strip() clean = re.sub(r'[^\x00-\x7F].*', '', text).strip() for stop in ["###", "assistant", "\n\n"]: if stop in clean: clean = clean.split(stop)[0].strip() return clean def build_prompt(question, schema): return f"""You are a SQL expert. Generate the simplest and most direct SQL query. Use JOINs only when multiple tables are needed. ### Schema: {schema} ### Question: {question} ### SQL:""" def generate(question, schema): if not question or not schema: return "Please provide both a question and schema!" messages = [{"role": "user", "content": build_prompt(question, schema)}] text = tokenizer.apply_chat_template( messages, tokenize = False, add_generation_prompt = True ) inputs = tokenizer( text, return_tensors = "pt", truncation = True, max_length = 1024 ).to(model.device) stop_tokens = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>"), ] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens = 200, do_sample = False, temperature = 0, repetition_penalty = 1.3, eos_token_id = stop_tokens, pad_token_id = tokenizer.eos_token_id, ) input_len = inputs["input_ids"].shape[1] raw = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True) return clean_sql(raw) # Example schemas for demo example_schema = """CREATE TABLE employees ( id INTEGER, name VARCHAR, salary REAL, department VARCHAR, age INTEGER );""" with gr.Blocks(title="SQL Query Generator") as demo: gr.Markdown("# 🗄️ SQL Query Generator") gr.Markdown("Fine-tuned Qwen2.5-Coder 7B on Spider dataset | 42% execution accuracy") with gr.Row(): with gr.Column(): schema = gr.Textbox( label = "Database Schema (CREATE TABLE statements)", value = example_schema, lines = 10 ) question = gr.Textbox( label = "Question", placeholder = "How many employees have salary > 50000?", lines = 2 ) btn = gr.Button("🚀 Generate SQL", variant="primary") with gr.Column(): output = gr.Code( label = "Generated SQL", language = "sql" ) gr.Markdown(""" ### 📊 Model Stats - **Base model**: Qwen2.5-Coder-7B - **Training data**: Spider dataset (7.9k samples) - **Simple queries**: 64.2% accuracy - **Complex queries**: 17.0% accuracy - **Overall**: 42% execution accuracy """) btn.click(fn=generate, inputs=[question, schema], outputs=output) gr.Examples( examples=[ ["How many employees are there?", example_schema], ["Find all employees with salary greater than 50000", example_schema], ["What is the average salary by department?", example_schema], ], inputs=[question, schema] ) demo.launch()