saadkhi commited on
Commit
bbdf923
Β·
verified Β·
1 Parent(s): 7697572

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -42
app.py CHANGED
@@ -1,36 +1,37 @@
1
- # app.py
 
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
  from peft import PeftModel
6
 
7
  # ────────────────────────────────────────────────────────────────
8
- # Configuration - fastest practical settings
9
  # ────────────────────────────────────────────────────────────────
10
 
11
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
12
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
13
 
14
- MAX_NEW_TOKENS = 180 # ← keep reasonable
15
- TEMPERATURE = 0.0 # greedy = fastest & most deterministic
16
- DO_SAMPLE = False # no sampling = faster
17
 
18
  # ────────────────────────────────────────────────────────────────
19
- # 4-bit quantization config (this is the key speedup)
20
  # ────────────────────────────────────────────────────────────────
21
 
22
  bnb_config = BitsAndBytesConfig(
23
  load_in_4bit = True,
24
- bnb_4bit_quant_type = "nf4", # "nf4" usually fastest + good quality
25
- bnb_4bit_use_double_quant = True, # nested quantization β†’ extra memory saving
26
- bnb_4bit_compute_dtype = torch.bfloat16 # fastest compute type on modern GPUs
27
  )
28
 
29
  print("Loading quantized base model...")
30
  model = AutoModelForCausalLM.from_pretrained(
31
  BASE_MODEL,
32
  quantization_config = bnb_config,
33
- device_map = "auto", # auto = best available (cuda > cpu)
34
  trust_remote_code = True,
35
  torch_dtype = torch.bfloat16
36
  )
@@ -38,71 +39,64 @@ model = AutoModelForCausalLM.from_pretrained(
38
  print("Loading LoRA adapters...")
39
  model = PeftModel.from_pretrained(model, LORA_PATH)
40
 
41
- # Important: merge LoRA weights into base (faster inference, less overhead)
42
  model = model.merge_and_unload()
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
 
46
- # Optional: small speedup boost on supported hardware
47
- if torch.cuda.is_available():
48
- try:
49
- import torch.backends.cuda
50
- torch.backends.cuda.enable_flash_sdp(True) # flash scaled dot product
51
- except:
52
- pass
53
-
54
  model.eval()
55
  print("Model ready!")
56
 
57
  # ────────────────────────────────────────────────────────────────
58
  def generate_sql(prompt: str):
59
- # Use proper chat template (Phi-3 expects it)
60
  messages = [{"role": "user", "content": prompt}]
61
 
62
  inputs = tokenizer.apply_chat_template(
63
  messages,
64
- tokenize = True,
65
- add_generation_prompt = True,
66
- return_tensors = "pt"
67
  ).to(model.device)
68
 
69
  with torch.inference_mode():
70
  outputs = model.generate(
71
- input_ids = inputs,
72
- max_new_tokens = MAX_NEW_TOKENS,
73
- temperature = TEMPERATURE,
74
- do_sample = DO_SAMPLE,
75
- use_cache = True,
76
- pad_token_id = tokenizer.eos_token_id,
77
- eos_token_id = tokenizer.eos_token_id,
78
  )
79
 
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
 
82
- # Clean output - try to get only assistant's answer
83
  if "<|assistant|>" in response:
84
  response = response.split("<|assistant|>", 1)[-1].strip()
85
- response = response.split("<|end|>")[0].strip()
86
 
87
  return response
88
 
89
  # ────────────────────────────────────────────────────────────────
 
 
 
90
  demo = gr.Interface(
91
- fn = generate_sql,
92
- inputs = gr.Textbox(
93
- label = "Ask SQL related question",
94
- placeholder = "Show me all employees with salary > 50000...",
95
- lines = 3
96
  ),
97
- outputs = gr.Textbox(label="Generated SQL / Answer"),
98
- title = "SQL Chatbot - Fast Version",
99
- description = "Phi-3-mini 4bit quantized + LoRA",
100
- examples = [
101
  ["Find duplicate emails in users table"],
102
  ["Top 5 highest paid employees"],
103
  ["Count orders per customer last month"]
104
  ],
105
- allow_flagging = "never"
106
  )
107
 
108
  if __name__ == "__main__":
 
1
+ # app.py - Fixed for recent Gradio versions (no allow_flagging)
2
+
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
  # ────────────────────────────────────────────────────────────────
9
+ # Fastest practical configuration
10
  # ────────────────────────────────────────────────────────────────
11
 
12
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
13
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
14
 
15
+ MAX_NEW_TOKENS = 180
16
+ TEMPERATURE = 0.0 # greedy = fastest
17
+ DO_SAMPLE = False
18
 
19
  # ────────────────────────────────────────────────────────────────
20
+ # 4-bit quantization (very important for speed)
21
  # ────────────────────────────────────────────────────────────────
22
 
23
  bnb_config = BitsAndBytesConfig(
24
  load_in_4bit = True,
25
+ bnb_4bit_quant_type = "nf4",
26
+ bnb_4bit_use_double_quant = True,
27
+ bnb_4bit_compute_dtype = torch.bfloat16
28
  )
29
 
30
  print("Loading quantized base model...")
31
  model = AutoModelForCausalLM.from_pretrained(
32
  BASE_MODEL,
33
  quantization_config = bnb_config,
34
+ device_map = "auto",
35
  trust_remote_code = True,
36
  torch_dtype = torch.bfloat16
37
  )
 
39
  print("Loading LoRA adapters...")
40
  model = PeftModel.from_pretrained(model, LORA_PATH)
41
 
42
+ # Merge LoRA into base model β†’ much faster inference
43
  model = model.merge_and_unload()
44
 
45
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
46
 
 
 
 
 
 
 
 
 
47
  model.eval()
48
  print("Model ready!")
49
 
50
  # ────────────────────────────────────────────────────────────────
51
  def generate_sql(prompt: str):
 
52
  messages = [{"role": "user", "content": prompt}]
53
 
54
  inputs = tokenizer.apply_chat_template(
55
  messages,
56
+ tokenize=True,
57
+ add_generation_prompt=True,
58
+ return_tensors="pt"
59
  ).to(model.device)
60
 
61
  with torch.inference_mode():
62
  outputs = model.generate(
63
+ input_ids = inputs,
64
+ max_new_tokens = MAX_NEW_TOKENS,
65
+ temperature = TEMPERATURE,
66
+ do_sample = DO_SAMPLE,
67
+ use_cache = True,
68
+ pad_token_id = tokenizer.eos_token_id,
 
69
  )
70
 
71
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
 
73
+ # Clean output
74
  if "<|assistant|>" in response:
75
  response = response.split("<|assistant|>", 1)[-1].strip()
76
+ response = response.split("<|end|>")[0].strip() if "<|end|>" in response else response
77
 
78
  return response
79
 
80
  # ────────────────────────────────────────────────────────────────
81
+ # Gradio interface - modern style (no allow_flagging)
82
+ # ─────────────���──────────────────────────────────────────────────
83
+
84
  demo = gr.Interface(
85
+ fn=generate_sql,
86
+ inputs=gr.Textbox(
87
+ label="Ask SQL related question",
88
+ placeholder="Show me all employees with salary > 50000...",
89
+ lines=3
90
  ),
91
+ outputs=gr.Textbox(label="Generated SQL / Answer"),
92
+ title="SQL Chatbot - Optimized",
93
+ description="Phi-3-mini 4bit + LoRA merged",
94
+ examples=[
95
  ["Find duplicate emails in users table"],
96
  ["Top 5 highest paid employees"],
97
  ["Count orders per customer last month"]
98
  ],
99
+ # flag button is disabled by default in newer versions β†’ no need for allow_flagging
100
  )
101
 
102
  if __name__ == "__main__":