import warnings warnings.filterwarnings("ignore") import torch from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import threading torch.set_num_threads(1) app = FastAPI() BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) model.eval() print("Model ready") # ───────────────────────── # SQL FILTER # ───────────────────────── SQL_KEYWORDS = [ "sql", "database", "table", "select", "insert", "update", "delete", "join", "group by", "postgres", "mysql", "sqlite", "query" ] def is_sql_related(text): return any(k in text.lower() for k in SQL_KEYWORDS) SYSTEM_PROMPT = """ You are an expert SQL generator. Only output SQL query. """ def generate_sql(user_input: str): if not user_input.strip(): return "Enter SQL question." if not is_sql_related(user_input): return "Only SQL/database questions allowed." prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:" inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=120, temperature=0.1, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) text = tokenizer.decode(output[0], skip_special_tokens=True) result = text.split("SQL:")[-1].strip().split("\n")[0] return result # ───────────────────────── # FastAPI Routes # ───────────────────────── class Query(BaseModel): text: str @app.get("/") def root(): return {"status": "API running"} @app.post("/generate") def generate(query: Query): return {"result": generate_sql(query.text)} # ───────────────────────── # Gradio UI (for testing) # ───────────────────────── def launch_gradio(): demo = gr.Interface( fn=generate_sql, inputs=gr.Textbox(lines=3, label="SQL Question"), outputs=gr.Textbox(lines=6, label="Generated SQL"), title="SQL Generator Test UI" ) demo.launch(server_name="0.0.0.0", server_port=7861) # Run Gradio in parallel thread threading.Thread(target=launch_gradio).start()