saadkhi commited on
Commit
02976e0
Β·
verified Β·
1 Parent(s): 806622f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -18
app.py CHANGED
@@ -1,53 +1,87 @@
1
- # app.py - ZeroGPU compatible version (NO Unsloth)
2
 
3
- import gradio as gr
4
  import torch
 
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  from peft import PeftModel
7
- from huggingface_hub import spaces # ← important!
8
 
9
- # Your model paths
10
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
11
- LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- print("Loading model on CPU first... (will use GPU only during @spaces.GPU)")
14
- bnb_config = BitsAndBytesConfig(load_in_4bit=True)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  BASE_MODEL,
17
  quantization_config=bnb_config,
18
  device_map="auto",
19
  trust_remote_code=True
20
  )
 
 
21
  model = PeftModel.from_pretrained(model, LORA_PATH)
 
 
22
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
23
  model.eval()
24
 
25
- @spaces.GPU # ← this requests GPU slice only during this function
 
26
  def generate_sql(prompt: str):
27
  messages = [{"role": "user", "content": prompt}]
28
- inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
29
 
 
 
 
 
 
 
 
30
  with torch.inference_mode():
31
  outputs = model.generate(
32
- inputs,
33
- max_new_tokens=180,
34
- temperature=0.0,
35
- do_sample=False,
36
  use_cache=True,
37
  pad_token_id=tokenizer.eos_token_id,
38
  )
39
-
40
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
41
  if "<|assistant|>" in response:
42
  response = response.split("<|assistant|>", 1)[-1].strip()
43
- return response.split("<|end|>")[0].strip()
 
 
44
 
 
45
  demo = gr.Interface(
46
  fn=generate_sql,
47
- inputs=gr.Textbox(label="Your SQL question"),
48
- outputs="text",
 
 
 
 
49
  title="SQL Chatbot (ZeroGPU)",
50
- description="Free but limited daily GPU time"
 
 
 
 
 
51
  )
52
 
53
  if __name__ == "__main__":
 
1
+ # app.py - ZeroGPU compatible version (standard transformers + @spaces.GPU)
2
 
 
3
  import torch
4
+ import gradio as gr
5
+ import spaces # ← Correct import!
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from peft import PeftModel
 
8
 
9
+ # ────────────────────────────────────────────────────────────────
10
  BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
11
+ LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
12
+
13
+ MAX_NEW_TOKENS = 180
14
+ TEMPERATURE = 0.0
15
+ DO_SAMPLE = False
16
+
17
+ print("Loading quantized base model (CPU first)...")
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.bfloat16
22
+ )
23
 
 
 
24
  model = AutoModelForCausalLM.from_pretrained(
25
  BASE_MODEL,
26
  quantization_config=bnb_config,
27
  device_map="auto",
28
  trust_remote_code=True
29
  )
30
+
31
+ print("Loading LoRA...")
32
  model = PeftModel.from_pretrained(model, LORA_PATH)
33
+ model = model.merge_and_unload() # Merge for faster inference
34
+
35
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
36
  model.eval()
37
 
38
+ # ────────────────────────────────────────────────────────────────
39
+ @spaces.GPU(duration=60) # ← Decorator! Requests GPU slice only here (60s max recommended)
40
  def generate_sql(prompt: str):
41
  messages = [{"role": "user", "content": prompt}]
 
42
 
43
+ inputs = tokenizer.apply_chat_template(
44
+ messages,
45
+ tokenize=True,
46
+ add_generation_prompt=True,
47
+ return_tensors="pt"
48
+ ).to("cuda") # ZeroGPU makes cuda available here
49
+
50
  with torch.inference_mode():
51
  outputs = model.generate(
52
+ input_ids=inputs,
53
+ max_new_tokens=MAX_NEW_TOKENS,
54
+ temperature=TEMPERATURE,
55
+ do_sample=DO_SAMPLE,
56
  use_cache=True,
57
  pad_token_id=tokenizer.eos_token_id,
58
  )
59
+
60
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ # Clean output
63
  if "<|assistant|>" in response:
64
  response = response.split("<|assistant|>", 1)[-1].strip()
65
+ response = response.split("<|end|>")[0].strip() if "<|end|>" in response else response
66
+
67
+ return response
68
 
69
+ # ────────────────────────────────────────────────────────────────
70
  demo = gr.Interface(
71
  fn=generate_sql,
72
+ inputs=gr.Textbox(
73
+ label="Ask SQL question",
74
+ placeholder="Delete duplicate rows from users table based on email",
75
+ lines=3
76
+ ),
77
+ outputs=gr.Textbox(label="Generated SQL"),
78
  title="SQL Chatbot (ZeroGPU)",
79
+ description="Phi-3-mini 4bit + LoRA - Free but limited daily GPU time",
80
+ examples=[
81
+ ["Find duplicate emails in users table"],
82
+ ["Top 5 highest paid employees"],
83
+ ["Count orders per customer last month"]
84
+ ]
85
  )
86
 
87
  if __name__ == "__main__":