Spaces:
Sleeping
Sleeping
| # app.py | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" | |
| LORA_PATH = "saadkhi/SQL_Chat_finetuned_model" | |
| MAX_NEW_TOKENS = 180 | |
| TEMPERATURE = 0.0 | |
| DO_SAMPLE = False | |
| print("Loading quantized base model on CPU...") | |
| print("(GPU will be used only during inference if available)") | |
| # 4-bit quantization config | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # Load base model β always on CPU first | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| print("Loading LoRA adapters...") | |
| model = PeftModel.from_pretrained(model, LORA_PATH) | |
| # Merge for faster inference (very recommended) | |
| print("Merging LoRA into base model...") | |
| model = model.merge_and_unload() | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.eval() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # safe values for ZeroGPU | |
| def generate_sql(prompt: str): | |
| # Prepare chat format | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| # Tokenize on CPU (safe everywhere) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| # Choose device dynamically - this is the ZeroGPU-safe way | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"β Running inference on device: {device}") | |
| inputs = inputs.to(device) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| do_sample=DO_SAMPLE, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and clean output | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove user's prompt + assistant tag if present | |
| if "<|assistant|>" in response: | |
| response = response.split("<|assistant|>", 1)[-1].strip() | |
| # Cut at end token if exists | |
| if "<|end|>" in response: | |
| response = response.split("<|end|>", 1)[0].strip() | |
| return response.strip() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=gr.Textbox( | |
| label="Ask a question about SQL", | |
| placeholder="Delete duplicate rows from users table based on email", | |
| lines=3, | |
| ), | |
| outputs=gr.Textbox(label="Generated SQL Query"), | |
| title="SQL Chatbot β Phi-3-mini + LoRA", | |
| description=( | |
| "Fine-tuned Phi-3-mini-4k-instruct (4bit) for generating SQL queries\n\n" | |
| "Works on ZeroGPU and regular GPU hardware" | |
| ), | |
| examples=[ | |
| ["Find duplicate emails in users table"], | |
| ["Top 5 highest paid employees"], | |
| ["Count orders per customer last month"], | |
| ["Show all products that haven't been ordered in the last 6 months"], | |
| ["Update all orders from 2024 to status 'completed'"], | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |