File size: 3,960 Bytes
a2f39c6
7f424d1
02976e0
a2f39c6
7f424d1
 
00c8a57
02976e0
84031c5
a2f39c6
02976e0
 
a2f39c6
 
 
 
 
02976e0
a2f39c6
02976e0
 
 
a2f39c6
 
02976e0
7f424d1
a2f39c6
7f424d1
 
 
a2f39c6
 
 
84031c5
02976e0
a2f39c6
7f424d1
02976e0
a2f39c6
 
 
 
 
7f424d1
a2f39c6
107fcf0
a2f39c6
02976e0
a2f39c6
 
84031c5
a2f39c6
 
 
 
 
 
02976e0
 
 
 
 
32343cc
a2f39c6
 
 
 
 
 
 
7f424d1
 
02976e0
 
 
 
7f424d1
a2f39c6
7f424d1
02976e0
a2f39c6
7f424d1
a2f39c6
 
84031c5
 
a2f39c6
 
32343cc
a2f39c6
 
 
02976e0
84031c5
02976e0
84031c5
7f424d1
02976e0
a2f39c6
02976e0
a2f39c6
 
 
 
 
 
 
02976e0
 
 
 
a2f39c6
 
 
32343cc
a2f39c6
84031c5
1344c31
84031c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# app.py
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ────────────────────────────────────────────────────────────────
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"

MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0
DO_SAMPLE = False

print("Loading quantized base model on CPU...")
print("(GPU will be used only during inference if available)")

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load base model β†’ always on CPU first
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="cpu",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)

print("Loading LoRA adapters...")
model = PeftModel.from_pretrained(model, LORA_PATH)

# Merge for faster inference (very recommended)
print("Merging LoRA into base model...")
model = model.merge_and_unload()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

model.eval()
# ────────────────────────────────────────────────────────────────

@spaces.GPU(duration=60, max_requests=20)  # safe values for ZeroGPU
def generate_sql(prompt: str):
    # Prepare chat format
    messages = [
        {"role": "user", "content": prompt}
    ]

    # Tokenize on CPU (safe everywhere)
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    )

    # Choose device dynamically - this is the ZeroGPU-safe way
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"β†’ Running inference on device: {device}")

    inputs = inputs.to(device)

    with torch.inference_mode():
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            do_sample=DO_SAMPLE,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Decode and clean output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Remove user's prompt + assistant tag if present
    if "<|assistant|>" in response:
        response = response.split("<|assistant|>", 1)[-1].strip()

    # Cut at end token if exists
    if "<|end|>" in response:
        response = response.split("<|end|>", 1)[0].strip()

    return response.strip()


# ────────────────────────────────────────────────────────────────
demo = gr.Interface(
    fn=generate_sql,
    inputs=gr.Textbox(
        label="Ask a question about SQL",
        placeholder="Delete duplicate rows from users table based on email",
        lines=3,
    ),
    outputs=gr.Textbox(label="Generated SQL Query"),
    title="SQL Chatbot – Phi-3-mini + LoRA",
    description=(
        "Fine-tuned Phi-3-mini-4k-instruct (4bit) for generating SQL queries\n\n"
        "Works on ZeroGPU and regular GPU hardware"
    ),
    examples=[
        ["Find duplicate emails in users table"],
        ["Top 5 highest paid employees"],
        ["Count orders per customer last month"],
        ["Show all products that haven't been ordered in the last 6 months"],
        ["Update all orders from 2024 to status 'completed'"],
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch()