saadkhi commited on
Commit
a2f39c6
Β·
verified Β·
1 Parent(s): 32343cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -33
app.py CHANGED
@@ -1,93 +1,119 @@
1
- # app.py - ZeroGPU safe version (no .to("cuda") outside decorated fn + no caching)
2
-
3
  import torch
4
  import gradio as gr
5
- import spaces # Correct import
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
8
 
9
  # ────────────────────────────────────────────────────────────────
10
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
11
- LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
12
 
13
  MAX_NEW_TOKENS = 180
14
- TEMPERATURE = 0.0
15
- DO_SAMPLE = False
 
 
 
16
 
17
- print("Loading quantized base model on CPU (GPU only in @spaces.GPU)...")
18
  bnb_config = BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_quant_type="nf4",
21
- bnb_4bit_compute_dtype=torch.bfloat16
 
22
  )
23
 
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  BASE_MODEL,
26
  quantization_config=bnb_config,
27
- device_map="cpu", # ← Force CPU at load time (required for ZeroGPU)
28
- trust_remote_code=True
 
29
  )
30
 
31
- print("Loading LoRA...")
32
  model = PeftModel.from_pretrained(model, LORA_PATH)
33
- model = model.merge_and_unload() # Merge for speed
34
 
 
 
 
 
 
35
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
36
- model.eval()
37
 
 
38
  # ────────────────────────────────────────────────────────────────
39
- @spaces.GPU(duration=60) # 60s max is safe & gives good queue priority
 
40
  def generate_sql(prompt: str):
41
- messages = [{"role": "user", "content": prompt}]
42
-
43
- # Tokenize on CPU first
 
 
 
44
  inputs = tokenizer.apply_chat_template(
45
  messages,
46
  tokenize=True,
47
  add_generation_prompt=True,
48
  return_tensors="pt"
49
  )
50
-
51
- # Move to CUDA ONLY inside here (GPU is now allocated)
52
- inputs = inputs.to("cuda")
53
-
 
 
 
54
  with torch.inference_mode():
55
  outputs = model.generate(
56
  input_ids=inputs,
57
  max_new_tokens=MAX_NEW_TOKENS,
58
  temperature=TEMPERATURE,
59
  do_sample=DO_SAMPLE,
60
- use_cache=True,
61
  pad_token_id=tokenizer.eos_token_id,
 
62
  )
63
 
 
64
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
-
66
- # Clean output
67
  if "<|assistant|>" in response:
68
  response = response.split("<|assistant|>", 1)[-1].strip()
 
 
69
  if "<|end|>" in response:
70
- response = response.split("<|end|>")[0].strip()
 
 
71
 
72
- return response
73
 
74
  # ────────────────────────────────────────────────────────────────
75
  demo = gr.Interface(
76
  fn=generate_sql,
77
  inputs=gr.Textbox(
78
- label="Ask SQL question",
79
  placeholder="Delete duplicate rows from users table based on email",
80
- lines=3
 
 
 
 
 
 
81
  ),
82
- outputs=gr.Textbox(label="Generated SQL"),
83
- title="SQL Chatbot (ZeroGPU Safe)",
84
- description="Phi-3-mini 4bit + LoRA - GPU allocated only during generation",
85
  examples=[
86
  ["Find duplicate emails in users table"],
87
  ["Top 5 highest paid employees"],
88
- ["Count orders per customer last month"]
 
 
89
  ],
90
- cache_examples=False # ← CRITICAL: Disable caching to avoid startup .to("cuda") call
91
  )
92
 
93
  if __name__ == "__main__":
 
1
+ # app.py
 
2
  import torch
3
  import gradio as gr
4
+ import spaces
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
  # ────────────────────────────────────────────────────────────────
9
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
10
+ LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
11
 
12
  MAX_NEW_TOKENS = 180
13
+ TEMPERATURE = 0.0
14
+ DO_SAMPLE = False
15
+
16
+ print("Loading quantized base model on CPU...")
17
+ print("(GPU will be used only during inference if available)")
18
 
19
+ # 4-bit quantization config
20
  bnb_config = BitsAndBytesConfig(
21
  load_in_4bit=True,
22
  bnb_4bit_quant_type="nf4",
23
+ bnb_4bit_compute_dtype=torch.bfloat16,
24
+ bnb_4bit_use_double_quant=True,
25
  )
26
 
27
+ # Load base model β†’ always on CPU first
28
  model = AutoModelForCausalLM.from_pretrained(
29
  BASE_MODEL,
30
  quantization_config=bnb_config,
31
+ device_map="cpu",
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.bfloat16,
34
  )
35
 
36
+ print("Loading LoRA adapters...")
37
  model = PeftModel.from_pretrained(model, LORA_PATH)
 
38
 
39
+ # Merge for faster inference (very recommended)
40
+ print("Merging LoRA into base model...")
41
+ model = model.merge_and_unload()
42
+
43
+ # Load tokenizer
44
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
45
+ tokenizer.pad_token = tokenizer.eos_token
46
 
47
+ model.eval()
48
  # ────────────────────────────────────────────────────────────────
49
+
50
+ @spaces.GPU(duration=60, max_requests=20) # safe values for ZeroGPU
51
  def generate_sql(prompt: str):
52
+ # Prepare chat format
53
+ messages = [
54
+ {"role": "user", "content": prompt}
55
+ ]
56
+
57
+ # Tokenize on CPU (safe everywhere)
58
  inputs = tokenizer.apply_chat_template(
59
  messages,
60
  tokenize=True,
61
  add_generation_prompt=True,
62
  return_tensors="pt"
63
  )
64
+
65
+ # Choose device dynamically - this is the ZeroGPU-safe way
66
+ device = "cuda" if torch.cuda.is_available() else "cpu"
67
+ print(f"β†’ Running inference on device: {device}")
68
+
69
+ inputs = inputs.to(device)
70
+
71
  with torch.inference_mode():
72
  outputs = model.generate(
73
  input_ids=inputs,
74
  max_new_tokens=MAX_NEW_TOKENS,
75
  temperature=TEMPERATURE,
76
  do_sample=DO_SAMPLE,
 
77
  pad_token_id=tokenizer.eos_token_id,
78
+ eos_token_id=tokenizer.eos_token_id,
79
  )
80
 
81
+ # Decode and clean output
82
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+
84
+ # Remove user's prompt + assistant tag if present
85
  if "<|assistant|>" in response:
86
  response = response.split("<|assistant|>", 1)[-1].strip()
87
+
88
+ # Cut at end token if exists
89
  if "<|end|>" in response:
90
+ response = response.split("<|end|>", 1)[0].strip()
91
+
92
+ return response.strip()
93
 
 
94
 
95
  # ────────────────────────────────────────────────────────────────
96
  demo = gr.Interface(
97
  fn=generate_sql,
98
  inputs=gr.Textbox(
99
+ label="Ask a question about SQL",
100
  placeholder="Delete duplicate rows from users table based on email",
101
+ lines=3,
102
+ ),
103
+ outputs=gr.Textbox(label="Generated SQL Query"),
104
+ title="SQL Chatbot – Phi-3-mini + LoRA",
105
+ description=(
106
+ "Fine-tuned Phi-3-mini-4k-instruct (4bit) for generating SQL queries\n\n"
107
+ "Works on ZeroGPU and regular GPU hardware"
108
  ),
 
 
 
109
  examples=[
110
  ["Find duplicate emails in users table"],
111
  ["Top 5 highest paid employees"],
112
+ ["Count orders per customer last month"],
113
+ ["Show all products that haven't been ordered in the last 6 months"],
114
+ ["Update all orders from 2024 to status 'completed'"],
115
  ],
116
+ cache_examples=False,
117
  )
118
 
119
  if __name__ == "__main__":