kdevoe commited on
Commit
9266488
·
verified ·
1 Parent(s): c9c6a47

Adding clear history button and using last 100 tokens of history

Browse files
Files changed (1) hide show
  1. app.py +36 -11
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
3
  import torch
4
  from langchain.memory import ConversationBufferMemory
5
 
@@ -16,21 +16,39 @@ model.to(device)
16
  # Set up conversational memory using LangChain's ConversationBufferMemory
17
  memory = ConversationBufferMemory()
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Define the chatbot function with memory and additional parameters
20
  def chat_with_dialogpt(input_text, temperature, top_p, top_k):
21
  # Retrieve conversation history
22
  conversation_history = memory.load_memory_variables({})['history']
23
 
24
  # Combine the (possibly summarized) history with the current user input
25
- no_memory_input = f"Question: {input_text}\nAnswer:"
26
-
27
- # Tokenize the input and convert to tensor
28
- input_ids = tokenizer.encode(no_memory_input, return_tensors="pt").to(device)
29
-
 
 
 
 
 
 
 
30
  # Generate the response using the model with adjusted parameters
31
  outputs = model.generate(
32
- input_ids,
33
- max_length=input_ids.shape[1] + 50, # Limit total length
34
  max_new_tokens=15,
35
  num_return_sequences=1,
36
  no_repeat_ngram_size=3,
@@ -50,10 +68,15 @@ def chat_with_dialogpt(input_text, temperature, top_p, top_k):
50
  memory.save_context({"input": input_text}, {"output": response})
51
 
52
  # Format the chat history for display
53
- chat_history = conversation_history + f"\nYou: {input_text}\nBot: {response}\n"
54
 
55
  return chat_history
56
 
 
 
 
 
 
57
  # Set up the Gradio interface with the input box below the output box
58
  with gr.Blocks() as interface:
59
  chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
@@ -79,10 +102,12 @@ with gr.Blocks() as interface:
79
  inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider],
80
  outputs=[chatbot_output, user_input])
81
 
 
 
 
 
82
  # Layout for sliders and chatbot UI
83
  gr.Row([temperature_slider, top_p_slider, top_k_slider])
84
 
85
  # Launch the Gradio app
86
  interface.launch()
87
-
88
-
 
1
  import gradio as gr
2
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
  import torch
4
  from langchain.memory import ConversationBufferMemory
5
 
 
16
  # Set up conversational memory using LangChain's ConversationBufferMemory
17
  memory = ConversationBufferMemory()
18
 
19
+ # Function to truncate tokens to the last 100 tokens
20
+ def truncate_history_to_100_tokens(history, tokenizer, max_tokens=100):
21
+ # Tokenize the history
22
+ tokenized_history = tokenizer.encode(history)
23
+
24
+ # Truncate to the last 100 tokens if necessary
25
+ if len(tokenized_history) > max_tokens:
26
+ tokenized_history = tokenized_history[-max_tokens:]
27
+
28
+ return tokenized_history
29
+
30
  # Define the chatbot function with memory and additional parameters
31
  def chat_with_dialogpt(input_text, temperature, top_p, top_k):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
  # Combine the (possibly summarized) history with the current user input
36
+ full_history = conversation_history + f"\nYou: {input_text}\nBot:"
37
+
38
+ # Truncate history to the most recent 100 tokens
39
+ truncated_input_ids = truncate_history_to_100_tokens(full_history, tokenizer)
40
+
41
+ # Tokenize the user input and append to truncated history
42
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
43
+ truncated_input_ids_tensor = torch.tensor([truncated_input_ids]).to(device)
44
+
45
+ # Concatenate truncated history with the new input
46
+ final_input_ids = torch.cat((truncated_input_ids_tensor, input_ids), dim=1)
47
+
48
  # Generate the response using the model with adjusted parameters
49
  outputs = model.generate(
50
+ final_input_ids,
51
+ max_length=final_input_ids.shape[1] + 50, # Limit total length
52
  max_new_tokens=15,
53
  num_return_sequences=1,
54
  no_repeat_ngram_size=3,
 
68
  memory.save_context({"input": input_text}, {"output": response})
69
 
70
  # Format the chat history for display
71
+ chat_history = full_history + f"\nBot: {response}\n"
72
 
73
  return chat_history
74
 
75
+ # Function to clear the chat history
76
+ def clear_history():
77
+ memory.clear() # Clear the memory object
78
+ return "" # Return empty string to reset the chat display
79
+
80
  # Set up the Gradio interface with the input box below the output box
81
  with gr.Blocks() as interface:
82
  chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
 
102
  inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider],
103
  outputs=[chatbot_output, user_input])
104
 
105
+ # Add a clear history button
106
+ clear_button = gr.Button("Clear History")
107
+ clear_button.click(fn=clear_history, outputs=[chatbot_output])
108
+
109
  # Layout for sliders and chatbot UI
110
  gr.Row([temperature_slider, top_p_slider, top_k_slider])
111
 
112
  # Launch the Gradio app
113
  interface.launch()