SQL_chatbot_API / app.py
saadkhi's picture
optimized sol, review needed
979ad48
raw
history blame
1.13 kB
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from transformers import BitsAndBytesConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
tokenizer = AutoTokenizer.from_pretrained(base_model)
bnb = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto"
)
model = PeftModel.from_pretrained(model, finetuned_model).to(device)
model.eval()
def chat(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=60,
temperature=0.1,
do_sample=False
)
return tokenizer.decode(output[0], skip_special_tokens=True)
iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="SQL Chatbot")
iface.launch()