saadkhi commited on
Commit
979ad48
·
1 Parent(s): fe626d6

optimized sol, review needed

Browse files
Files changed (3) hide show
  1. app.py +28 -7
  2. app_old.txt +18 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,19 +1,40 @@
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
 
 
 
4
 
5
- # Load base + finetuned model
6
  base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
7
  finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(base_model)
10
- model = AutoModelForCausalLM.from_pretrained(base_model)
11
- model = PeftModel.from_pretrained(model, finetuned_model)
 
 
 
 
 
 
 
 
 
 
12
 
13
  def chat(prompt):
14
- inputs = tokenizer(prompt, return_tensors="pt")
15
- outputs = model.generate(**inputs, max_new_tokens=200)
16
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
17
 
18
  iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="SQL Chatbot")
19
- iface.launch()
 
1
+ import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
+ from transformers import BitsAndBytesConfig
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
9
  base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
10
  finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(base_model)
13
+
14
+ bnb = BitsAndBytesConfig(load_in_4bit=True)
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ base_model,
18
+ quantization_config=bnb,
19
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
20
+ device_map="auto"
21
+ )
22
+
23
+ model = PeftModel.from_pretrained(model, finetuned_model).to(device)
24
+ model.eval()
25
 
26
  def chat(prompt):
27
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
28
+
29
+ with torch.inference_mode():
30
+ output = model.generate(
31
+ **inputs,
32
+ max_new_tokens=60,
33
+ temperature=0.1,
34
+ do_sample=False
35
+ )
36
+
37
+ return tokenizer.decode(output[0], skip_special_tokens=True)
38
 
39
  iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="SQL Chatbot")
40
+ iface.launch()
app_old.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+
3
+ # Use the existing Hugging Face Space as the backend
4
+ client = Client("saadkhi/SQL_chatbot_API")
5
+
6
+
7
+ def chat(prompt: str) -> str:
8
+ """Proxy the prompt to the remote Space /chat endpoint."""
9
+ return client.predict(
10
+ prompt=prompt,
11
+ api_name="/chat",
12
+ )
13
+
14
+
15
+ if __name__ == "__main__":
16
+ # Simple CLI test
17
+ user_prompt = input("Enter your SQL question: ")
18
+ print(chat(user_prompt))
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  gradio
2
  transformers
3
  peft
 
 
4
  torch
5
- bitsandbytes
 
1
  gradio
2
  transformers
3
  peft
4
+ accelerate
5
+ bitsandbytes
6
  torch