# app.py import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel # ──────────────────────────────────────────────────────────────── # Configuration - fastest practical settings # ──────────────────────────────────────────────────────────────── BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" LORA_PATH = "saadkhi/SQL_Chat_finetuned_model" MAX_NEW_TOKENS = 180 # ← keep reasonable TEMPERATURE = 0.0 # greedy = fastest & most deterministic DO_SAMPLE = False # no sampling = faster # ──────────────────────────────────────────────────────────────── # 4-bit quantization config (this is the key speedup) # ──────────────────────────────────────────────────────────────── bnb_config = BitsAndBytesConfig( load_in_4bit = True, bnb_4bit_quant_type = "nf4", # "nf4" usually fastest + good quality bnb_4bit_use_double_quant = True, # nested quantization → extra memory saving bnb_4bit_compute_dtype = torch.bfloat16 # fastest compute type on modern GPUs ) print("Loading quantized base model...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config = bnb_config, device_map = "auto", # auto = best available (cuda > cpu) trust_remote_code = True, torch_dtype = torch.bfloat16 ) print("Loading LoRA adapters...") model = PeftModel.from_pretrained(model, LORA_PATH) # Important: merge LoRA weights into base (faster inference, less overhead) model = model.merge_and_unload() tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # Optional: small speedup boost on supported hardware if torch.cuda.is_available(): try: import torch.backends.cuda torch.backends.cuda.enable_flash_sdp(True) # flash scaled dot product except: pass model.eval() print("Model ready!") # ──────────────────────────────────────────────────────────────── def generate_sql(prompt: str): # Use proper chat template (Phi-3 expects it) messages = [{"role": "user", "content": prompt}] inputs = tokenizer.apply_chat_template( messages, tokenize = True, add_generation_prompt = True, return_tensors = "pt" ).to(model.device) with torch.inference_mode(): outputs = model.generate( input_ids = inputs, max_new_tokens = MAX_NEW_TOKENS, temperature = TEMPERATURE, do_sample = DO_SAMPLE, use_cache = True, pad_token_id = tokenizer.eos_token_id, eos_token_id = tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean output - try to get only assistant's answer if "<|assistant|>" in response: response = response.split("<|assistant|>", 1)[-1].strip() response = response.split("<|end|>")[0].strip() return response # ──────────────────────────────────────────────────────────────── demo = gr.Interface( fn = generate_sql, inputs = gr.Textbox( label = "Ask SQL related question", placeholder = "Show me all employees with salary > 50000...", lines = 3 ), outputs = gr.Textbox(label="Generated SQL / Answer"), title = "SQL Chatbot - Fast Version", description = "Phi-3-mini 4bit quantized + LoRA", examples = [ ["Find duplicate emails in users table"], ["Top 5 highest paid employees"], ["Count orders per customer last month"] ], allow_flagging = "never" ) if __name__ == "__main__": demo.launch()