Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import threading | |
| torch.set_num_threads(1) | |
| app = FastAPI() | |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) | |
| model.eval() | |
| print("Model ready") | |
| # βββββββββββββββββββββββββ | |
| # SQL FILTER | |
| # βββββββββββββββββββββββββ | |
| SQL_KEYWORDS = [ | |
| "sql", "database", "table", "select", "insert", | |
| "update", "delete", "join", "group by", | |
| "postgres", "mysql", "sqlite", "query" | |
| ] | |
| def is_sql_related(text): | |
| return any(k in text.lower() for k in SQL_KEYWORDS) | |
| SYSTEM_PROMPT = """ | |
| You are an expert SQL generator. | |
| Only output SQL query. | |
| """ | |
| def generate_sql(user_input: str): | |
| if not user_input.strip(): | |
| return "Enter SQL question." | |
| if not is_sql_related(user_input): | |
| return "Only SQL/database questions allowed." | |
| prompt = f"{SYSTEM_PROMPT}\nUser: {user_input}\nSQL:" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=120, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| result = text.split("SQL:")[-1].strip().split("\n")[0] | |
| return result | |
| # βββββββββββββββββββββββββ | |
| # FastAPI Routes | |
| # βββββββββββββββββββββββββ | |
| class Query(BaseModel): | |
| text: str | |
| def root(): | |
| return {"status": "API running"} | |
| def generate(query: Query): | |
| return {"result": generate_sql(query.text)} | |
| # βββββββββββββββββββββββββ | |
| # Gradio UI (for testing) | |
| # βββββββββββββββββββββββββ | |
| def launch_gradio(): | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=gr.Textbox(lines=3, label="SQL Question"), | |
| outputs=gr.Textbox(lines=6, label="Generated SQL"), | |
| title="SQL Generator Test UI" | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7861) | |
| # Run Gradio in parallel thread | |
| threading.Thread(target=launch_gradio).start() |