Spaces:
Running
Running
File size: 3,960 Bytes
a2f39c6 7f424d1 02976e0 a2f39c6 7f424d1 00c8a57 02976e0 84031c5 a2f39c6 02976e0 a2f39c6 02976e0 a2f39c6 02976e0 a2f39c6 02976e0 7f424d1 a2f39c6 7f424d1 a2f39c6 84031c5 02976e0 a2f39c6 7f424d1 02976e0 a2f39c6 7f424d1 a2f39c6 107fcf0 a2f39c6 02976e0 a2f39c6 84031c5 a2f39c6 02976e0 32343cc a2f39c6 7f424d1 02976e0 7f424d1 a2f39c6 7f424d1 02976e0 a2f39c6 7f424d1 a2f39c6 84031c5 a2f39c6 32343cc a2f39c6 02976e0 84031c5 02976e0 84031c5 7f424d1 02976e0 a2f39c6 02976e0 a2f39c6 02976e0 a2f39c6 32343cc a2f39c6 84031c5 1344c31 84031c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# app.py
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0
DO_SAMPLE = False
print("Loading quantized base model on CPU...")
print("(GPU will be used only during inference if available)")
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load base model β always on CPU first
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="cpu",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
print("Loading LoRA adapters...")
model = PeftModel.from_pretrained(model, LORA_PATH)
# Merge for faster inference (very recommended)
print("Merging LoRA into base model...")
model = model.merge_and_unload()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
model.eval()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@spaces.GPU(duration=60, max_requests=20) # safe values for ZeroGPU
def generate_sql(prompt: str):
# Prepare chat format
messages = [
{"role": "user", "content": prompt}
]
# Tokenize on CPU (safe everywhere)
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
)
# Choose device dynamically - this is the ZeroGPU-safe way
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"β Running inference on device: {device}")
inputs = inputs.to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
do_sample=DO_SAMPLE,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and clean output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove user's prompt + assistant tag if present
if "<|assistant|>" in response:
response = response.split("<|assistant|>", 1)[-1].strip()
# Cut at end token if exists
if "<|end|>" in response:
response = response.split("<|end|>", 1)[0].strip()
return response.strip()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
label="Ask a question about SQL",
placeholder="Delete duplicate rows from users table based on email",
lines=3,
),
outputs=gr.Textbox(label="Generated SQL Query"),
title="SQL Chatbot β Phi-3-mini + LoRA",
description=(
"Fine-tuned Phi-3-mini-4k-instruct (4bit) for generating SQL queries\n\n"
"Works on ZeroGPU and regular GPU hardware"
),
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"],
["Show all products that haven't been ordered in the last 6 months"],
["Update all orders from 2024 to status 'completed'"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch() |