SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
84031c5 verified
raw
history blame
3.29 kB
# app.py - Optimized for Hugging Face Spaces (Unsloth = 2-4x faster)
import torch
import gradio as gr
from unsloth import FastLanguageModel
# ────────────────────────────────────────────────────────────────
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0 # Greedy = fastest & deterministic
# ────────────────────────────────────────────────────────────────
print("Loading base model with Unsloth (4-bit)...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = BASE_MODEL,
max_seq_length = 2048,
dtype = None, # Auto: bfloat16 on GPU
load_in_4bit = True, # Already quantized base
)
print("Applying your LoRA adapter...")
model = FastLanguageModel.get_peft_model(
model,
r = 64, # Match your original rank
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 128,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
)
# Enable 2x faster inference kernels
FastLanguageModel.for_inference(model)
print("Model ready! (very fast now)")
# ────────────────────────────────────────────────────────────────
def generate_sql(prompt: str):
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to("cuda" if torch.cuda.is_available() else "cpu")
outputs = model.generate(
input_ids = inputs,
max_new_tokens = MAX_NEW_TOKENS,
temperature = TEMPERATURE,
do_sample = (TEMPERATURE > 0.01),
use_cache = True,
pad_token_id = tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only assistant response
if "<|assistant|>" in response:
response = response.split("<|assistant|>", 1)[-1].strip()
response = response.split("<|end|>")[0].strip()
return response
# ────────────────────────────────────────────────────────────────
demo = gr.Interface(
fn = generate_sql,
inputs = gr.Textbox(
label = "Ask SQL question",
placeholder = "Delete duplicate rows from users table based on email",
lines = 3
),
outputs = gr.Textbox(label="Generated SQL"),
title = "SQL Chatbot - Ultra Fast (Unsloth)",
description = "Phi-3-mini 4-bit + your LoRA",
examples = [
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"]
]
)
if __name__ == "__main__":
demo.launch()