saadkhi commited on
Commit
24f8f89
·
verified ·
1 Parent(s): cdd8e55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -25
app.py CHANGED
@@ -1,40 +1,94 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
- from transformers import BitsAndBytesConfig
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
8
 
9
- base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
10
- finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
 
 
 
 
 
11
 
12
- tokenizer = AutoTokenizer.from_pretrained(base_model)
 
 
13
 
14
- bnb = BitsAndBytesConfig(load_in_4bit=True)
 
 
 
 
 
 
 
15
 
16
- model = AutoModelForCausalLM.from_pretrained(
 
17
  base_model,
18
- quantization_config=bnb,
19
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
20
- device_map="auto"
21
  )
22
 
23
- model = PeftModel.from_pretrained(model, finetuned_model).to(device)
 
24
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def chat(prompt):
27
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
28
 
29
- with torch.inference_mode():
30
- output = model.generate(
31
- **inputs,
32
- max_new_tokens=60,
33
- temperature=0.8,
34
- do_sample=False
35
- )
 
 
 
 
 
36
 
37
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
38
 
39
- iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="SQL Chatbot")
40
- iface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import PeftModel
4
+ import torch
5
+ import os
6
 
7
+ # Create offload folder (very important!)
8
+ OFFLOAD_DIR = "offload"
9
+ os.makedirs(OFFLOAD_DIR, exist_ok=True)
10
 
11
+ # Optimal 4-bit quantization config
12
+ quant_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.bfloat16,
16
+ bnb_4bit_use_double_quant=True,
17
+ )
18
 
19
+ # Model paths
20
+ base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
21
+ lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
22
 
23
+ print("Loading base model...")
24
+ base_model = AutoModelForCausalLM.from_pretrained(
25
+ base_model_name,
26
+ quantization_config=quant_config,
27
+ device_map="auto",
28
+ trust_remote_code=True,
29
+ offload_folder=OFFLOAD_DIR, # ← Required fix!
30
+ )
31
 
32
+ print("Loading LoRA adapter...")
33
+ model = PeftModel.from_pretrained(
34
  base_model,
35
+ lora_model_name,
36
+ offload_folder=OFFLOAD_DIR, # Required here too!
 
37
  )
38
 
39
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
40
+
41
  model.eval()
42
+ print("Model loaded successfully!")
43
+
44
+ def chat(message, history):
45
+ # Build conversation in correct Phi-3 format
46
+ messages = []
47
+ for user, assistant in history:
48
+ messages.append({"role": "user", "content": user})
49
+ if assistant:
50
+ messages.append({"role": "assistant", "content": assistant})
51
+ messages.append({"role": "user", "content": message})
52
+
53
+ inputs = tokenizer.apply_chat_template(
54
+ messages,
55
+ tokenize=True,
56
+ add_generation_prompt=True,
57
+ return_tensors="pt"
58
+ ).to(model.device)
59
+
60
+ # Fast generation settings
61
+ outputs = model.generate(
62
+ inputs,
63
+ max_new_tokens=256,
64
+ temperature=0.7,
65
+ do_sample=True,
66
+ top_p=0.9,
67
+ repetition_penalty=1.1,
68
+ use_cache=True,
69
+ eos_token_id=tokenizer.eos_token_id,
70
+ )
71
 
72
+ response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
73
+
74
+ history.append((message, response))
75
+ return history, ""
76
 
77
+ # Gradio UI
78
+ with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
79
+ gr.Markdown("# SQL Chat Assistant")
80
+ gr.Markdown("Fine-tuned Phi-3 Mini (4-bit) for SQL queries. Responses ~3–10s on GPU.")
81
+
82
+ chatbot = gr.Chatbot(height=500)
83
+ msg = gr.Textbox(
84
+ label="Your Question",
85
+ placeholder="e.g., delete duplicate rows from users table based on email",
86
+ lines=2
87
+ )
88
+ clear = gr.Button("Clear")
89
 
90
+ msg.submit(chat, [msg, chatbot], [chatbot, msg])
91
+ clear.click(lambda: ([], ""), None, chatbot)
92
 
93
+ demo.queue(max_size=30)
94
+ demo.launch()