botsi commited on
Commit
a4c5e87
·
verified ·
1 Parent(s): 962c084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -142,8 +142,7 @@ def generate(
142
 
143
  # Use the global variable to store the chat history
144
  # global global_chat_history
145
-
146
- conversation = []
147
 
148
  # Move the condition here after the assignment
149
  if input_prompt:
@@ -152,15 +151,17 @@ def generate(
152
  # Convert input prompt to tensor
153
  input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
154
 
155
-
156
  for user, assistant in chat_history:
157
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
158
  conversation.append({"role": "user", "content": message})
159
 
160
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
161
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
162
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
163
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
164
  input_ids = input_ids.to(model.device)
165
 
166
  # Set up the TextIteratorStreamer
@@ -226,11 +227,14 @@ examples=[
226
 
227
  with gr.Blocks(css="style.css") as demo:
228
  ## gradio-chatbot-read-query-param
229
- session_index = 'eb3636167d3a63fbeee32934610e5b2f'
230
- #url_params = gr.JSON({}, visible=False, label="URL Params")
231
- #def get_session_index(url_params):
232
- # session_index = url_params.get('session_index')
233
- # print(session_index)
 
 
 
234
  personalized_data = fetch_personalized_data(session_index)
235
  print(personalized_data)
236
 
@@ -270,6 +274,16 @@ with gr.Blocks(css="style.css") as demo:
270
  input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
271
  input_prompt += f"{message} [/INST] "
272
  return input_prompt
 
 
 
 
 
 
 
 
 
 
273
 
274
  chat_interface.render()
275
  #gr.Markdown(LICENSE)
@@ -281,13 +295,4 @@ if __name__ == "__main__":
281
 
282
  # Register the function to be called when the program exits
283
  # atexit.register(save_chat_history)
284
-
285
-
286
- ''' demo.load(
287
- fn=lambda x: x,
288
- inputs=[url_params],
289
- outputs=[url_params],
290
- _js=get_window_url_params(),
291
- queue=False
292
- )
293
- '''
 
142
 
143
  # Use the global variable to store the chat history
144
  # global global_chat_history
145
+ # conversation = []
 
146
 
147
  # Move the condition here after the assignment
148
  if input_prompt:
 
151
  # Convert input prompt to tensor
152
  input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
153
 
 
154
  for user, assistant in chat_history:
155
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
156
+
157
  conversation.append({"role": "user", "content": message})
158
 
159
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
160
+
161
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
162
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
163
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
164
+
165
  input_ids = input_ids.to(model.device)
166
 
167
  # Set up the TextIteratorStreamer
 
227
 
228
  with gr.Blocks(css="style.css") as demo:
229
  ## gradio-chatbot-read-query-param
230
+ #session_index = 'eb3636167d3a63fbeee32934610e5b2f'
231
+ url_params = gr.JSON({}, visible=False, label="URL Params")
232
+ def get_session_index(url_params):
233
+ session_index = url_params.get('session_index')
234
+ print(session_index)
235
+ return session_index
236
+
237
+ session_index = get_session_index(url_params)
238
  personalized_data = fetch_personalized_data(session_index)
239
  print(personalized_data)
240
 
 
274
  input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
275
  input_prompt += f"{message} [/INST] "
276
  return input_prompt
277
+
278
+ clear.click(lambda: None, None, chatbot, queue=False)
279
+
280
+ demo.load(
281
+ fn=lambda x: x,
282
+ inputs=[url_params],
283
+ outputs=[url_params],
284
+ _js=get_window_url_params(),
285
+ queue=False
286
+ )
287
 
288
  chat_interface.render()
289
  #gr.Markdown(LICENSE)
 
295
 
296
  # Register the function to be called when the program exits
297
  # atexit.register(save_chat_history)
298
+