# app.py import torch import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel # ──────────────────────────────────────────────────────────────── # Configuration # ──────────────────────────────────────────────────────────────── BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" LORA_PATH = "saadkhi/SQL_Chat_finetuned_model" MAX_NEW_TOKENS = 180 TEMPERATURE = 0.0 DO_SAMPLE = False # ──────────────────────────────────────────────────────────────── # Load model safely on CPU first # ──────────────────────────────────────────────────────────────── print("Loading base model on CPU...") try: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="cpu", # Critical for ZeroGPU + CPU spaces trust_remote_code=True, low_cpu_mem_usage=True ) print("Loading and merging LoRA adapters...") model = PeftModel.from_pretrained(model, LORA_PATH) model = model.merge_and_unload() # Merge once → faster inference tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model.eval() print("Model successfully loaded on CPU") except Exception as e: print(f"Model loading failed: {str(e)}") raise # ──────────────────────────────────────────────────────────────── # Inference function – GPU only here # ──────────────────────────────────────────────────────────────── @spaces.GPU(duration=60) # 60 seconds is usually enough def generate_sql(prompt: str): try: messages = [{"role": "user", "content": prompt.strip()}] # Tokenize on CPU inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) # Move to GPU only inside decorated function if torch.cuda.is_available(): inputs = inputs.to("cuda") with torch.inference_mode(): outputs = model.generate( input_ids=inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, do_sample=DO_SAMPLE, use_cache=True, pad_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean output if "<|assistant|>" in response: response = response.split("<|assistant|>", 1)[-1].strip() if "<|end|>" in response: response = response.split("<|end|>")[0].strip() if "<|user|>" in response: response = response.split("<|user|>")[0].strip() return response.strip() or "No valid response generated." except Exception as e: return f"Error during generation: {str(e)}" # ──────────────────────────────────────────────────────────────── # Gradio Interface # ──────────────────────────────────────────────────────────────── demo = gr.Interface( fn=generate_sql, inputs=gr.Textbox( label="Your SQL-related question", placeholder="e.g. Find duplicate emails in users table", lines=3, max_lines=6 ), outputs=gr.Textbox( label="Generated SQL / Answer", lines=6 ), title="SQL Chatbot – Phi-3-mini fine-tuned", description=( "Ask questions about SQL queries.\n\n" "Free CPU version – responses may take 30–120 seconds or more." ), examples=[ ["Find duplicate emails in users table"], ["Top 5 highest paid employees from employees table"], ["Count total orders per customer in last 30 days"], ["Delete duplicate rows based on email column"] ], cache_examples=False, # keep this # allow_flagging="never" ← REMOVE THIS LINE COMPLETELY ) if __name__ == "__main__": print("Starting Gradio server...") import time time.sleep(15) # Give extra time for model/Gradio to settle demo.launch( server_name="0.0.0.0", server_port=7860, debug=False, quiet=False, show_error=True, prevent_thread_lock=True # Helps in containers ) except Exception as e: print(f"Launch failed: {str(e)}") raise