Spaces:
Runtime error
Runtime error
File size: 1,956 Bytes
e62bece 7f3026b c7c0d53 7f424d1 02976e0 e62bece 00c8a57 e95c2d3 bb16527 4bc3e8b e95c2d3 7f3026b e95c2d3 60e496e e95c2d3 02976e0 bb16527 7f3026b bb16527 a2f39c6 bb16527 4bc3e8b bb16527 22df2c5 e95c2d3 bb16527 7f3026b e62bece 7f424d1 22df2c5 e95c2d3 02976e0 bb16527 8b67be0 bb16527 e95c2d3 bb16527 84031c5 e62bece 7f3026b e95c2d3 bb16527 84031c5 1344c31 bb16527 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import warnings
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_num_threads(1)
# βββββββββββββββββββββ
# MODEL
# βββββββββββββββββββββ
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# βββββββββββββββββββββ
# GENERATION
# βββββββββββββββββββββ
def generate_sql(question):
if not question.strip():
return "Enter SQL question."
prompt = f"""
You are a SQL expert.
Convert the user request into SQL query only.
User: {question}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.2,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text.split("SQL:")[-1].strip()
# βββββββββββββββββββββ
# UI
# βββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(lines=3, label="SQL Question"),
outputs=gr.Textbox(lines=8, label="Generated SQL"),
title="SQL Generator (Portfolio Demo)",
description="Fast CPU model for portfolio demo.",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Orders per customer last month"],
],
)
demo.launch(server_name="0.0.0.0")
|