Spaces:
Runtime error
Runtime error
File size: 5,555 Bytes
8b67be0 c7c0d53 7f424d1 02976e0 0fad5f5 7f424d1 00c8a57 0fad5f5 8b67be0 84031c5 0fad5f5 02976e0 0fad5f5 a2f39c6 8b67be0 7f424d1 8b67be0 02976e0 8b67be0 a2f39c6 8b67be0 a2f39c6 0fad5f5 8b67be0 84031c5 8b67be0 7f424d1 8b67be0 02976e0 8b67be0 02976e0 8b67be0 84031c5 8b67be0 0fad5f5 8b67be0 84031c5 7f424d1 02976e0 8b67be0 9158eaa 02976e0 8b67be0 32343cc 9158eaa 84031c5 1344c31 84031c5 8b67be0 9ae2c39 8b67be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | # app.py
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Configuration
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0
DO_SAMPLE = False
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Load model safely on CPU first
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Loading base model on CPU...")
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="cpu", # Critical for ZeroGPU + CPU spaces
trust_remote_code=True,
low_cpu_mem_usage=True
)
print("Loading and merging LoRA adapters...")
model = PeftModel.from_pretrained(model, LORA_PATH)
model = model.merge_and_unload() # Merge once β faster inference
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model successfully loaded on CPU")
except Exception as e:
print(f"Model loading failed: {str(e)}")
raise
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Inference function β GPU only here
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@spaces.GPU(duration=60) # 60 seconds is usually enough
def generate_sql(prompt: str):
try:
messages = [{"role": "user", "content": prompt.strip()}]
# Tokenize on CPU
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
)
# Move to GPU only inside decorated function
if torch.cuda.is_available():
inputs = inputs.to("cuda")
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
do_sample=DO_SAMPLE,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean output
if "<|assistant|>" in response:
response = response.split("<|assistant|>", 1)[-1].strip()
if "<|end|>" in response:
response = response.split("<|end|>")[0].strip()
if "<|user|>" in response:
response = response.split("<|user|>")[0].strip()
return response.strip() or "No valid response generated."
except Exception as e:
return f"Error during generation: {str(e)}"
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Gradio Interface
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
label="Your SQL-related question",
placeholder="e.g. Find duplicate emails in users table",
lines=3,
max_lines=6
),
outputs=gr.Textbox(
label="Generated SQL / Answer",
lines=6
),
title="SQL Chatbot β Phi-3-mini fine-tuned",
description=(
"Ask questions about SQL queries.\n\n"
"Free CPU version β responses may take 30β120 seconds or more."
),
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees from employees table"],
["Count total orders per customer in last 30 days"],
["Delete duplicate rows based on email column"]
],
cache_examples=False, # keep this
# allow_flagging="never" β REMOVE THIS LINE COMPLETELY
)
if __name__ == "__main__":
print("Starting Gradio server...")
import time
time.sleep(15) # Give extra time for model/Gradio to settle
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=False,
quiet=False,
show_error=True,
prevent_thread_lock=True # Helps in containers
)
except Exception as e:
print(f"Launch failed: {str(e)}")
raise |