SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
52ae0ac verified
raw
history blame
2.65 kB
import gradio as gr
import torch
from unsloth import FastLanguageModel
# ── Global model (loaded once at startup) ───────────────────────────────
print("Loading model...")
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # very fast pre-quantized base
max_seq_length=2048,
dtype=None, # auto (bf16/float16)
load_in_4bit=True,
)
# Load your LoRA adapter
model = FastLanguageModel.for_inference(
model.load_adapter("saadkhi/SQL_Chat_finetuned_model")
)
print("Model loaded successfully!")
# ── Chat function ───────────────────────────────────────────────────────
def generate_response(message, history):
# Build messages list (multi-turn support)
messages = []
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# Use the proper chat template (very important for Phi-3)
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to("cuda" if torch.cuda.is_available() else "cpu")
# Generate
outputs = model.generate(
input_ids=inputs,
max_new_tokens=180, # ← increased but still reasonable
temperature=0.0,
do_sample=False, # greedy = fastest & most deterministic
use_cache=True,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up output (remove input prompt part)
if "<|assistant|>" in response:
response = response.split("<|assistant|>")[-1].strip()
return response
# ── Gradio UI ───────────────────────────────────────────────────────────
demo = gr.ChatInterface(
fn=generate_response,
title="SQL Chat Assistant (Fast Version)",
description="Ask SQL related questions β€’ Powered by Phi-3-mini + your fine-tune",
examples=[
"Write a query to find duplicate emails in users table",
"How to delete rows with NULL values in column price?",
"Select top 10 most expensive products",
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()