| | import gradio as gr |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import torch |
| | import re |
| | import sqlparse |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "onkolahmet/Qwen2-0.5B-Instruct-SQL-generator", |
| | torch_dtype="auto", |
| | device_map="auto" |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def generate_sql(question, context=None): |
| | |
| | prompt = "Translate natural language questions to SQL queries.\n\n" |
| | |
| | |
| | if context and context.strip(): |
| | prompt += f"Table Context:\n{context}\n\n" |
| | |
| | |
| | |
| | |
| | |
| | |
| | prompt += f"Q: {question}\nSQL:" |
| | |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| | |
| | |
| | outputs = model.generate( |
| | inputs.input_ids, |
| | max_new_tokens=128, |
| | do_sample=True, |
| | eos_token_id=tokenizer.eos_token_id |
| | ) |
| | |
| | |
| | sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) |
| | return sql_query.strip() |
| |
|
| | def clean_sql_output(sql_text): |
| | """ |
| | Clean and deduplicate SQL queries: |
| | 1. Remove comments |
| | 2. Remove duplicate queries |
| | 3. Extract only the most relevant query |
| | 4. Format properly |
| | """ |
| | |
| | sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE) |
| | sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL) |
| | |
| | |
| | sql_text = re.sub(r'```sql|```', '', sql_text) |
| | |
| | |
| | if ';' in sql_text: |
| | queries = [q.strip() for q in sql_text.split(';') if q.strip()] |
| | else: |
| | |
| | sql_text_cleaned = re.sub(r'\s+', ' ', sql_text) |
| | select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE)) |
| | |
| | if len(select_matches) > 1: |
| | queries = [] |
| | for i in range(len(select_matches)): |
| | start = select_matches[i].start() |
| | end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned) |
| | queries.append(sql_text_cleaned[start:end].strip()) |
| | else: |
| | queries = [sql_text] |
| | |
| | |
| | queries = [q for q in queries if q.strip()] |
| | |
| | if not queries: |
| | return "" |
| | |
| | |
| | if len(queries) > 1: |
| | |
| | normalized_queries = [] |
| | for q in queries: |
| | |
| | try: |
| | formatted = sqlparse.format( |
| | q + ('' if q.strip().endswith(';') else ';'), |
| | keyword_case='lower', |
| | identifier_case='lower', |
| | strip_comments=True, |
| | reindent=True |
| | ) |
| | normalized_queries.append(formatted) |
| | except: |
| | |
| | normalized = re.sub(r'\s+', ' ', q.lower().strip()) |
| | normalized_queries.append(normalized) |
| | |
| | |
| | unique_queries = [] |
| | unique_normalized = [] |
| | |
| | for i, norm_q in enumerate(normalized_queries): |
| | if norm_q not in unique_normalized: |
| | unique_normalized.append(norm_q) |
| | unique_queries.append(queries[i]) |
| | |
| | |
| | |
| | |
| | |
| | select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)] |
| | |
| | if select_queries: |
| | |
| | best_query = max(select_queries, key=len) |
| | elif unique_queries: |
| | |
| | best_query = max(unique_queries, key=len) |
| | else: |
| | |
| | best_query = queries[0] |
| | else: |
| | best_query = queries[0] |
| | |
| | |
| | best_query = best_query.strip() |
| | if not best_query.endswith(';'): |
| | best_query += ';' |
| | |
| | |
| | best_query = re.sub(r'\s+', ' ', best_query) |
| | |
| | try: |
| | |
| | formatted_sql = sqlparse.format( |
| | best_query, |
| | keyword_case='upper', |
| | identifier_case='lower', |
| | reindent=True, |
| | indent_width=2 |
| | ) |
| | return formatted_sql |
| | except: |
| | return best_query |
| |
|
| | def process_input(question, table_context): |
| | """Function to process user input through the model and return formatted results""" |
| | if not question.strip(): |
| | return "Please enter a question." |
| | |
| | |
| | raw_sql = generate_sql(question, table_context) |
| | |
| | |
| | cleaned_sql = clean_sql_output(raw_sql) |
| | |
| | if not cleaned_sql: |
| | return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question." |
| | |
| | return cleaned_sql |
| |
|
| | |
| | example_contexts = [ |
| | |
| | """ |
| | CREATE TABLE customers ( |
| | id INT PRIMARY KEY, |
| | name VARCHAR(100), |
| | email VARCHAR(100), |
| | order_date DATE |
| | ); |
| | """, |
| | |
| | |
| | """ |
| | CREATE TABLE products ( |
| | id INT PRIMARY KEY, |
| | name VARCHAR(100), |
| | category VARCHAR(50), |
| | price DECIMAL(10,2), |
| | stock_quantity INT |
| | ); |
| | """, |
| | |
| | |
| | """ |
| | CREATE TABLE employees ( |
| | id INT PRIMARY KEY, |
| | name VARCHAR(100), |
| | department VARCHAR(50), |
| | salary DECIMAL(10,2), |
| | hire_date DATE |
| | ); |
| | CREATE TABLE departments ( |
| | id INT PRIMARY KEY, |
| | name VARCHAR(50), |
| | manager_id INT, |
| | budget DECIMAL(15,2) |
| | ); |
| | """ |
| | ] |
| |
|
| | |
| | example_questions = [ |
| | "Get the names and emails of customers who placed an order in the last 30 days.", |
| | "Find all products with less than 10 items in stock.", |
| | "List all employees in the Sales department with a salary greater than 50000.", |
| | "What is the total budget for departments with more than 5 employees?", |
| | "Count how many products are in each category where the price is greater than 100." |
| | ] |
| |
|
| | |
| | with gr.Blocks(title="Text to SQL Converter") as demo: |
| | gr.Markdown("# Text to SQL Query Converter") |
| | gr.Markdown("Enter your question and optional table context to generate an SQL query.") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | question_input = gr.Textbox( |
| | label="Your Question", |
| | placeholder="e.g., Find all products with price less than $50", |
| | lines=2 |
| | ) |
| | |
| | table_context = gr.Textbox( |
| | label="Table Context (Optional)", |
| | placeholder="Enter your database schema or table definitions here...", |
| | lines=10 |
| | ) |
| | |
| | submit_btn = gr.Button("Generate SQL Query") |
| | |
| | with gr.Column(): |
| | sql_output = gr.Code( |
| | label="Generated SQL Query", |
| | language="sql", |
| | lines=12 |
| | ) |
| | |
| | |
| | gr.Markdown("### Try some examples") |
| | |
| | example_selector = gr.Examples( |
| | examples=[ |
| | ["List all products in the 'Electronics' category with price less than $500", example_contexts[1]], |
| | ["Find the total number of employees in each department", example_contexts[2]], |
| | ["Get customers who placed orders in the last 7 days", example_contexts[0]], |
| | ["Count the number of products in each category", example_contexts[1]], |
| | ["Find the average salary by department", example_contexts[2]] |
| | ], |
| | inputs=[question_input, table_context] |
| | ) |
| | |
| | |
| | submit_btn.click( |
| | fn=process_input, |
| | inputs=[question_input, table_context], |
| | outputs=sql_output |
| | ) |
| | |
| | |
| | question_input.submit( |
| | fn=process_input, |
| | inputs=[question_input, table_context], |
| | outputs=sql_output |
| | ) |
| | |
| | |
| | gr.Markdown(""" |
| | ### About |
| | This app uses a fine-tuned language model to convert natural language questions into SQL queries. |
| | |
| | - **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator) |
| | - **How to use**: |
| | 1. Enter your question in natural language |
| | 2. If you have specific table schemas, add them in the Table Context field |
| | 3. Click "Generate SQL Query" or press Enter |
| | |
| | Note: The model works best when table context is provided, but can generate generic SQL queries without it. |
| | """) |
| |
|
| | |
| | demo.launch() |