Spaces:
Runtime error
Runtime error
| # app.py | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" | |
| LORA_PATH = "saadkhi/SQL_Chat_finetuned_model" | |
| MAX_NEW_TOKENS = 180 | |
| TEMPERATURE = 0.0 | |
| DO_SAMPLE = False | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load model safely on CPU first | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading base model on CPU...") | |
| try: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="cpu", # Critical for ZeroGPU + CPU spaces | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| print("Loading and merging LoRA adapters...") | |
| model = PeftModel.from_pretrained(model, LORA_PATH) | |
| model = model.merge_and_unload() # Merge once β faster inference | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model.eval() | |
| print("Model successfully loaded on CPU") | |
| except Exception as e: | |
| print(f"Model loading failed: {str(e)}") | |
| raise | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Inference function β GPU only here | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 60 seconds is usually enough | |
| def generate_sql(prompt: str): | |
| try: | |
| messages = [{"role": "user", "content": prompt.strip()}] | |
| # Tokenize on CPU | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| # Move to GPU only inside decorated function | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to("cuda") | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| do_sample=DO_SAMPLE, | |
| use_cache=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean output | |
| if "<|assistant|>" in response: | |
| response = response.split("<|assistant|>", 1)[-1].strip() | |
| if "<|end|>" in response: | |
| response = response.split("<|end|>")[0].strip() | |
| if "<|user|>" in response: | |
| response = response.split("<|user|>")[0].strip() | |
| return response.strip() or "No valid response generated." | |
| except Exception as e: | |
| return f"Error during generation: {str(e)}" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio Interface | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=gr.Textbox( | |
| label="Your SQL-related question", | |
| placeholder="e.g. Find duplicate emails in users table", | |
| lines=3, | |
| max_lines=6 | |
| ), | |
| outputs=gr.Textbox( | |
| label="Generated SQL / Answer", | |
| lines=6 | |
| ), | |
| title="SQL Chatbot β Phi-3-mini fine-tuned", | |
| description=( | |
| "Ask questions about SQL queries.\n\n" | |
| "Free CPU version β responses may take 30β120 seconds or more." | |
| ), | |
| examples=[ | |
| ["Find duplicate emails in users table"], | |
| ["Top 5 highest paid employees from employees table"], | |
| ["Count total orders per customer in last 30 days"], | |
| ["Delete duplicate rows based on email column"] | |
| ], | |
| cache_examples=False, # keep this | |
| # allow_flagging="never" β REMOVE THIS LINE COMPLETELY | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting Gradio server...") | |
| import time | |
| time.sleep(15) # Give extra time for model/Gradio to settle | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=False, | |
| quiet=False, | |
| show_error=True, | |
| prevent_thread_lock=True # Helps in containers | |
| ) | |
| except Exception as e: | |
| print(f"Launch failed: {str(e)}") | |
| raise |