import warnings warnings.filterwarnings("ignore") import torch from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM torch.set_num_threads(1) app = FastAPI() BASE_MODEL = "distilgpt2" print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) model.eval() print("Model ready") # ───────────────────────── # Request schema # ───────────────────────── class Query(BaseModel): question: str # ───────────────────────── # SQL FILTER # ───────────────────────── SQL_KEYWORDS = [ "sql", "database", "table", "select", "insert", "update", "delete", "join", "group by", "postgres", "mysql", "sqlite", "query" ] def is_sql_related(text): text = text.lower() return any(k in text for k in SQL_KEYWORDS) # ───────────────────────── # Endpoint # ───────────────────────── @app.post("/generate-sql") def generate_sql(data: Query): user_input = data.question if not user_input.strip(): return {"error": "Empty input"} if not is_sql_related(user_input): return {"error": "Only SQL-related queries allowed"} prompt = f""" You are an expert SQL generator. Only output SQL query. User: {user_input} SQL: """ inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=80, temperature=0.2, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) text = tokenizer.decode(output[0], skip_special_tokens=True) result = text.split("SQL:")[-1].strip().split("\n")[0] return {"sql": result}