abanm commited on
Commit
79e4e8d
·
verified ·
1 Parent(s): 7cf1012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -35
app.py CHANGED
@@ -113,49 +113,33 @@ for message in st.session_state["messages"]:
113
  st.chat_message("assistant", avatar=DUBS_PATH).write(message["content"])
114
 
115
  # -------------------------
116
- # Streaming Logic using InferenceClient
117
  # -------------------------
118
  def stream_response(prompt_text, api_key):
119
  """
120
  Stream text from the HF Inference Endpoint using the InferenceClient.
121
  Yields each partial chunk of text as it arrives.
122
  """
123
- # Initialize the client with your endpoint_url and API key
124
- client = InferenceClient(
125
- SPACE_URL,
126
- token=api_key
127
- )
128
-
129
- # Define generation parameters
130
- gen_kwargs = dict(
131
- max_new_tokens=512,
132
- top_k=30,
133
- top_p=0.9,
134
- temperature=0.2,
135
- repetition_penalty=1.02,
136
- stop_sequences=["<|endoftext|>"]
137
- )
138
-
139
- # Start streaming from the model
140
  stream = client.text_generation(prompt_text, stream=True, details=True, **gen_kwargs)
141
 
142
- # We'll build the response incrementally
143
  partial_text = ""
144
-
145
  try:
146
  for response in stream:
147
- # Skip special tokens
148
  if response.token.special:
149
  continue
150
- # Break if we encounter a stop sequence
151
- if response.token.text in gen_kwargs["stop_sequences"]:
152
- break
153
-
154
- # Update the partial text
155
- partial_text = response.token.text
156
-
157
- # Yield the text so far so we can stream on the frontend
158
- yield partial_text
159
  except Exception as e:
160
  yield f"Error: {e}"
161
 
@@ -171,7 +155,6 @@ if prompt := st.chat_input():
171
  st.chat_message("user").write(prompt)
172
 
173
  # 2) Build combined chat history for the model prompt
174
- # This format is just an example; adjust as needed for your model
175
  chat_history = "".join(
176
  [f"<|{msg['role']}|>{msg['content']}<|end|>" for msg in st.session_state["messages"]]
177
  )
@@ -181,11 +164,10 @@ if prompt := st.chat_input():
181
  assistant_message_placeholder = st.chat_message("assistant", avatar=DUBS_PATH).empty()
182
 
183
  full_response = ""
184
- # 4) Stream chunks from the Hugging Face InferenceClient
185
  for chunk in stream_response(chat_history, HF_API_KEY):
186
- full_response += chunk + " " # each chunk is the incremental text so far
187
- msg = st.write_stream(full_response)
188
- assistant_message_placeholder.markdown(full_response + "▌")
189
  assistant_message_placeholder.markdown(full_response)
190
 
191
  # 5) Save the final assistant message in session state
 
113
  st.chat_message("assistant", avatar=DUBS_PATH).write(message["content"])
114
 
115
  # -------------------------
116
+ # Streaming Logic using Generator
117
  # -------------------------
118
  def stream_response(prompt_text, api_key):
119
  """
120
  Stream text from the HF Inference Endpoint using the InferenceClient.
121
  Yields each partial chunk of text as it arrives.
122
  """
123
+ client = InferenceClient(SPACE_URL, token=api_key)
124
+
125
+ gen_kwargs = {
126
+ "max_new_tokens": 512,
127
+ "top_k": 30,
128
+ "top_p": 0.9,
129
+ "temperature": 0.2,
130
+ "repetition_penalty": 1.02,
131
+ "stop_sequences": ["<|endoftext|>"]
132
+ }
133
+
 
 
 
 
 
 
134
  stream = client.text_generation(prompt_text, stream=True, details=True, **gen_kwargs)
135
 
 
136
  partial_text = ""
 
137
  try:
138
  for response in stream:
 
139
  if response.token.special:
140
  continue
141
+ partial_text += response.token.text
142
+ yield response.token.text
 
 
 
 
 
 
 
143
  except Exception as e:
144
  yield f"Error: {e}"
145
 
 
155
  st.chat_message("user").write(prompt)
156
 
157
  # 2) Build combined chat history for the model prompt
 
158
  chat_history = "".join(
159
  [f"<|{msg['role']}|>{msg['content']}<|end|>" for msg in st.session_state["messages"]]
160
  )
 
164
  assistant_message_placeholder = st.chat_message("assistant", avatar=DUBS_PATH).empty()
165
 
166
  full_response = ""
167
+ # 4) Stream chunks from the generator
168
  for chunk in stream_response(chat_history, HF_API_KEY):
169
+ full_response += chunk # Accumulate the full response
170
+ assistant_message_placeholder.markdown(full_response + "▌") # Show streamed response
 
171
  assistant_message_placeholder.markdown(full_response)
172
 
173
  # 5) Save the final assistant message in session state