Prajjwalng commited on
Commit
2560070
·
verified ·
1 Parent(s): 95fb0a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -52
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  from huggingface_hub import login
6
  from peft import PeftModel, PeftConfig
7
  import time
 
8
 
9
  # Login with HF_TOKEN (if available)
10
  hf_token = os.environ.get("HF_TOKEN")
@@ -18,8 +19,8 @@ else:
18
  st.warning("HF_TOKEN environment variable not set. Some features may be limited.")
19
 
20
  # Model and Adapter Configuration
21
- model_id = "Prajjwalng/gemma_customer_care" # Base model
22
- adapter_id = "Prajjwalng/gemma_customercare_adapters" # adapter model
23
 
24
  # Initialize model and tokenizer (load only once)
25
  @st.cache_resource
@@ -31,37 +32,25 @@ def load_model(model_id):
31
  torch_dtype=torch.float16,
32
  device_map={"": 0} if torch.cuda.is_available() else "cpu"
33
  )
34
-
35
  tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
36
  return base_model, tokenizer
37
 
38
  merged_model, tokenizer = load_model(model_id)
39
 
40
- # Function to generate chatbot response using the provided template
41
- def get_completion(query: str, model, tokenizer) -> str:
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
-
44
  prompt_template = f"""
45
  <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
46
- <start_of_turn>user
47
-
48
-
49
- {query}
50
- <end_of_turn>
51
-
52
- <start_of_turn>model
53
- """
54
  prompt = prompt_template.format(query=query)
55
-
56
  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
57
-
58
  model_inputs = encodeds.to(device)
59
-
60
  model.to(device)
61
-
62
  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
63
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
64
  model_response = decoded.split("model\n")[-1].strip()
 
65
  return model_response
66
 
67
  # Streamlit app
@@ -70,47 +59,61 @@ st.title("Customer Care ChatBot")
70
  # Initialize chat history
71
  if "messages" not in st.session_state:
72
  st.session_state.messages = []
73
- # Add initial welcome message
74
  initial_message = {"role": "assistant", "content": "Hi, I am Sora, I am your customer support agent."}
75
  st.session_state.messages.append(initial_message)
76
 
77
  # Display chat messages from history on app rerun
78
  for message in st.session_state.messages:
79
- with st.chat_message(message["role"]):
80
- st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Accept user input
83
  if prompt := st.chat_input("How can I help you?"):
84
- # Add user message to chat history
85
  st.session_state.messages.append({"role": "user", "content": prompt})
86
- # Display user message in chat message container
87
- with st.chat_message("user"):
88
- st.markdown(prompt)
89
-
90
- # Generate and display chatbot response
91
- with st.chat_message("assistant"):
92
- message_placeholder = st.empty()
93
- typing_placeholder = st.empty()
94
- typing_dots = "" # Initialize empty string for typing dots
95
-
96
- # Animate typing dots
97
- for i in range(3):
98
- typing_dots += "."
99
- typing_placeholder.markdown(typing_dots)
100
- time.sleep(0.3) # Adjust speed as needed
101
-
102
- typing_placeholder.empty() # Clear typing dots
103
-
104
- full_response = ""
105
- response = get_completion(prompt, merged_model, tokenizer)
106
-
107
- # Simulate stream of responses with milliseconds delay
108
- for chunk in response.split():
109
- full_response += chunk + " "
110
- time.sleep(0.05)
111
- # Add a placeholder to stream the response
112
- message_placeholder.markdown(full_response + "▌")
113
- message_placeholder.markdown(full_response)
114
-
115
- # Add assistant response to chat history
 
 
 
 
116
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
5
  from huggingface_hub import login
6
  from peft import PeftModel, PeftConfig
7
  import time
8
+ import threading
9
 
10
  # Login with HF_TOKEN (if available)
11
  hf_token = os.environ.get("HF_TOKEN")
 
19
  st.warning("HF_TOKEN environment variable not set. Some features may be limited.")
20
 
21
  # Model and Adapter Configuration
22
+ model_id = "Prajjwalng/gemma_customer_care"
23
+ adapter_id = "Prajjwalng/gemma_customercare_adapters"
24
 
25
  # Initialize model and tokenizer (load only once)
26
  @st.cache_resource
 
32
  torch_dtype=torch.float16,
33
  device_map={"": 0} if torch.cuda.is_available() else "cpu"
34
  )
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
36
  return base_model, tokenizer
37
 
38
  merged_model, tokenizer = load_model(model_id)
39
 
40
+ # Function to generate chatbot response
41
+ def get_completion(query: str, model, tokenizer, stop_event) -> str:
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
43
  prompt_template = f"""
44
  <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
45
+ <start_of_turn>user\n\n{query}<end_of_turn>\n\n<start_of_turn>model\n"""
 
 
 
 
 
 
 
46
  prompt = prompt_template.format(query=query)
 
47
  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
 
48
  model_inputs = encodeds.to(device)
 
49
  model.to(device)
 
50
  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
51
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
52
  model_response = decoded.split("model\n")[-1].strip()
53
+ stop_event.set() #signal to stop typing animation.
54
  return model_response
55
 
56
  # Streamlit app
 
59
  # Initialize chat history
60
  if "messages" not in st.session_state:
61
  st.session_state.messages = []
 
62
  initial_message = {"role": "assistant", "content": "Hi, I am Sora, I am your customer support agent."}
63
  st.session_state.messages.append(initial_message)
64
 
65
  # Display chat messages from history on app rerun
66
  for message in st.session_state.messages:
67
+ if message["role"] == "assistant":
68
+ with st.container():
69
+ col1, col2 = st.columns([1, 4])
70
+ with col1:
71
+ st.write("Agent:")
72
+ with col2:
73
+ st.markdown(message["content"])
74
+ else:
75
+ with st.container():
76
+ col1, col2 = st.columns([4, 1])
77
+ with col1:
78
+ st.markdown(message["content"])
79
+ with col2:
80
+ st.write("Customer:")
81
 
82
  # Accept user input
83
  if prompt := st.chat_input("How can I help you?"):
 
84
  st.session_state.messages.append({"role": "user", "content": prompt})
85
+ with st.container():
86
+ col1, col2 = st.columns([4, 1])
87
+ with col1:
88
+ st.markdown(prompt)
89
+ with col2:
90
+ st.write("Customer:")
91
+ with st.container():
92
+ col1, col2 = st.columns([1, 4])
93
+ with col1:
94
+ st.write("Agent:")
95
+ with col2:
96
+ message_placeholder = st.empty()
97
+ typing_placeholder = st.empty()
98
+ stop_event = threading.Event() # Create an event to stop the typing animation.
99
+
100
+ def animate_typing(placeholder, stop_event):
101
+ typing_dots = ""
102
+ while not stop_event.is_set():
103
+ typing_dots += "."
104
+ if len(typing_dots) > 3:
105
+ typing_dots = "."
106
+ placeholder.markdown(typing_dots)
107
+ time.sleep(0.3)
108
+ placeholder.empty()
109
+
110
+ threading.Thread(target=animate_typing, args=(typing_placeholder, stop_event)).start() #start the typing animation.
111
+
112
+ full_response = ""
113
+ response = get_completion(prompt, merged_model, tokenizer, stop_event) #pass the stop event.
114
+ for chunk in response.split():
115
+ full_response += chunk + " "
116
+ time.sleep(0.05)
117
+ message_placeholder.markdown(full_response + "▌")
118
+ message_placeholder.markdown(full_response)
119
  st.session_state.messages.append({"role": "assistant", "content": full_response})