File size: 5,555 Bytes
8b67be0
c7c0d53
7f424d1
02976e0
0fad5f5
7f424d1
 
00c8a57
0fad5f5
8b67be0
 
 
84031c5
0fad5f5
02976e0
 
0fad5f5
 
a2f39c6
8b67be0
 
 
7f424d1
8b67be0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02976e0
8b67be0
 
a2f39c6
8b67be0
 
 
 
a2f39c6
0fad5f5
8b67be0
 
 
 
84031c5
8b67be0
 
 
 
 
 
 
 
 
7f424d1
8b67be0
 
 
 
 
 
 
 
 
 
 
 
 
 
02976e0
8b67be0
 
 
 
 
 
 
 
 
02976e0
8b67be0
84031c5
8b67be0
 
 
 
 
0fad5f5
8b67be0
84031c5
7f424d1
02976e0
8b67be0
 
 
 
 
 
 
 
 
 
 
 
9158eaa
02976e0
 
 
8b67be0
 
 
32343cc
9158eaa
 
84031c5
1344c31
84031c5
8b67be0
9ae2c39
 
 
 
 
 
 
 
 
 
 
8b67be0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# app.py

import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ────────────────────────────────────────────────────────────────
# Configuration
# ────────────────────────────────────────────────────────────────

BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH  = "saadkhi/SQL_Chat_finetuned_model"

MAX_NEW_TOKENS = 180
TEMPERATURE    = 0.0
DO_SAMPLE      = False

# ────────────────────────────────────────────────────────────────
# Load model safely on CPU first
# ────────────────────────────────────────────────────────────────

print("Loading base model on CPU...")
try:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="cpu",           # Critical for ZeroGPU + CPU spaces
        trust_remote_code=True,
        low_cpu_mem_usage=True
    )

    print("Loading and merging LoRA adapters...")
    model = PeftModel.from_pretrained(model, LORA_PATH)
    model = model.merge_and_unload()  # Merge once β†’ faster inference

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    model.eval()

    print("Model successfully loaded on CPU")
except Exception as e:
    print(f"Model loading failed: {str(e)}")
    raise

# ────────────────────────────────────────────────────────────────
# Inference function – GPU only here
# ────────────────────────────────────────────────────────────────

@spaces.GPU(duration=60)  # 60 seconds is usually enough
def generate_sql(prompt: str):
    try:
        messages = [{"role": "user", "content": prompt.strip()}]
        
        # Tokenize on CPU
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        )
        
        # Move to GPU only inside decorated function
        if torch.cuda.is_available():
            inputs = inputs.to("cuda")
        
        with torch.inference_mode():
            outputs = model.generate(
                input_ids=inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                do_sample=DO_SAMPLE,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id,
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean output
        if "<|assistant|>" in response:
            response = response.split("<|assistant|>", 1)[-1].strip()
        if "<|end|>" in response:
            response = response.split("<|end|>")[0].strip()
        if "<|user|>" in response:
            response = response.split("<|user|>")[0].strip()

        return response.strip() or "No valid response generated."

    except Exception as e:
        return f"Error during generation: {str(e)}"

# ────────────────────────────────────────────────────────────────
# Gradio Interface
# ────────────────────────────────────────────────────────────────

demo = gr.Interface(
    fn=generate_sql,
    inputs=gr.Textbox(
        label="Your SQL-related question",
        placeholder="e.g. Find duplicate emails in users table",
        lines=3,
        max_lines=6
    ),
    outputs=gr.Textbox(
        label="Generated SQL / Answer",
        lines=6
    ),
    title="SQL Chatbot – Phi-3-mini fine-tuned",
    description=(
        "Ask questions about SQL queries.\n\n"
        "Free CPU version – responses may take 30–120 seconds or more."
    ),
    examples=[
        ["Find duplicate emails in users table"],
        ["Top 5 highest paid employees from employees table"],
        ["Count total orders per customer in last 30 days"],
        ["Delete duplicate rows based on email column"]
    ],
    cache_examples=False,          # keep this
    # allow_flagging="never"       ← REMOVE THIS LINE COMPLETELY
)

if __name__ == "__main__":
    print("Starting Gradio server...")
    import time
    time.sleep(15)  # Give extra time for model/Gradio to settle
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        debug=False,
        quiet=False,
        show_error=True,
        prevent_thread_lock=True  # Helps in containers
    )
    except Exception as e:
        print(f"Launch failed: {str(e)}")
        raise