Spaces:
Sleeping
Sleeping
File size: 2,700 Bytes
ef89ab8 24f8f89 43c048b 24f8f89 ef89ab8 24f8f89 43c048b 24f8f89 979ad48 24f8f89 43c048b 24f8f89 979ad48 24f8f89 ed1eebe 24f8f89 ed1eebe 0ddc005 24f8f89 ed1eebe 24f8f89 0ddc005 24f8f89 0ddc005 24f8f89 979ad48 24f8f89 ef89ab8 24f8f89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch
import os
# Create offload folder (very important!)
OFFLOAD_DIR = "offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
# Optimal 4-bit quantization config
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Model paths
base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True,
offload_folder=OFFLOAD_DIR, # β Required fix!
)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(
base_model,
lora_model_name,
offload_folder=OFFLOAD_DIR, # β Required here too!
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
model.eval()
print("Model loaded successfully!")
def chat(message, history):
# Build conversation in correct Phi-3 format
messages = []
for user, assistant in history:
messages.append({"role": "user", "content": user})
if assistant:
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Fast generation settings
outputs = model.generate(
inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
use_cache=True,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
history.append((message, response))
return history, ""
# Gradio UI
with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
gr.Markdown("# SQL Chat Assistant")
gr.Markdown("Fine-tuned Phi-3 Mini (4-bit) for SQL queries. Responses ~3β10s on GPU.")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(
label="Your Question",
placeholder="e.g., delete duplicate rows from users table based on email",
lines=2
)
clear = gr.Button("Clear")
msg.submit(chat, [msg, chatbot], [chatbot, msg])
clear.click(lambda: ([], ""), None, chatbot)
demo.queue(max_size=30)
demo.launch() |