Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Reduce CPU pressure | |
| torch.set_num_threads(1) | |
| # β Use lightweight model (IMPORTANT) | |
| BASE_MODEL = "distilgpt2" | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) | |
| model.eval() | |
| print("Model ready") | |
| # βββββββββββββββββββββββββ | |
| # SQL FILTER | |
| # βββββββββββββββββββββββββ | |
| SQL_KEYWORDS = [ | |
| "sql", "database", "table", "select", "insert", | |
| "update", "delete", "join", "group by", | |
| "postgres", "mysql", "sqlite", "query" | |
| ] | |
| def is_sql_related(text): | |
| text = text.lower() | |
| return any(k in text for k in SQL_KEYWORDS) | |
| # βββββββββββββββββββββββββ | |
| # PROMPT | |
| # βββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """ | |
| You are an expert SQL generator. | |
| Rules: | |
| - Only respond to SQL or database related questions. | |
| - Output ONLY SQL query. | |
| - No explanation. | |
| """ | |
| # βββββββββββββββββββββββββ | |
| # GENERATION | |
| # βββββββββββββββββββββββββ | |
| def generate_sql(user_input): | |
| if not user_input.strip(): | |
| return "Enter SQL question." | |
| if not is_sql_related(user_input): | |
| return "Only SQL/database questions are allowed." | |
| prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=80, | |
| temperature=0.2, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| result = text.split("SQL:")[-1].strip() | |
| result = result.split("\n")[0] | |
| return result | |
| # βββββββββββββββββββββββββ | |
| # UI | |
| # βββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=gr.Textbox( | |
| lines=3, | |
| label="SQL Question", | |
| placeholder="Find duplicate emails in users table" | |
| ), | |
| outputs=gr.Textbox( | |
| lines=6, | |
| label="Generated SQL" | |
| ), | |
| title="AI SQL Generator (Portfolio Project)", | |
| description="Only SQL/database queries are supported.", | |
| examples=[ | |
| ["Find duplicate emails in users table"], | |
| ["Top 5 highest paid employees"], | |
| ["Count orders per customer last month"], | |
| ["Write a joke about cats"] | |
| ], | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |