Prajjwalng commited on
Commit
e3cc7e5
·
verified ·
1 Parent(s): 2560070

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -55
app.py CHANGED
@@ -5,7 +5,6 @@ import os
5
  from huggingface_hub import login
6
  from peft import PeftModel, PeftConfig
7
  import time
8
- import threading
9
 
10
  # Login with HF_TOKEN (if available)
11
  hf_token = os.environ.get("HF_TOKEN")
@@ -19,8 +18,8 @@ else:
19
  st.warning("HF_TOKEN environment variable not set. Some features may be limited.")
20
 
21
  # Model and Adapter Configuration
22
- model_id = "Prajjwalng/gemma_customer_care"
23
- adapter_id = "Prajjwalng/gemma_customercare_adapters"
24
 
25
  # Initialize model and tokenizer (load only once)
26
  @st.cache_resource
@@ -32,25 +31,37 @@ def load_model(model_id):
32
  torch_dtype=torch.float16,
33
  device_map={"": 0} if torch.cuda.is_available() else "cpu"
34
  )
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
36
  return base_model, tokenizer
37
 
38
  merged_model, tokenizer = load_model(model_id)
39
 
40
- # Function to generate chatbot response
41
- def get_completion(query: str, model, tokenizer, stop_event) -> str:
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
43
  prompt_template = f"""
44
  <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
45
- <start_of_turn>user\n\n{query}<end_of_turn>\n\n<start_of_turn>model\n"""
 
 
 
 
 
 
 
46
  prompt = prompt_template.format(query=query)
 
47
  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
 
48
  model_inputs = encodeds.to(device)
 
49
  model.to(device)
 
50
  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
51
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
52
  model_response = decoded.split("model\n")[-1].strip()
53
- stop_event.set() #signal to stop typing animation.
54
  return model_response
55
 
56
  # Streamlit app
@@ -59,61 +70,47 @@ st.title("Customer Care ChatBot")
59
  # Initialize chat history
60
  if "messages" not in st.session_state:
61
  st.session_state.messages = []
 
62
  initial_message = {"role": "assistant", "content": "Hi, I am Sora, I am your customer support agent."}
63
  st.session_state.messages.append(initial_message)
64
 
65
  # Display chat messages from history on app rerun
66
  for message in st.session_state.messages:
67
- if message["role"] == "assistant":
68
- with st.container():
69
- col1, col2 = st.columns([1, 4])
70
- with col1:
71
- st.write("Agent:")
72
- with col2:
73
- st.markdown(message["content"])
74
- else:
75
- with st.container():
76
- col1, col2 = st.columns([4, 1])
77
- with col1:
78
- st.markdown(message["content"])
79
- with col2:
80
- st.write("Customer:")
81
 
82
  # Accept user input
83
  if prompt := st.chat_input("How can I help you?"):
 
84
  st.session_state.messages.append({"role": "user", "content": prompt})
85
- with st.container():
86
- col1, col2 = st.columns([4, 1])
87
- with col1:
88
- st.markdown(prompt)
89
- with col2:
90
- st.write("Customer:")
91
- with st.container():
92
- col1, col2 = st.columns([1, 4])
93
- with col1:
94
- st.write("Agent:")
95
- with col2:
96
- message_placeholder = st.empty()
97
- typing_placeholder = st.empty()
98
- stop_event = threading.Event() # Create an event to stop the typing animation.
99
-
100
- def animate_typing(placeholder, stop_event):
101
- typing_dots = ""
102
- while not stop_event.is_set():
103
- typing_dots += "."
104
- if len(typing_dots) > 3:
105
- typing_dots = "."
106
- placeholder.markdown(typing_dots)
107
- time.sleep(0.3)
108
- placeholder.empty()
109
-
110
- threading.Thread(target=animate_typing, args=(typing_placeholder, stop_event)).start() #start the typing animation.
111
-
112
- full_response = ""
113
- response = get_completion(prompt, merged_model, tokenizer, stop_event) #pass the stop event.
114
- for chunk in response.split():
115
- full_response += chunk + " "
116
- time.sleep(0.05)
117
- message_placeholder.markdown(full_response + "▌")
118
- message_placeholder.markdown(full_response)
119
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
5
  from huggingface_hub import login
6
  from peft import PeftModel, PeftConfig
7
  import time
 
8
 
9
  # Login with HF_TOKEN (if available)
10
  hf_token = os.environ.get("HF_TOKEN")
 
18
  st.warning("HF_TOKEN environment variable not set. Some features may be limited.")
19
 
20
  # Model and Adapter Configuration
21
+ model_id = "Prajjwalng/gemma_customer_care" # Base model
22
+ adapter_id = "Prajjwalng/gemma_customercare_adapters" # adapter model
23
 
24
  # Initialize model and tokenizer (load only once)
25
  @st.cache_resource
 
31
  torch_dtype=torch.float16,
32
  device_map={"": 0} if torch.cuda.is_available() else "cpu"
33
  )
34
+
35
  tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
36
  return base_model, tokenizer
37
 
38
  merged_model, tokenizer = load_model(model_id)
39
 
40
+ # Function to generate chatbot response using the provided template
41
+ def get_completion(query: str, model, tokenizer) -> str:
42
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
+
44
  prompt_template = f"""
45
  <start_of_turn>system You are a support chatbot who helps with user queries chatbot who always responds in the style of a professional.\n<end_of_turn>
46
+ <start_of_turn>user
47
+
48
+
49
+ {query}
50
+ <end_of_turn>
51
+
52
+ <start_of_turn>model
53
+ """
54
  prompt = prompt_template.format(query=query)
55
+
56
  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
57
+
58
  model_inputs = encodeds.to(device)
59
+
60
  model.to(device)
61
+
62
  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
63
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
64
  model_response = decoded.split("model\n")[-1].strip()
 
65
  return model_response
66
 
67
  # Streamlit app
 
70
  # Initialize chat history
71
  if "messages" not in st.session_state:
72
  st.session_state.messages = []
73
+ # Add initial welcome message
74
  initial_message = {"role": "assistant", "content": "Hi, I am Sora, I am your customer support agent."}
75
  st.session_state.messages.append(initial_message)
76
 
77
  # Display chat messages from history on app rerun
78
  for message in st.session_state.messages:
79
+ with st.chat_message(message["role"]):
80
+ st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # Accept user input
83
  if prompt := st.chat_input("How can I help you?"):
84
+ # Add user message to chat history
85
  st.session_state.messages.append({"role": "user", "content": prompt})
86
+ # Display user message in chat message container
87
+ with st.chat_message("user"):
88
+ st.markdown(prompt)
89
+
90
+ # Generate and display chatbot response
91
+ with st.chat_message("assistant"):
92
+ message_placeholder = st.empty()
93
+ typing_placeholder = st.empty()
94
+ typing_dots = "" # Initialize empty string for typing dots
95
+
96
+ # Animate typing dots
97
+ for i in range(3):
98
+ typing_dots += "."
99
+ typing_placeholder.markdown(typing_dots)
100
+ time.sleep(0.3) # Adjust speed as needed
101
+
102
+ typing_placeholder.empty() # Clear typing dots
103
+
104
+ full_response = ""
105
+ response = get_completion(prompt, merged_model, tokenizer)
106
+
107
+ # Simulate stream of responses with milliseconds delay
108
+ for chunk in response.split():
109
+ full_response += chunk + " "
110
+ time.sleep(0.05)
111
+ # Add a placeholder to stream the response
112
+ message_placeholder.markdown(full_response + "▌")
113
+ message_placeholder.markdown(full_response)
114
+
115
+ # Add assistant response to chat history
 
 
 
 
116
  st.session_state.messages.append({"role": "assistant", "content": full_response})