File size: 4,168 Bytes
bbdf923
 
24f8f89
87ff5b4
00c8a57
 
1344c31
 
bbdf923
1344c31
00c8a57
1344c31
00c8a57
 
bbdf923
 
 
1344c31
 
bbdf923
00c8a57
 
 
 
bbdf923
 
 
52ae0ac
1344c31
00c8a57
 
 
 
bbdf923
00c8a57
 
1344c31
 
00c8a57
 
 
bbdf923
00c8a57
 
 
1344c31
00c8a57
1344c31
 
 
 
 
 
 
 
bbdf923
 
 
00c8a57
 
 
 
bbdf923
 
 
 
 
 
00c8a57
1344c31
 
00c8a57
bbdf923
1344c31
 
bbdf923
1344c31
 
 
 
bbdf923
 
 
1344c31
bbdf923
 
 
 
 
1344c31
bbdf923
 
 
 
00c8a57
 
 
1344c31
bbdf923
1344c31
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# app.py - Fixed for recent Gradio versions (no allow_flagging)

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ────────────────────────────────────────────────────────────────
#   Fastest practical configuration
# ────────────────────────────────────────────────────────────────

BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH  = "saadkhi/SQL_Chat_finetuned_model"

MAX_NEW_TOKENS = 180
TEMPERATURE    = 0.0          # greedy = fastest
DO_SAMPLE      = False

# ────────────────────────────────────────────────────────────────
#   4-bit quantization (very important for speed)
# ────────────────────────────────────────────────────────────────

bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_use_double_quant = True,
    bnb_4bit_compute_dtype    = torch.bfloat16
)

print("Loading quantized base model...")
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config = bnb_config,
    device_map          = "auto",
    trust_remote_code   = True,
    torch_dtype         = torch.bfloat16
)

print("Loading LoRA adapters...")
model = PeftModel.from_pretrained(model, LORA_PATH)

# Merge LoRA into base model β†’ much faster inference
model = model.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

model.eval()
print("Model ready!")

# ────────────────────────────────────────────────────────────────
def generate_sql(prompt: str):
    messages = [{"role": "user", "content": prompt}]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    with torch.inference_mode():
        outputs = model.generate(
            input_ids      = inputs,
            max_new_tokens = MAX_NEW_TOKENS,
            temperature    = TEMPERATURE,
            do_sample      = DO_SAMPLE,
            use_cache      = True,
            pad_token_id   = tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Clean output
    if "<|assistant|>" in response:
        response = response.split("<|assistant|>", 1)[-1].strip()
    response = response.split("<|end|>")[0].strip() if "<|end|>" in response else response

    return response

# ────────────────────────────────────────────────────────────────
#   Gradio interface - modern style (no allow_flagging)
# ────────────────────────────────────────────────────────────────

demo = gr.Interface(
    fn=generate_sql,
    inputs=gr.Textbox(
        label="Ask SQL related question",
        placeholder="Show me all employees with salary > 50000...",
        lines=3
    ),
    outputs=gr.Textbox(label="Generated SQL / Answer"),
    title="SQL Chatbot - Optimized",
    description="Phi-3-mini 4bit + LoRA merged",
    examples=[
        ["Find duplicate emails in users table"],
        ["Top 5 highest paid employees"],
        ["Count orders per customer last month"]
    ],
    # flag button is disabled by default in newer versions β†’ no need for allow_flagging
)

if __name__ == "__main__":
    demo.launch()