File size: 5,075 Bytes
4bc3e8b
 
c7c0d53
7f424d1
02976e0
7f424d1
 
00c8a57
0fad5f5
4bc3e8b
8b67be0
4bc3e8b
84031c5
4bc3e8b
02976e0
 
4bc3e8b
 
a2f39c6
8b67be0
4bc3e8b
8b67be0
4bc3e8b
 
8b67be0
 
 
 
 
 
 
 
 
4bc3e8b
 
 
 
8b67be0
 
4bc3e8b
8b67be0
4bc3e8b
22df2c5
02976e0
8b67be0
 
a2f39c6
4bc3e8b
8b67be0
 
 
a2f39c6
0fad5f5
4bc3e8b
8b67be0
4bc3e8b
 
8b67be0
4bc3e8b
22df2c5
8b67be0
 
 
 
 
7f424d1
22df2c5
8b67be0
 
4bc3e8b
 
 
 
 
 
8b67be0
02976e0
8b67be0
22df2c5
4bc3e8b
 
 
 
02976e0
4bc3e8b
84031c5
8b67be0
4bc3e8b
8b67be0
 
4bc3e8b
0fad5f5
4bc3e8b
84031c5
4bc3e8b
 
 
 
 
 
8b67be0
4bc3e8b
 
 
8b67be0
4bc3e8b
 
 
 
02976e0
4bc3e8b
02976e0
4bc3e8b
 
 
32343cc
4bc3e8b
84031c5
1344c31
84031c5
4bc3e8b
9ae2c39
4bc3e8b
 
 
 
 
 
bf06a99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# app.py
# Minimal & stable version for free CPU Hugging Face Space – Phi-3-mini + LoRA

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ────────────────────────────────────────────────────────────────
# Config
# ────────────────────────────────────────────────────────────────

BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH  = "saadkhi/SQL_Chat_finetuned_model"

MAX_NEW_TOKENS = 180
TEMPERATURE    = 0.0
DO_SAMPLE      = False

# ────────────────────────────────────────────────────────────────
# Load model & tokenizer
# ────────────────────────────────────────────────────────────────

print("Loading base model (CPU)...")
try:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config = bnb_config,
        device_map          = "cpu",
        trust_remote_code   = True,
        low_cpu_mem_usage   = True
    )

    print("Loading LoRA...")
    model = PeftModel.from_pretrained(model, LORA_PATH)
    print("Merging LoRA weights...")
    model = model.merge_and_unload()

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    model.eval()

    print("Model & tokenizer loaded successfully")
except Exception as e:
    print(f"Model loading failed: {str(e)}")
    raise

# ────────────────────────────────────────────────────────────────
# Inference function
# ────────────────────────────────────────────────────────────────

def generate_sql(question: str):
    try:
        messages = [{"role": "user", "content": question.strip()}]

        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        )

        with torch.inference_mode():
            outputs = model.generate(
                input_ids       = inputs,
                max_new_tokens  = MAX_NEW_TOKENS,
                temperature     = TEMPERATURE,
                do_sample       = DO_SAMPLE,
                use_cache       = True,
                pad_token_id    = tokenizer.eos_token_id,
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Clean typical Phi-3 output markers
        for marker in ["<|assistant|>", "<|end|>", "<|user|>"]:
            if marker in response:
                response = response.split(marker, 1)[-1].strip()

        return response.strip() or "(empty response)"

    except Exception as e:
        return f"Generation error: {str(e)}"

# ────────────────────────────────────────────────────────────────
# Gradio UI
# ────────────────────────────────────────────────────────────────

demo = gr.Interface(
    fn              = generate_sql,
    inputs          = gr.Textbox(
        label       = "SQL question",
        placeholder = "Find duplicate emails in users table",
        lines       = 3,
        max_lines   = 6
    ),
    outputs         = gr.Textbox(
        label       = "Generated SQL",
        lines       = 8
    ),
    title           = "SQL Chat – Phi-3-mini fine-tuned (CPU)",
    description     = (
        "Free CPU version – first answer usually takes 60–180+ seconds.\n"
        "Later answers are faster (model stays in memory)."
    ),
    examples        = [
        ["Find duplicate emails in users table"],
        ["Top 5 highest paid employees"],
        ["Count orders per customer last month"],
        ["Delete duplicate rows based on email"]
    ],
    cache_examples  = False,
)

if __name__ == "__main__":
    print("Launching interface...")
    demo.launch(
        server_name       = "0.0.0.0",
        # NO fixed server_port β†’ let Gradio pick free port automatically
        debug             = False,
        quiet             = False,
        show_error        = True,
        prevent_thread_lock = True
    )