kdevoe commited on
Commit
d51dabc
·
verified ·
1 Parent(s): 0c0cab0

Fixing memory issue

Browse files
Files changed (1) hide show
  1. app.py +65 -67
app.py CHANGED
@@ -1,89 +1,87 @@
1
  import gradio as gr
2
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
  import torch
4
  from langchain.memory import ConversationBufferMemory
5
 
6
  # Move model to device (GPU if available)
7
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
  # Load all three DialoGPT models (small, medium, large)
10
  models = {
11
- "small": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-small").to(device),
12
- "medium": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium").to(device),
13
- "large": GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-large").to(device)
14
  }
15
 
16
  # Load the tokenizer (same tokenizer for all models)
17
- tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
18
 
19
- # Set up conversational memory using LangChain's ConversationBufferMemory
20
- memory = ConversationBufferMemory()
21
 
22
- # Function to truncate tokens to the last 100 tokens
23
- def truncate_history_to_100_tokens(history, tokenizer, max_tokens=100):
24
- # Tokenize the history
25
- tokenized_history = tokenizer.encode(history)
26
-
27
- # Truncate to the last 100 tokens if necessary
28
- if len(tokenized_history) > max_tokens:
29
- tokenized_history = tokenized_history[-max_tokens:]
30
-
31
- return tokenized_history
32
 
33
  # Define the chatbot function with memory and additional parameters
34
  def chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size):
35
- # Retrieve conversation history
36
- conversation_history = memory.load_memory_variables({})['history']
37
-
38
- # Combine the (possibly summarized) history with the current user input
39
- full_history = conversation_history + f">> User: {input_text}"
40
-
41
- # Truncate history to the most recent 100 tokens
42
- truncated_input_ids = truncate_history_to_100_tokens(full_history, tokenizer)
43
-
44
- # Tokenize the user input and append to truncated history
45
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
46
- truncated_input_ids_tensor = torch.tensor([truncated_input_ids]).to(device)
47
-
48
- # Concatenate truncated history with the new input
49
- final_input_ids = torch.cat((truncated_input_ids_tensor, input_ids), dim=1)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Get the model corresponding to the selected size
52
  model = models[model_size]
53
-
54
- # Generate the response using the model with adjusted parameters
55
- outputs = model.generate(
56
- final_input_ids,
57
- max_length=final_input_ids.shape[1] + 50, # Limit total length
58
- max_new_tokens=15,
59
- num_return_sequences=1,
 
 
 
60
  no_repeat_ngram_size=3,
61
  repetition_penalty=1.2,
62
  early_stopping=True,
63
- pad_token_id=tokenizer.eos_token_id,
64
- eos_token_id=tokenizer.eos_token_id,
65
- temperature=temperature, # Add temperature from slider
66
- top_p=top_p, # Add top_p from slider
67
- top_k=top_k # Add top_k from slider
68
  )
69
 
70
- # Decode the model output
71
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
72
 
73
- # Update the memory with the user input and model response
74
- memory.save_context({"input": input_text}, {"output": response})
75
 
76
  # Format the chat history for display
77
- chat_history = full_history + f"\nBot: {response}\n"
 
 
 
 
 
 
 
78
 
79
- return chat_history
80
 
81
- # Function to clear the chat history
82
- def clear_history():
83
- memory.clear() # Clear the memory object
84
- return "" # Return empty string to reset the chat display
85
-
86
- # Set up the Gradio interface with the input box below the output box
87
  with gr.Blocks() as interface:
88
  chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
89
 
@@ -92,34 +90,34 @@ with gr.Blocks() as interface:
92
 
93
  # Add a dropdown for selecting the model size (small, medium, large)
94
  model_selector = gr.Dropdown(choices=["small", "medium", "large"], value="medium", label="Select Model Size")
95
-
96
  # Add a clear history button
97
- clear_button = gr.Button("Clear History", scale=0)
98
  clear_button.click(fn=clear_history, outputs=[chatbot_output])
99
 
100
  # Input box for the user
101
  user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
102
 
103
  # Sliders for temperature, top_p, and top_k
104
- temperature_slider = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Temperature", scale=0)
105
- top_p_slider = gr.Slider(0.0, 1.0, step=0.1, value=1.0, label="Top-p", scale=0)
106
- top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k", scale=0)
107
 
108
  # Define the function to update the chat
109
- def update_chat(input_text, chat_history, temperature, top_p, top_k, model_size):
110
  updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size)
111
  return updated_history, ""
112
 
113
  # Submit when pressing Shift + Enter
114
  user_input.submit(update_chat,
115
- inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider, model_selector],
116
  outputs=[chatbot_output, user_input])
117
-
118
  # Layout for sliders and chatbot UI
119
  gr.Row([temperature_slider, top_p_slider, top_k_slider])
120
-
121
  # Layout for model selector and clear button in a row
122
  gr.Row([model_selector, clear_button])
123
-
124
  # Launch the Gradio app
125
  interface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from langchain.memory import ConversationBufferMemory
5
 
6
  # Move model to device (GPU if available)
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  # Load all three DialoGPT models (small, medium, large)
10
  models = {
11
+ "small": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device),
12
+ "medium": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium").to(device),
13
+ "large": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large").to(device)
14
  }
15
 
16
  # Load the tokenizer (same tokenizer for all models)
17
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
18
 
19
+ # Initialize conversation history
20
+ conversation_history = []
21
 
22
+ # Function to clear the chat history
23
+ def clear_history():
24
+ global conversation_history
25
+ conversation_history = []
26
+ return ""
 
 
 
 
 
27
 
28
  # Define the chatbot function with memory and additional parameters
29
  def chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size):
30
+ global conversation_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Encode the user input and append the end-of-text token
33
+ new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt').to(device)
34
+
35
+ # Append the user input to the conversation history
36
+ conversation_history.append(new_user_input_ids)
37
+
38
+ # Concatenate conversation history
39
+ bot_input_ids = torch.cat(conversation_history, dim=-1)
40
+
41
+ # Truncate input_ids to the last 100 tokens if necessary
42
+ max_length = 100
43
+ if bot_input_ids.size(-1) > max_length:
44
+ bot_input_ids = bot_input_ids[:, -max_length:]
45
+
46
  # Get the model corresponding to the selected size
47
  model = models[model_size]
48
+
49
+ # Generate a response
50
+ response_ids = model.generate(
51
+ bot_input_ids,
52
+ max_length=bot_input_ids.shape[-1] + 50,
53
+ pad_token_id=tokenizer.eos_token_id,
54
+ eos_token_id=tokenizer.eos_token_id,
55
+ temperature=temperature,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
  no_repeat_ngram_size=3,
59
  repetition_penalty=1.2,
60
  early_stopping=True,
 
 
 
 
 
61
  )
62
 
63
+ # Extract only the new tokens generated
64
+ new_response_ids = response_ids[:, bot_input_ids.shape[-1]:]
65
+
66
+ # Decode the response
67
+ response = tokenizer.decode(new_response_ids[0], skip_special_tokens=True)
68
 
69
+ # Append the bot response to the conversation history
70
+ conversation_history.append(new_response_ids)
71
 
72
  # Format the chat history for display
73
+ # For display purposes, reconstruct the conversation
74
+ display_conversation = ""
75
+ for i in range(0, len(conversation_history), 2):
76
+ user_input = tokenizer.decode(conversation_history[i], skip_special_tokens=True)
77
+ display_conversation += f"You: {user_input}\n"
78
+ if i+1 < len(conversation_history):
79
+ bot_response = tokenizer.decode(conversation_history[i+1], skip_special_tokens=True)
80
+ display_conversation += f"Bot: {bot_response}\n"
81
 
82
+ return display_conversation
83
 
84
+ # Set up the Gradio interface
 
 
 
 
 
85
  with gr.Blocks() as interface:
86
  chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
87
 
 
90
 
91
  # Add a dropdown for selecting the model size (small, medium, large)
92
  model_selector = gr.Dropdown(choices=["small", "medium", "large"], value="medium", label="Select Model Size")
93
+
94
  # Add a clear history button
95
+ clear_button = gr.Button("Clear History")
96
  clear_button.click(fn=clear_history, outputs=[chatbot_output])
97
 
98
  # Input box for the user
99
  user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
100
 
101
  # Sliders for temperature, top_p, and top_k
102
+ temperature_slider = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Temperature")
103
+ top_p_slider = gr.Slider(0.0, 1.0, step=0.1, value=1.0, label="Top-p")
104
+ top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
105
 
106
  # Define the function to update the chat
107
+ def update_chat(input_text, temperature, top_p, top_k, model_size):
108
  updated_history = chat_with_dialogpt(input_text, temperature, top_p, top_k, model_size)
109
  return updated_history, ""
110
 
111
  # Submit when pressing Shift + Enter
112
  user_input.submit(update_chat,
113
+ inputs=[user_input, temperature_slider, top_p_slider, top_k_slider, model_selector],
114
  outputs=[chatbot_output, user_input])
115
+
116
  # Layout for sliders and chatbot UI
117
  gr.Row([temperature_slider, top_p_slider, top_k_slider])
118
+
119
  # Layout for model selector and clear button in a row
120
  gr.Row([model_selector, clear_button])
121
+
122
  # Launch the Gradio app
123
  interface.launch()