SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
1344c31 verified
raw
history blame
3.75 kB
# app.py
import torch
import gradio as gr
from unsloth import FastLanguageModel
# ────────────────────────────────────────────────────────────────
# Configuration - change here if needed
# ────────────────────────────────────────────────────────────────
MAX_NEW_TOKENS = 96
TEMPERATURE = 0.0 # 0.0 = greedy decoding = fastest
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
# ────────────────────────────────────────────────────────────────
print("Loading model with Unsloth...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL,
max_seq_length=2048,
dtype=None, # auto-detect (bf16 on GPU)
load_in_4bit=True,
)
print("Loading LoRA adapters...")
model = FastLanguageModel.get_peft_model(
model,
r=64, # your original rank
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_alpha=128,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
)
print("Merging LoRA and preparing for inference...")
model = FastLanguageModel.for_inference(model) # important! activates 2x faster kernels
# Optional - compile can give additional 20-60% speedup (PyTorch 2.0+)
if torch.cuda.is_available() and torch.__version__ >= "2.0":
print("Compiling model...")
model = torch.compile(model, mode="reduce-overhead")
print("Model ready!")
# ────────────────────────────────────────────────────────────────
def generate_sql(prompt: str):
# Very clean chat template usage
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to("cuda" if torch.cuda.is_available() else "cpu")
outputs = model.generate(
input_ids=inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
do_sample=(TEMPERATURE > 0.01),
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Try to cut after assistant's answer
if "<|assistant|>" in response:
response = response.split("<|assistant|>", 1)[-1].strip()
if "<|end|>" in response:
response = response.split("<|end|>")[0].strip()
return response
# ────────────────────────────────────────────────────────────────
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
label="Ask SQL related question",
placeholder="Show me all employees with salary > 50000...",
lines=3,
),
outputs=gr.Textbox(label="Generated SQL / Answer"),
title="SQL Chat Assistant (Phi-3-mini fine-tuned)",
description="Fast version using Unsloth",
examples=[
["Find all duplicate emails in users table"],
["Get top 5 highest paid employees"],
["How many orders per customer last month?"],
],
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()