saadkhi commited on
Commit
c7c0d53
Β·
verified Β·
1 Parent(s): 2af8046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -52
app.py CHANGED
@@ -1,11 +1,11 @@
1
- # app.py
 
2
  import torch
3
  import gradio as gr
4
- import spaces
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
- # ────────────────────────────────────────────────────────────────
9
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
10
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
11
 
@@ -13,48 +13,37 @@ MAX_NEW_TOKENS = 180
13
  TEMPERATURE = 0.0
14
  DO_SAMPLE = False
15
 
16
- print("Loading quantized base model on CPU...")
17
- print("(GPU will be used only during inference if available)")
18
 
19
- # 4-bit quantization config
20
  bnb_config = BitsAndBytesConfig(
21
  load_in_4bit=True,
22
  bnb_4bit_quant_type="nf4",
23
- bnb_4bit_compute_dtype=torch.bfloat16,
24
- bnb_4bit_use_double_quant=True,
25
  )
26
 
27
- # Load base model β†’ always on CPU first
28
  model = AutoModelForCausalLM.from_pretrained(
29
  BASE_MODEL,
30
  quantization_config=bnb_config,
31
  device_map="cpu",
32
- trust_remote_code=True,
33
- torch_dtype=torch.bfloat16,
34
  )
35
 
36
- print("Loading LoRA adapters...")
37
  model = PeftModel.from_pretrained(model, LORA_PATH)
38
 
39
- # Merge for faster inference (very recommended)
40
- print("Merging LoRA into base model...")
41
  model = model.merge_and_unload()
42
 
43
- # Load tokenizer
44
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
- tokenizer.pad_token = tokenizer.eos_token
46
-
47
  model.eval()
48
- # ────────────────────────────────────────────────────────────────
49
 
50
- @spaces.GPU(duration=60, max_requests=20) # safe values for ZeroGPU
51
  def generate_sql(prompt: str):
52
- # Prepare chat format
53
- messages = [
54
- {"role": "user", "content": prompt}
55
- ]
56
 
57
- # Tokenize on CPU (safe everywhere)
58
  inputs = tokenizer.apply_chat_template(
59
  messages,
60
  tokenize=True,
@@ -62,59 +51,44 @@ def generate_sql(prompt: str):
62
  return_tensors="pt"
63
  )
64
 
65
- # Choose device dynamically - this is the ZeroGPU-safe way
66
- device = "cuda" if torch.cuda.is_available() else "cpu"
67
- print(f"β†’ Running inference on device: {device}")
68
-
69
- inputs = inputs.to(device)
70
-
71
  with torch.inference_mode():
72
  outputs = model.generate(
73
  input_ids=inputs,
74
  max_new_tokens=MAX_NEW_TOKENS,
75
  temperature=TEMPERATURE,
76
  do_sample=DO_SAMPLE,
 
77
  pad_token_id=tokenizer.eos_token_id,
78
- eos_token_id=tokenizer.eos_token_id,
79
  )
80
 
81
- # Decode and clean output
82
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
 
84
- # Remove user's prompt + assistant tag if present
85
  if "<|assistant|>" in response:
86
  response = response.split("<|assistant|>", 1)[-1].strip()
87
-
88
- # Cut at end token if exists
89
  if "<|end|>" in response:
90
- response = response.split("<|end|>", 1)[0].strip()
91
-
92
- return response.strip()
93
 
 
94
 
95
- # ────────────────────────────────────────────────────────────────
96
  demo = gr.Interface(
97
  fn=generate_sql,
98
  inputs=gr.Textbox(
99
- label="Ask a question about SQL",
100
  placeholder="Delete duplicate rows from users table based on email",
101
- lines=3,
102
- ),
103
- outputs=gr.Textbox(label="Generated SQL Query"),
104
- title="SQL Chatbot – Phi-3-mini + LoRA",
105
- description=(
106
- "Fine-tuned Phi-3-mini-4k-instruct (4bit) for generating SQL queries\n\n"
107
- "Works on ZeroGPU and regular GPU hardware"
108
  ),
 
 
 
109
  examples=[
110
  ["Find duplicate emails in users table"],
111
  ["Top 5 highest paid employees"],
112
- ["Count orders per customer last month"],
113
- ["Show all products that haven't been ordered in the last 6 months"],
114
- ["Update all orders from 2024 to status 'completed'"],
115
  ],
116
- cache_examples=False,
117
  )
118
 
119
  if __name__ == "__main__":
120
- demo.launch()
 
1
+ # app.py - CPU SAFE VERSION (No CUDA, No GPU)
2
+
3
  import torch
4
  import gradio as gr
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
+ # ─────────────────────────────────────────────
9
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
10
  LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
11
 
 
13
  TEMPERATURE = 0.0
14
  DO_SAMPLE = False
15
 
16
+ print("Loading model on CPU...")
 
17
 
18
+ # 4-bit config (works on CPU but slower)
19
  bnb_config = BitsAndBytesConfig(
20
  load_in_4bit=True,
21
  bnb_4bit_quant_type="nf4",
22
+ bnb_4bit_compute_dtype=torch.bfloat16
 
23
  )
24
 
25
+ # Load base model on CPU
26
  model = AutoModelForCausalLM.from_pretrained(
27
  BASE_MODEL,
28
  quantization_config=bnb_config,
29
  device_map="cpu",
30
+ trust_remote_code=True
 
31
  )
32
 
33
+ print("Loading LoRA...")
34
  model = PeftModel.from_pretrained(model, LORA_PATH)
35
 
36
+ # Merge LoRA for simpler inference
 
37
  model = model.merge_and_unload()
38
 
 
39
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
 
40
  model.eval()
 
41
 
42
+ # ─────────────────────────────────────────────
43
  def generate_sql(prompt: str):
44
+ messages = [{"role": "user", "content": prompt}]
 
 
 
45
 
46
+ # Tokenize (CPU)
47
  inputs = tokenizer.apply_chat_template(
48
  messages,
49
  tokenize=True,
 
51
  return_tensors="pt"
52
  )
53
 
 
 
 
 
 
 
54
  with torch.inference_mode():
55
  outputs = model.generate(
56
  input_ids=inputs,
57
  max_new_tokens=MAX_NEW_TOKENS,
58
  temperature=TEMPERATURE,
59
  do_sample=DO_SAMPLE,
60
+ use_cache=True,
61
  pad_token_id=tokenizer.eos_token_id,
 
62
  )
63
 
 
64
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
 
66
+ # Cleanup
67
  if "<|assistant|>" in response:
68
  response = response.split("<|assistant|>", 1)[-1].strip()
 
 
69
  if "<|end|>" in response:
70
+ response = response.split("<|end|>")[0].strip()
 
 
71
 
72
+ return response
73
 
74
+ # ─────────────────────────────────────────────
75
  demo = gr.Interface(
76
  fn=generate_sql,
77
  inputs=gr.Textbox(
78
+ label="Ask SQL question",
79
  placeholder="Delete duplicate rows from users table based on email",
80
+ lines=3
 
 
 
 
 
 
81
  ),
82
+ outputs=gr.Textbox(label="Generated SQL"),
83
+ title="SQL Chatbot (CPU Mode)",
84
+ description="Phi-3-mini 4bit + LoRA (CPU only, slower inference)",
85
  examples=[
86
  ["Find duplicate emails in users table"],
87
  ["Top 5 highest paid employees"],
88
+ ["Count orders per customer last month"]
 
 
89
  ],
90
+ cache_examples=False
91
  )
92
 
93
  if __name__ == "__main__":
94
+ demo.launch()