import warnings warnings.filterwarnings("ignore") import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # Reduce CPU pressure torch.set_num_threads(1) # ✅ Use lightweight model (IMPORTANT) BASE_MODEL = "distilgpt2" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) model.eval() print("Model ready") # ───────────────────────── # SQL FILTER # ───────────────────────── SQL_KEYWORDS = [ "sql", "database", "table", "select", "insert", "update", "delete", "join", "group by", "postgres", "mysql", "sqlite", "query" ] def is_sql_related(text): text = text.lower() return any(k in text for k in SQL_KEYWORDS) # ───────────────────────── # PROMPT # ───────────────────────── SYSTEM_PROMPT = """ You are an expert SQL generator. Rules: - Only respond to SQL or database related questions. - Output ONLY SQL query. - No explanation. """ # ───────────────────────── # GENERATION # ───────────────────────── def generate_sql(user_input): if not user_input.strip(): return "Enter SQL question." if not is_sql_related(user_input): return "Only SQL/database questions are allowed." prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:" inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=80, temperature=0.2, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) text = tokenizer.decode(output[0], skip_special_tokens=True) result = text.split("SQL:")[-1].strip() result = result.split("\n")[0] return result # ───────────────────────── # UI # ───────────────────────── demo = gr.Interface( fn=generate_sql, inputs=gr.Textbox( lines=3, label="SQL Question", placeholder="Find duplicate emails in users table" ), outputs=gr.Textbox( lines=6, label="Generated SQL" ), title="AI SQL Generator (Portfolio Project)", description="Only SQL/database queries are supported.", examples=[ ["Find duplicate emails in users table"], ["Top 5 highest paid employees"], ["Count orders per customer last month"], ["Write a joke about cats"] ], ) demo.launch(server_name="0.0.0.0", server_port=7860)