Spaces:
Sleeping
Sleeping
| # app.py | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration - fastest practical settings | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" | |
| LORA_PATH = "saadkhi/SQL_Chat_finetuned_model" | |
| MAX_NEW_TOKENS = 180 # β keep reasonable | |
| TEMPERATURE = 0.0 # greedy = fastest & most deterministic | |
| DO_SAMPLE = False # no sampling = faster | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4-bit quantization config (this is the key speedup) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit = True, | |
| bnb_4bit_quant_type = "nf4", # "nf4" usually fastest + good quality | |
| bnb_4bit_use_double_quant = True, # nested quantization β extra memory saving | |
| bnb_4bit_compute_dtype = torch.bfloat16 # fastest compute type on modern GPUs | |
| ) | |
| print("Loading quantized base model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config = bnb_config, | |
| device_map = "auto", # auto = best available (cuda > cpu) | |
| trust_remote_code = True, | |
| torch_dtype = torch.bfloat16 | |
| ) | |
| print("Loading LoRA adapters...") | |
| model = PeftModel.from_pretrained(model, LORA_PATH) | |
| # Important: merge LoRA weights into base (faster inference, less overhead) | |
| model = model.merge_and_unload() | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| # Optional: small speedup boost on supported hardware | |
| if torch.cuda.is_available(): | |
| try: | |
| import torch.backends.cuda | |
| torch.backends.cuda.enable_flash_sdp(True) # flash scaled dot product | |
| except: | |
| pass | |
| model.eval() | |
| print("Model ready!") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_sql(prompt: str): | |
| # Use proper chat template (Phi-3 expects it) | |
| messages = [{"role": "user", "content": prompt}] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize = True, | |
| add_generation_prompt = True, | |
| return_tensors = "pt" | |
| ).to(model.device) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids = inputs, | |
| max_new_tokens = MAX_NEW_TOKENS, | |
| temperature = TEMPERATURE, | |
| do_sample = DO_SAMPLE, | |
| use_cache = True, | |
| pad_token_id = tokenizer.eos_token_id, | |
| eos_token_id = tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean output - try to get only assistant's answer | |
| if "<|assistant|>" in response: | |
| response = response.split("<|assistant|>", 1)[-1].strip() | |
| response = response.split("<|end|>")[0].strip() | |
| return response | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn = generate_sql, | |
| inputs = gr.Textbox( | |
| label = "Ask SQL related question", | |
| placeholder = "Show me all employees with salary > 50000...", | |
| lines = 3 | |
| ), | |
| outputs = gr.Textbox(label="Generated SQL / Answer"), | |
| title = "SQL Chatbot - Fast Version", | |
| description = "Phi-3-mini 4bit quantized + LoRA", | |
| examples = [ | |
| ["Find duplicate emails in users table"], | |
| ["Top 5 highest paid employees"], | |
| ["Count orders per customer last month"] | |
| ], | |
| allow_flagging = "never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |