Spaces:
Running
Running
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| torch.set_num_threads(1) | |
| app = FastAPI() | |
| BASE_MODEL = "distilgpt2" | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) | |
| model.eval() | |
| print("Model ready") | |
| # βββββββββββββββββββββββββ | |
| # Request schema | |
| # βββββββββββββββββββββββββ | |
| class Query(BaseModel): | |
| question: str | |
| # βββββββββββββββββββββββββ | |
| # SQL FILTER | |
| # βββββββββββββββββββββββββ | |
| SQL_KEYWORDS = [ | |
| "sql", "database", "table", "select", "insert", | |
| "update", "delete", "join", "group by", | |
| "postgres", "mysql", "sqlite", "query" | |
| ] | |
| def is_sql_related(text): | |
| text = text.lower() | |
| return any(k in text for k in SQL_KEYWORDS) | |
| # βββββββββββββββββββββββββ | |
| # Endpoint | |
| # βββββββββββββββββββββββββ | |
| def generate_sql(data: Query): | |
| user_input = data.question | |
| if not user_input.strip(): | |
| return {"error": "Empty input"} | |
| if not is_sql_related(user_input): | |
| return {"error": "Only SQL-related queries allowed"} | |
| prompt = f""" | |
| You are an expert SQL generator. | |
| Only output SQL query. | |
| User: {user_input} | |
| SQL: | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=80, | |
| temperature=0.2, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| result = text.split("SQL:")[-1].strip().split("\n")[0] | |
| return {"sql": result} |