Spaces:
Sleeping
Sleeping
File size: 3,817 Bytes
a663164 ef89ab8 43c048b 0ddc005 ef89ab8 43c048b 0ddc005 43c048b 0ddc005 ef89ab8 43c048b 0ddc005 979ad48 43c048b 979ad48 0ddc005 43c048b 0ddc005 43c048b 0ddc005 ef89ab8 43c048b a663164 0ddc005 43c048b 0ddc005 43c048b 0ddc005 43c048b 0ddc005 43c048b 0ddc005 43c048b a663164 0ddc005 43c048b 0ddc005 979ad48 0ddc005 ef89ab8 43c048b 0ddc005 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | # import torch
# import gradio as gr
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from peft import PeftModel
# from transformers import BitsAndBytesConfig
# device = "cuda" if torch.cuda.is_available() else "cpu"
# base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
# finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
# tokenizer = AutoTokenizer.from_pretrained(base_model)
# bnb = BitsAndBytesConfig(load_in_4bit=True)
# model = AutoModelForCausalLM.from_pretrained(
# base_model,
# quantization_config=bnb,
# torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
# device_map="auto"
# )
# model = PeftModel.from_pretrained(model, finetuned_model).to(device)
# model.eval()
# def chat(prompt):
# inputs = tokenizer(prompt, return_tensors="pt").to(device)
# with torch.inference_mode():
# output = model.generate(
# **inputs,
# max_new_tokens=60,
# temperature=0.1,
# do_sample=False
# )
# return tokenizer.decode(output[0], skip_special_tokens=True)
# iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="SQL Chatbot")
# iface.launch()
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch
# Quantization config for fast 4-bit loading
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load base model + your LoRA once at startup
base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
print("Loading model (20–40 seconds first time)...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2", # Fastest on T4/A10G
)
model = PeftModel.from_pretrained(base_model, lora_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
model.eval()
print("Model ready!")
def chat(message, history):
# Build full conversation history in Phi-3 format
messages = []
for user, assistant in history:
messages.append({"role": "user", "content": user})
if assistant:
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
# Tokenize with chat template
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Generate with optimal settings
outputs = model.generate(
inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
use_cache=True, # KV caching = much faster
eos_token_id=tokenizer.eos_token_id,
)
# Decode only the new response
response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
history.append((message, response))
return history, ""
# Gradio interface
with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
gr.Markdown("# SQL Chat Assistant")
gr.Markdown("Fine-tuned Phi-3 Mini for SQL queries. Responses in 2–6 seconds on GPU.")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Your Question", placeholder="e.g., delete duplicate rows from users table based on email", lines=2)
clear = gr.Button("Clear")
msg.submit(chat, [msg, chatbot], [chatbot, msg])
clear.click(lambda: ([], ""), None, chatbot)
demo.queue(max_size=30)
demo.launch() |