Prajjwalng commited on
Commit
53cc210
·
verified ·
1 Parent(s): 6035193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
  from huggingface_hub import login
 
6
 
7
  # Login with HF_TOKEN (if available)
8
  hf_token = os.environ.get("HF_TOKEN")
@@ -15,19 +16,30 @@ if hf_token:
15
  else:
16
  st.warning("HF_TOKEN environment variable not set. Some features may be limited.")
17
 
 
 
 
 
18
  # Initialize model and tokenizer (load only once)
19
  @st.cache_resource
20
- def load_model():
21
- model_name = "google/gemma-2b-it"
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
- model = AutoModelForCausalLM.from_pretrained(model_name)
24
- return tokenizer, model
 
 
 
 
 
 
 
25
 
26
- tokenizer, model = load_model()
27
 
28
  # Function to generate chatbot response using the provided template
29
  def get_completion(query: str, model, tokenizer) -> str:
30
- device = "cuda:0" if torch.cuda.is_available() else "cpu" #Use cuda if available.
31
 
32
  prompt_template = f"""
33
  <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
@@ -45,7 +57,7 @@ def get_completion(query: str, model, tokenizer) -> str:
45
 
46
  model_inputs = encodeds.to(device)
47
 
48
- model.to(device) #Move model to device.
49
 
50
  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
51
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
@@ -53,7 +65,7 @@ def get_completion(query: str, model, tokenizer) -> str:
53
  return model_response
54
 
55
  # Streamlit app
56
- st.title("Gemma-2b-it Support Chatbot")
57
 
58
  # Initialize chat history
59
  if "messages" not in st.session_state:
@@ -76,7 +88,7 @@ if prompt := st.chat_input("How can I help you?"):
76
  with st.chat_message("assistant"):
77
  message_placeholder = st.empty()
78
  full_response = ""
79
- response = get_completion(prompt, model, tokenizer)
80
 
81
  # Simulate stream of responses with milliseconds delay
82
  import time
 
3
  import torch
4
  import os
5
  from huggingface_hub import login
6
+ from peft import PeftModel, PeftConfig
7
 
8
  # Login with HF_TOKEN (if available)
9
  hf_token = os.environ.get("HF_TOKEN")
 
16
  else:
17
  st.warning("HF_TOKEN environment variable not set. Some features may be limited.")
18
 
19
+ # Model and Adapter Configuration
20
+ model_id = "google/gemma-2b-it" # Base model
21
+ adapter_id = "Prajjwalng/gemma_customercare_adapters" #adapter model
22
+
23
  # Initialize model and tokenizer (load only once)
24
  @st.cache_resource
25
+ def load_model(model_id, adapter_id):
26
+ base_model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ low_cpu_mem_usage=True,
29
+ return_dict=True,
30
+ torch_dtype=torch.float16,
31
+ device_map={"": 0} if torch.cuda.is_available() else "cpu"
32
+ )
33
+
34
+ merged_model = PeftModel.from_pretrained(base_model, adapter_id)
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
36
+ return merged_model, tokenizer
37
 
38
+ merged_model, tokenizer = load_model(model_id, adapter_id)
39
 
40
  # Function to generate chatbot response using the provided template
41
  def get_completion(query: str, model, tokenizer) -> str:
42
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
 
44
  prompt_template = f"""
45
  <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
 
57
 
58
  model_inputs = encodeds.to(device)
59
 
60
+ model.to(device)
61
 
62
  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
63
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
65
  return model_response
66
 
67
  # Streamlit app
68
+ st.title("Gemma-2b-it Customer Care Chatbot")
69
 
70
  # Initialize chat history
71
  if "messages" not in st.session_state:
 
88
  with st.chat_message("assistant"):
89
  message_placeholder = st.empty()
90
  full_response = ""
91
+ response = get_completion(prompt, merged_model, tokenizer)
92
 
93
  # Simulate stream of responses with milliseconds delay
94
  import time