Spaces:
Sleeping
Sleeping
File size: 4,168 Bytes
bbdf923 24f8f89 87ff5b4 00c8a57 1344c31 bbdf923 1344c31 00c8a57 1344c31 00c8a57 bbdf923 1344c31 bbdf923 00c8a57 bbdf923 52ae0ac 1344c31 00c8a57 bbdf923 00c8a57 1344c31 00c8a57 bbdf923 00c8a57 1344c31 00c8a57 1344c31 bbdf923 00c8a57 bbdf923 00c8a57 1344c31 00c8a57 bbdf923 1344c31 bbdf923 1344c31 bbdf923 1344c31 bbdf923 1344c31 bbdf923 00c8a57 1344c31 bbdf923 1344c31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# app.py - Fixed for recent Gradio versions (no allow_flagging)
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Fastest practical configuration
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0 # greedy = fastest
DO_SAMPLE = False
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4-bit quantization (very important for speed)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_use_double_quant = True,
bnb_4bit_compute_dtype = torch.bfloat16
)
print("Loading quantized base model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config = bnb_config,
device_map = "auto",
trust_remote_code = True,
torch_dtype = torch.bfloat16
)
print("Loading LoRA adapters...")
model = PeftModel.from_pretrained(model, LORA_PATH)
# Merge LoRA into base model β much faster inference
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready!")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate_sql(prompt: str):
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
with torch.inference_mode():
outputs = model.generate(
input_ids = inputs,
max_new_tokens = MAX_NEW_TOKENS,
temperature = TEMPERATURE,
do_sample = DO_SAMPLE,
use_cache = True,
pad_token_id = tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean output
if "<|assistant|>" in response:
response = response.split("<|assistant|>", 1)[-1].strip()
response = response.split("<|end|>")[0].strip() if "<|end|>" in response else response
return response
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Gradio interface - modern style (no allow_flagging)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
label="Ask SQL related question",
placeholder="Show me all employees with salary > 50000...",
lines=3
),
outputs=gr.Textbox(label="Generated SQL / Answer"),
title="SQL Chatbot - Optimized",
description="Phi-3-mini 4bit + LoRA merged",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"]
],
# flag button is disabled by default in newer versions β no need for allow_flagging
)
if __name__ == "__main__":
demo.launch() |