saadkhi commited on
Commit
32343cc
Β·
verified Β·
1 Parent(s): 02976e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -1,8 +1,8 @@
1
- # app.py - ZeroGPU compatible version (standard transformers + @spaces.GPU)
2
 
3
  import torch
4
  import gradio as gr
5
- import spaces # ← Correct import!
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
8
 
@@ -14,7 +14,7 @@ MAX_NEW_TOKENS = 180
14
  TEMPERATURE = 0.0
15
  DO_SAMPLE = False
16
 
17
- print("Loading quantized base model (CPU first)...")
18
  bnb_config = BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_quant_type="nf4",
@@ -24,29 +24,33 @@ bnb_config = BitsAndBytesConfig(
24
  model = AutoModelForCausalLM.from_pretrained(
25
  BASE_MODEL,
26
  quantization_config=bnb_config,
27
- device_map="auto",
28
  trust_remote_code=True
29
  )
30
 
31
  print("Loading LoRA...")
32
  model = PeftModel.from_pretrained(model, LORA_PATH)
33
- model = model.merge_and_unload() # Merge for faster inference
34
 
35
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
36
  model.eval()
37
 
38
  # ────────────────────────────────────────────────────────────────
39
- @spaces.GPU(duration=60) # ← Decorator! Requests GPU slice only here (60s max recommended)
40
  def generate_sql(prompt: str):
41
  messages = [{"role": "user", "content": prompt}]
42
 
 
43
  inputs = tokenizer.apply_chat_template(
44
  messages,
45
  tokenize=True,
46
  add_generation_prompt=True,
47
  return_tensors="pt"
48
- ).to("cuda") # ZeroGPU makes cuda available here
49
-
 
 
 
50
  with torch.inference_mode():
51
  outputs = model.generate(
52
  input_ids=inputs,
@@ -62,7 +66,8 @@ def generate_sql(prompt: str):
62
  # Clean output
63
  if "<|assistant|>" in response:
64
  response = response.split("<|assistant|>", 1)[-1].strip()
65
- response = response.split("<|end|>")[0].strip() if "<|end|>" in response else response
 
66
 
67
  return response
68
 
@@ -75,13 +80,14 @@ demo = gr.Interface(
75
  lines=3
76
  ),
77
  outputs=gr.Textbox(label="Generated SQL"),
78
- title="SQL Chatbot (ZeroGPU)",
79
- description="Phi-3-mini 4bit + LoRA - Free but limited daily GPU time",
80
  examples=[
81
  ["Find duplicate emails in users table"],
82
  ["Top 5 highest paid employees"],
83
  ["Count orders per customer last month"]
84
- ]
 
85
  )
86
 
87
  if __name__ == "__main__":
 
1
+ # app.py - ZeroGPU safe version (no .to("cuda") outside decorated fn + no caching)
2
 
3
  import torch
4
  import gradio as gr
5
+ import spaces # Correct import
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
8
 
 
14
  TEMPERATURE = 0.0
15
  DO_SAMPLE = False
16
 
17
+ print("Loading quantized base model on CPU (GPU only in @spaces.GPU)...")
18
  bnb_config = BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_quant_type="nf4",
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  BASE_MODEL,
26
  quantization_config=bnb_config,
27
+ device_map="cpu", # ← Force CPU at load time (required for ZeroGPU)
28
  trust_remote_code=True
29
  )
30
 
31
  print("Loading LoRA...")
32
  model = PeftModel.from_pretrained(model, LORA_PATH)
33
+ model = model.merge_and_unload() # Merge for speed
34
 
35
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
36
  model.eval()
37
 
38
  # ────────────────────────────────────────────────────────────────
39
+ @spaces.GPU(duration=60) # 60s max is safe & gives good queue priority
40
  def generate_sql(prompt: str):
41
  messages = [{"role": "user", "content": prompt}]
42
 
43
+ # Tokenize on CPU first
44
  inputs = tokenizer.apply_chat_template(
45
  messages,
46
  tokenize=True,
47
  add_generation_prompt=True,
48
  return_tensors="pt"
49
+ )
50
+
51
+ # Move to CUDA ONLY inside here (GPU is now allocated)
52
+ inputs = inputs.to("cuda")
53
+
54
  with torch.inference_mode():
55
  outputs = model.generate(
56
  input_ids=inputs,
 
66
  # Clean output
67
  if "<|assistant|>" in response:
68
  response = response.split("<|assistant|>", 1)[-1].strip()
69
+ if "<|end|>" in response:
70
+ response = response.split("<|end|>")[0].strip()
71
 
72
  return response
73
 
 
80
  lines=3
81
  ),
82
  outputs=gr.Textbox(label="Generated SQL"),
83
+ title="SQL Chatbot (ZeroGPU Safe)",
84
+ description="Phi-3-mini 4bit + LoRA - GPU allocated only during generation",
85
  examples=[
86
  ["Find duplicate emails in users table"],
87
  ["Top 5 highest paid employees"],
88
  ["Count orders per customer last month"]
89
+ ],
90
+ cache_examples=False # ← CRITICAL: Disable caching to avoid startup .to("cuda") call
91
  )
92
 
93
  if __name__ == "__main__":