# app.py import torch import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel # ──────────────────────────────────────────────────────────────── BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" LORA_PATH = "saadkhi/SQL_Chat_finetuned_model" MAX_NEW_TOKENS = 180 TEMPERATURE = 0.0 DO_SAMPLE = False print("Loading quantized base model on CPU...") print("(GPU will be used only during inference if available)") # 4-bit quantization config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) # Load base model → always on CPU first model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="cpu", trust_remote_code=True, torch_dtype=torch.bfloat16, ) print("Loading LoRA adapters...") model = PeftModel.from_pretrained(model, LORA_PATH) # Merge for faster inference (very recommended) print("Merging LoRA into base model...") model = model.merge_and_unload() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) tokenizer.pad_token = tokenizer.eos_token model.eval() # ──────────────────────────────────────────────────────────────── @spaces.GPU(duration=60, max_requests=20) # safe values for ZeroGPU def generate_sql(prompt: str): # Prepare chat format messages = [ {"role": "user", "content": prompt} ] # Tokenize on CPU (safe everywhere) inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) # Choose device dynamically - this is the ZeroGPU-safe way device = "cuda" if torch.cuda.is_available() else "cpu" print(f"→ Running inference on device: {device}") inputs = inputs.to(device) with torch.inference_mode(): outputs = model.generate( input_ids=inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, do_sample=DO_SAMPLE, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode and clean output response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove user's prompt + assistant tag if present if "<|assistant|>" in response: response = response.split("<|assistant|>", 1)[-1].strip() # Cut at end token if exists if "<|end|>" in response: response = response.split("<|end|>", 1)[0].strip() return response.strip() # ──────────────────────────────────────────────────────────────── demo = gr.Interface( fn=generate_sql, inputs=gr.Textbox( label="Ask a question about SQL", placeholder="Delete duplicate rows from users table based on email", lines=3, ), outputs=gr.Textbox(label="Generated SQL Query"), title="SQL Chatbot – Phi-3-mini + LoRA", description=( "Fine-tuned Phi-3-mini-4k-instruct (4bit) for generating SQL queries\n\n" "Works on ZeroGPU and regular GPU hardware" ), examples=[ ["Find duplicate emails in users table"], ["Top 5 highest paid employees"], ["Count orders per customer last month"], ["Show all products that haven't been ordered in the last 6 months"], ["Update all orders from 2024 to status 'completed'"], ], cache_examples=False, ) if __name__ == "__main__": demo.launch()