swaroop77 commited on
Commit
e4ed1f7
·
verified ·
1 Parent(s): b597966

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -160
app.py CHANGED
@@ -1,26 +1,42 @@
1
  import os
2
  from flask import Flask, render_template, request, jsonify, session
3
- from sentence_transformers import SentenceTransformer
 
4
  from sklearn.metrics.pairwise import cosine_similarity
5
  from groq import Groq
6
  import numpy as np
7
- import logging # Import logging
8
- from waitress import serve # Import waitress
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
 
13
- # --- Initialize Models (Load these once) ---
14
- # Ensure model path is accessible, default path works in Docker
 
 
 
15
  try:
16
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
17
- logging.info("SentenceTransformer model loaded successfully.")
 
 
18
  except Exception as e:
19
- logging.error(f"Error loading SentenceTransformer model: {e}")
20
- embedding_model = None # Handle potential loading errors
21
 
22
  # Initialize the Groq client
23
- # It's recommended to set the API key as an environment variable GROQ_API_KEY
24
  groq_api_key = os.environ.get("GROQ_API_KEY")
25
  if not groq_api_key:
26
  logging.error("GROQ_API_KEY environment variable not set.")
@@ -32,37 +48,67 @@ else:
32
  logging.info("Groq client initialized.")
33
 
34
 
35
- # --- Flask App Setup ---
36
- app = Flask(__name__)
37
- # A secret key is required for Flask sessions
38
- # USE A STRONG, RANDOM KEY IN PRODUCTION ENVIRONMENT VARIABLE!
39
- app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'a_default_secret_key_please_change') # !!! CHANGE THIS DEFAULT or set ENV VAR in Hugging Face Space Secrets !!!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # --- Memory Management Functions (Adapted for Sessions) ---
 
 
42
 
43
- # These functions will now operate on a memory list passed to them,
44
- # rather than a global variable. The Flask route will manage the session state.
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def add_to_memory(mem_list, role, content):
47
  """
48
  Add a message to the provided memory list along with its embedding.
49
  Returns the updated list.
50
  """
51
- if embedding_model is None:
52
- logging.error("Embedding model not loaded. Cannot add to memory.")
53
- return mem_list
54
- try:
55
- # Check if content is not empty before encoding
56
- if not content or not content.strip():
57
- logging.warning(f"Attempted to add empty content to memory for role: {role}")
58
- return mem_list # Do not add empty messages
59
 
60
- embedding = embedding_model.encode(content, convert_to_numpy=True)
61
- mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) # Store embedding as list for JSON serializability in session
62
- return mem_list
63
- except Exception as e:
64
- logging.error(f"Error adding to memory: {e}")
65
- return mem_list
 
 
 
 
66
 
67
 
68
  def retrieve_relevant_memory(mem_list, user_input, top_k=5):
@@ -71,34 +117,45 @@ def retrieve_relevant_memory(mem_list, user_input, top_k=5):
71
  based on cosine similarity with user_input.
72
  Returns a list of relevant messages (dictionaries).
73
  """
74
- if not mem_list or embedding_model is None:
 
 
 
75
  return []
76
 
77
  try:
78
- # Compute the embedding of the user input
79
- user_embedding = embedding_model.encode(user_input, convert_to_numpy=True)
 
 
 
 
80
 
81
  # Calculate similarities. Ensure all memory entries have valid embeddings.
82
- # We need to convert embedding lists back to numpy arrays for cosine_similarity
83
- valid_memory_with_embeddings = []
84
- for m in mem_list:
85
- if "embedding" in m and m["embedding"] is not None:
86
- try:
87
- # Attempt to convert embedding list back to numpy array
88
- np_embedding = np.array(m["embedding"])
89
- if np_embedding.shape == (embedding_model.get_sentence_embedding_dimension(),): # Check dimension
90
- valid_memory_with_embeddings.append((m, np_embedding))
91
- except Exception as conv_e:
92
- logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
93
- pass # Skip this memory entry if embedding is invalid
94
-
95
- if not valid_memory_with_embeddings:
 
 
 
 
 
96
  return []
97
 
98
- memory_items, memory_embeddings = zip(*valid_memory_with_embeddings)
99
-
100
  # Calculate similarities
101
- similarities = cosine_similarity([user_embedding], list(memory_embeddings))[0]
 
102
 
103
  # Sort memory by similarity and return the top-k messages
104
  relevant_messages_sorted = sorted(zip(similarities, memory_items), key=lambda x: x[0], reverse=True)
@@ -111,7 +168,72 @@ def retrieve_relevant_memory(mem_list, user_input, top_k=5):
111
  return []
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increased max tokens slightly
 
115
  """
116
  Construct the list of messages suitable for the Groq API's 'messages' parameter
117
  by combining relevant memory and the current user input.
@@ -120,7 +242,7 @@ def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increas
120
  # Retrieve relevant memory *content* based on similarity
121
  relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
122
  # Create a set of content strings from the relevant items for quick lookup
123
- relevant_content_set = {m["content"] for m in relevant_memory_items}
124
 
125
  messages_for_api = []
126
  # Add a system message
@@ -128,14 +250,14 @@ def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increas
128
 
129
  current_prompt_tokens = len(messages_for_api[0]["content"].split()) # Start count with system message
130
 
131
- # Iterate through chronological session memory and add relevant messages
132
  context_messages = []
133
  for msg in mem_list:
134
  # Only add messages whose content is found in the top-k relevant messages
135
  # and which have a role suitable for the API messages list
136
- if msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
137
- # Estimate tokens for this message (simple word count)
138
- msg_text = f'{msg["role"]}: {msg["content"]}\n' # Estimate based on formatted string length
139
  msg_tokens = len(msg_text.split())
140
  if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
141
  break # Stop if adding this message exceeds the limit
@@ -144,6 +266,7 @@ def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increas
144
  context_messages.append({"role": msg["role"], "content": msg["content"]})
145
  current_prompt_tokens += msg_tokens
146
 
 
147
  # Add the chronological context messages
148
  messages_for_api.extend(context_messages)
149
 
@@ -151,9 +274,7 @@ def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increas
151
  # Ensure user input itself doesn't push over the limit significantly (though it should always be included)
152
  user_input_tokens = len(user_input.split())
153
  if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
154
- # If user input pushes over, and there's existing context, log a warning
155
- logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Truncating context.")
156
- # In a real scenario, you might trim context from the beginning here
157
  pass # User input is always added
158
 
159
  messages_for_api.append({"role": "user", "content": user_input})
@@ -162,6 +283,7 @@ def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increas
162
 
163
 
164
  def trim_memory(mem_list, max_size=50):
 
165
  """
166
  Trim the memory list to keep it within the specified max size.
167
  Removes the oldest entries (from the beginning of the list).
@@ -171,138 +293,51 @@ def trim_memory(mem_list, max_size=50):
171
  mem_list.pop(0) # Remove the oldest entry
172
  return mem_list
173
 
174
- # The summarize_memory function is defined but not used in the current web chat loop.
175
- # Keeping it here for completeness.
176
  def summarize_memory(mem_list):
 
177
  """
178
  Summarize the memory buffer to free up space.
179
- This would typically replace the detailed memory with a summary entry.
180
- Needs Groq client and memory list as input.
181
  """
182
  if not mem_list or client is None:
183
  logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
184
- return [] # Return empty list or original list? Let's return an empty list + summary
185
 
186
- long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m]) # Add check for content key
187
- if not long_term_memory.strip(): # Check if memory is empty or just whitespace after joining
188
  logging.warning("Memory content is empty. Cannot summarize.")
189
  return []
190
 
191
  try:
192
  summary_completion = client.chat.completions.create(
193
- # Use a currently available Groq model for summarization
194
- model="llama-3.1-8b-instruct-fpt", # Or "llama-3.1-70b-versatile", etc. Check Groq docs.
195
  messages=[
196
  {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
197
  {"role": "user", "content": long_term_memory},
198
  ],
199
- max_tokens= 500, # Limit summary length
200
  )
201
- # Access the content correctly from the message object
202
  summary_text = summary_completion.choices[0].message.content
203
  logging.info("Memory summarized.")
204
- # Replace detailed memory with summary
205
- # Embedding for summary isn't strictly needed for retrieval of detailed conversation, but could be added.
206
- # For simplicity, we'll store it without an embedding here.
207
- return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}] # Embedding is less relevant for a summary entry
208
  except Exception as e:
209
  logging.error(f"Error summarizing memory: {e}")
210
- return mem_list # Return original memory on failure
211
-
212
 
213
- # --- Flask Routes ---
214
 
215
  @app.route('/')
216
  def index():
 
217
  """
218
  Serve the main chat interface page.
219
  """
220
- # Initialize memory in session if it doesn't exist
221
  if 'chat_memory' not in session:
222
  session['chat_memory'] = []
223
  return render_template('index.html')
224
 
225
- @app.route('/chat', methods=['POST'])
226
- def chat():
227
- """
228
- Handle incoming chat messages, process with the bot logic,
229
- update session memory, and return the AI response.
230
- """
231
- if client is None or embedding_model is None:
232
- # Check if API key was missing or model failed to load at startup
233
- status_code = 500
234
- error_message = "Chatbot backend is not fully initialized (API key or embedding model missing)."
235
- logging.error(error_message)
236
- return jsonify({"response": error_message}), status_code
237
-
238
-
239
- user_input = request.json.get('message')
240
- if not user_input or not user_input.strip():
241
- return jsonify({"response": "Please enter a message."}), 400
242
-
243
- # Get memory from the session
244
- # Session data needs to be JSON serializable, embeddings are numpy arrays
245
- # We stored them as lists, retrieve_relevant_memory expects numpy. Handle conversion.
246
- current_memory_serializable = session.get('chat_memory', [])
247
- # Create a temporary list that converts embedding lists back to numpy for processing
248
- current_memory_for_processing = []
249
- for entry in current_memory_serializable:
250
- temp_entry = entry.copy() # Copy to avoid modifying session directly before commit
251
- if "embedding" in temp_entry and isinstance(temp_entry["embedding"], list):
252
- try:
253
- temp_entry["embedding"] = np.array(temp_entry["embedding"])
254
- current_memory_for_processing.append(temp_entry)
255
- except Exception as conv_e:
256
- logging.warning(f"Failed to convert session embedding to numpy: {conv_e}")
257
- # Skip this entry or handle error
258
- pass # Just skip for now
259
-
260
- # Construct prompt using relevant memory from the current session memory
261
- # The construct_prompt function returns a list of messages for the API
262
- messages_for_api = construct_prompt(current_memory_for_processing, user_input)
263
-
264
- try:
265
- # Get response from the model
266
- completion = client.chat.completions.create(
267
- model="llama-3.1-8b-instruct-fpt", # Use a suitable, available model
268
- messages=messages_for_api, # Pass the list of messages
269
- temperature=0.6,
270
- max_tokens=1024, # Limit response length
271
- top_p=0.95,
272
- stream=False, # Disable streaming for simpler HTTP response handling
273
- stop=None,
274
- )
275
- ai_response_content = completion.choices[0].message.content # Access content correctly
276
-
277
- except Exception as e:
278
- logging.error(f"Error calling Groq API: {e}")
279
- # Provide a user-friendly error message
280
- ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."
281
- # Optionally clear memory on API error if it might be corrupted
282
- # session['chat_memory'] = [] # Decide if you want to clear on error
283
-
284
-
285
- # --- Update Memory Buffer in Session ---
286
- # Use the original serializable memory list to add new entries
287
- # The add_to_memory function now returns the updated list
288
- current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
289
- current_memory_serializable = add_to_memory(current_memory_serializable, "assistant", ai_response_content)
290
-
291
- # Optionally trim memory to keep it manageable (e.g., last 20 turns)
292
- # You might want a larger size for better memory recall
293
- current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)
294
-
295
- # Store the updated memory back into the session
296
- # Ensure embeddings are lists when stored
297
- session['chat_memory'] = current_memory_serializable
298
-
299
- # Return the AI response as JSON
300
- return jsonify({"response": ai_response_content})
301
-
302
-
303
- # You can add a route to clear memory if needed (e.g., a "Start New Chat" button)
304
  @app.route('/clear_memory', methods=['POST'])
305
  def clear_memory():
 
306
  """
307
  Clear the chat memory from the session.
308
  """
@@ -314,7 +349,5 @@ def clear_memory():
314
  # --- Running the App ---
315
  if __name__ == '__main__':
316
  logging.info("Starting Waitress server...")
317
- # --- IMPORTANT: Use port 7860 for Hugging Face Spaces Docker SDK ---
318
- # Use the PORT environment variable if set, otherwise default to 7860
319
  port = int(os.environ.get('PORT', 7860))
320
  serve(app, host='0.0.0.0', port=port)
 
1
  import os
2
  from flask import Flask, render_template, request, jsonify, session
3
+ # Removed SentenceTransformer
4
+ # from sentence_transformers import SentenceTransformer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  from groq import Groq
7
  import numpy as np
8
+ import logging
9
+ # Import necessary components from transformers and torch
10
+ from transformers import AutoTokenizer, AutoModel
11
+ import torch
12
+ import torch.nn.functional as F # For normalization
13
+ # Ensure torch is using CPU if GPU is not available (standard for free tier)
14
+ torch.set_num_threads(1) # Limit threads for resource efficiency
15
+ if torch.cuda.is_available():
16
+ device = torch.device("cuda")
17
+ else:
18
+ device = torch.device("cpu")
19
+ logging.info(f"Using device: {device}")
20
+
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
24
 
25
+ # --- Initialize Models (Load these once using transformers) ---
26
+ tokenizer = None
27
+ model = None
28
+ client = None
29
+
30
  try:
31
+ # Load tokenizer and model from HuggingFace Hub using transformers
32
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
33
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device) # Move model to device
34
+ logging.info("Tokenizer and AutoModel loaded successfully.")
35
  except Exception as e:
36
+ logging.error(f"Error loading Transformer models: {e}")
37
+ # Models are None, will be handled below
38
 
39
  # Initialize the Groq client
 
40
  groq_api_key = os.environ.get("GROQ_API_KEY")
41
  if not groq_api_key:
42
  logging.error("GROQ_API_KEY environment variable not set.")
 
48
  logging.info("Groq client initialized.")
49
 
50
 
51
+ # --- Helper function for Mean Pooling (from documentation) ---
52
+ # Mean Pooling - Take attention mask into account for correct averaging
53
+ def mean_pooling(model_output, attention_mask):
54
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
55
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
56
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
57
+
58
+
59
+ # --- Function to get embedding using transformers and pooling ---
60
+ def get_embedding(text):
61
+ """
62
+ Generate embedding for a single text using transformers and mean pooling.
63
+ Returns a numpy array.
64
+ """
65
+ if tokenizer is None or model is None:
66
+ logging.error("Embedding models not loaded. Cannot generate embedding.")
67
+ return None
68
+ try:
69
+ # Tokenize the input text
70
+ encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) # Move input to device
71
 
72
+ # Compute token embeddings
73
+ with torch.no_grad(): # Disable gradient calculation for inference
74
+ model_output = model(**encoded_input)
75
 
76
+ # Perform pooling
77
+ sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
78
+
79
+ # Normalize embeddings
80
+ sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)
81
+
82
+ # Convert to numpy and return
83
+ return sentence_embedding.cpu().numpy()[0] # Move back to CPU and get the single embedding array
84
+
85
+ except Exception as e:
86
+ logging.error(f"Error generating embedding: {e}")
87
+ return None
88
+
89
+
90
+ # --- Memory Management Functions (Adapted for Sessions and new embedding method) ---
91
 
92
  def add_to_memory(mem_list, role, content):
93
  """
94
  Add a message to the provided memory list along with its embedding.
95
  Returns the updated list.
96
  """
97
+ # Ensure content is not empty
98
+ if not content or not content.strip():
99
+ logging.warning(f"Attempted to add empty content to memory for role: {role}")
100
+ return mem_list # Do not add empty messages
 
 
 
 
101
 
102
+ embedding = get_embedding(content) # Use the new get_embedding function
103
+
104
+ if embedding is not None:
105
+ mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) # Store embedding as list for JSON serializability
106
+ else:
107
+ # Add message without embedding if embedding failed
108
+ logging.warning(f"Failed to get embedding for message: {content[:50]}...")
109
+ mem_list.append({"role": role, "content": content, "embedding": None}) # Store None for embedding
110
+
111
+ return mem_list
112
 
113
 
114
  def retrieve_relevant_memory(mem_list, user_input, top_k=5):
 
117
  based on cosine similarity with user_input.
118
  Returns a list of relevant messages (dictionaries).
119
  """
120
+ # Ensure we have valid memory entries with embeddings and the necessary models
121
+ valid_memory_with_embeddings = [m for m in mem_list if m.get("embedding") is not None]
122
+
123
+ if not valid_memory_with_embeddings:
124
  return []
125
 
126
  try:
127
+ # Compute the embedding of the user input using the new function
128
+ user_embedding = get_embedding(user_input)
129
+
130
+ if user_embedding is None:
131
+ logging.error("Failed to get user input embedding for retrieval.")
132
+ return [] # Cannot retrieve if user embedding fails
133
 
134
  # Calculate similarities. Ensure all memory entries have valid embeddings.
135
+ memory_items = []
136
+ memory_embeddings = []
137
+ for m in valid_memory_with_embeddings:
138
+ try:
139
+ # Attempt to convert embedding list back to numpy array
140
+ np_embedding = np.array(m["embedding"])
141
+ # Optional: Check dimension if known (e.g., 384 for all-MiniLM-L6-v2)
142
+ if np_embedding.shape == (model.config.hidden_size,): # Check dimension based on loaded model config
143
+ memory_items.append(m)
144
+ memory_embeddings.append(np_embedding)
145
+ else:
146
+ logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...")
147
+
148
+ except Exception as conv_e:
149
+ logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
150
+ pass # Skip this memory entry if embedding is invalid or conversion fails
151
+
152
+
153
+ if not memory_items: # Check again after filtering
154
  return []
155
 
 
 
156
  # Calculate similarities
157
+ # Ensure both are numpy arrays
158
+ similarities = cosine_similarity([user_embedding], np.array(memory_embeddings))[0]
159
 
160
  # Sort memory by similarity and return the top-k messages
161
  relevant_messages_sorted = sorted(zip(similarities, memory_items), key=lambda x: x[0], reverse=True)
 
168
  return []
169
 
170
 
171
+ # construct_prompt, trim_memory, summarize_memory, index, chat, clear_memory routes
172
+ # and the final if __name__ == '__main__': block remain largely the same,
173
+ # except they now rely on the global `tokenizer` and `model` being initialized
174
+ # and call the new `get_embedding` function internally.
175
+
176
+ # Ensure the check in the chat route verifies tokenizer and model are not None
177
+ @app.route('/chat', methods=['POST'])
178
+ def chat():
179
+ """
180
+ Handle incoming chat messages, process with the bot logic,
181
+ update session memory, and return the AI response.
182
+ """
183
+ # Check if Groq client AND embedding models are initialized
184
+ if client is None or tokenizer is None or model is None:
185
+ status_code = 500
186
+ error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)."
187
+ logging.error(error_message)
188
+ return jsonify({"response": error_message}), status_code
189
+
190
+ # ... (rest of the chat function is the same) ...
191
+ user_input = request.json.get('message')
192
+ if not user_input or not user_input.strip():
193
+ return jsonify({"response": "Please enter a message."}), 400
194
+
195
+ current_memory_serializable = session.get('chat_memory', [])
196
+ # No need to convert embeddings to numpy here, construct_prompt does it if needed via retrieve_relevant_memory
197
+
198
+ messages_for_api = construct_prompt(current_memory_serializable, user_input)
199
+
200
+ try:
201
+ # Get response from the model
202
+ completion = client.chat.completions.create(
203
+ model="llama-3.1-8b-instruct-fpt", # Use a suitable, available model
204
+ messages=messages_for_api, # Pass the list of messages
205
+ temperature=0.6,
206
+ max_tokens=1024, # Limit response length
207
+ top_p=0.95,
208
+ stream=False, # Disable streaming for simpler HTTP response handling
209
+ stop=None,
210
+ )
211
+ ai_response_content = completion.choices[0].message.content
212
+
213
+ except Exception as e:
214
+ logging.error(f"Error calling Groq API: {e}")
215
+ ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."
216
+
217
+
218
+ # Update Memory Buffer (get_embedding is called within add_to_memory)
219
+ current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
220
+ current_memory_serializable = add_to_memory(current_memory_serializable, "assistant", ai_response_content)
221
+
222
+ # Trim Memory
223
+ current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)
224
+
225
+ # Store updated memory back into the session
226
+ session['chat_memory'] = current_memory_serializable
227
+
228
+ return jsonify({"response": ai_response_content})
229
+
230
+
231
+ # The construct_prompt, trim_memory, summarize_memory, index, clear_memory functions are mostly unchanged,
232
+ # but they now rely on the global `tokenizer` and `model` being available.
233
+ # construct_prompt calls retrieve_relevant_memory which calls get_embedding.
234
+
235
  def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increased max tokens slightly
236
+ # ... (This function remains the same as before, it calls retrieve_relevant_memory) ...
237
  """
238
  Construct the list of messages suitable for the Groq API's 'messages' parameter
239
  by combining relevant memory and the current user input.
 
242
  # Retrieve relevant memory *content* based on similarity
243
  relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
244
  # Create a set of content strings from the relevant items for quick lookup
245
+ relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} # Added content check
246
 
247
  messages_for_api = []
248
  # Add a system message
 
250
 
251
  current_prompt_tokens = len(messages_for_api[0]["content"].split()) # Start count with system message
252
 
253
+ # Iterate through chronological session memory and add relevant messages that are also in the relevant_content_set
254
  context_messages = []
255
  for msg in mem_list:
256
  # Only add messages whose content is found in the top-k relevant messages
257
  # and which have a role suitable for the API messages list
258
+ if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
259
+ # Estimate tokens for this message (simple word count)
260
+ msg_text = f'{msg["role"]}: {msg["content"]}\n'
261
  msg_tokens = len(msg_text.split())
262
  if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
263
  break # Stop if adding this message exceeds the limit
 
266
  context_messages.append({"role": msg["role"], "content": msg["content"]})
267
  current_prompt_tokens += msg_tokens
268
 
269
+
270
  # Add the chronological context messages
271
  messages_for_api.extend(context_messages)
272
 
 
274
  # Ensure user input itself doesn't push over the limit significantly (though it should always be included)
275
  user_input_tokens = len(user_input.split())
276
  if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
277
+ logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.")
 
 
278
  pass # User input is always added
279
 
280
  messages_for_api.append({"role": "user", "content": user_input})
 
283
 
284
 
285
  def trim_memory(mem_list, max_size=50):
286
+ # ... (This function is unchanged) ...
287
  """
288
  Trim the memory list to keep it within the specified max size.
289
  Removes the oldest entries (from the beginning of the list).
 
293
  mem_list.pop(0) # Remove the oldest entry
294
  return mem_list
295
 
 
 
296
  def summarize_memory(mem_list):
297
+ # ... (This function is unchanged, relies on global client) ...
298
  """
299
  Summarize the memory buffer to free up space.
 
 
300
  """
301
  if not mem_list or client is None:
302
  logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
303
+ return []
304
 
305
+ long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m])
306
+ if not long_term_memory.strip():
307
  logging.warning("Memory content is empty. Cannot summarize.")
308
  return []
309
 
310
  try:
311
  summary_completion = client.chat.completions.create(
312
+ model="llama-3.1-8b-instruct-fpt",
 
313
  messages=[
314
  {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
315
  {"role": "user", "content": long_term_memory},
316
  ],
317
+ max_tokens= 500,
318
  )
 
319
  summary_text = summary_completion.choices[0].message.content
320
  logging.info("Memory summarized.")
321
+ # When replacing with summary, the embedding logic becomes less relevant for this entry type
322
+ return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}]
 
 
323
  except Exception as e:
324
  logging.error(f"Error summarizing memory: {e}")
325
+ return mem_list
 
326
 
 
327
 
328
  @app.route('/')
329
  def index():
330
+ # ... (This route is unchanged) ...
331
  """
332
  Serve the main chat interface page.
333
  """
 
334
  if 'chat_memory' not in session:
335
  session['chat_memory'] = []
336
  return render_template('index.html')
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  @app.route('/clear_memory', methods=['POST'])
339
  def clear_memory():
340
+ # ... (This route is unchanged) ...
341
  """
342
  Clear the chat memory from the session.
343
  """
 
349
  # --- Running the App ---
350
  if __name__ == '__main__':
351
  logging.info("Starting Waitress server...")
 
 
352
  port = int(os.environ.get('PORT', 7860))
353
  serve(app, host='0.0.0.0', port=port)