import os from flask import Flask, render_template, request, jsonify, session # Removed SentenceTransformer # from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from groq import Groq import numpy as np import logging # Import necessary components from transformers and torch from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F # For normalization # Ensure torch is using CPU if GPU is not available (standard for free tier) torch.set_num_threads(1) # Limit threads for resource efficiency if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logging.info(f"Using device: {device}") # Configure logging logging.basicConfig(level=logging.INFO) # --- Initialize Models (Load these once using transformers) --- tokenizer = None model = None client = None try: # Load tokenizer and model from HuggingFace Hub using transformers tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device) # Move model to device logging.info("Tokenizer and AutoModel loaded successfully.") except Exception as e: logging.error(f"Error loading Transformer models: {e}") # Models are None, will be handled below # Initialize the Groq client groq_api_key = os.environ.get("GROQ_API_KEY") if not groq_api_key: logging.error("GROQ_API_KEY environment variable not set.") # In deployment, this should ideally stop the app or show an error page # For this example, we'll proceed but Groq calls will fail client = None else: client = Groq(api_key=groq_api_key) logging.info("Groq client initialized.") # --- Helper function for Mean Pooling (from documentation) --- # Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] #First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) # --- Function to get embedding using transformers and pooling --- def get_embedding(text): """ Generate embedding for a single text using transformers and mean pooling. Returns a numpy array. """ if tokenizer is None or model is None: logging.error("Embedding models not loaded. Cannot generate embedding.") return None try: # Tokenize the input text encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) # Move input to device # Compute token embeddings with torch.no_grad(): # Disable gradient calculation for inference model_output = model(**encoded_input) # Perform pooling sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask']) # Normalize embeddings sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1) # Convert to numpy and return return sentence_embedding.cpu().numpy()[0] # Move back to CPU and get the single embedding array except Exception as e: logging.error(f"Error generating embedding: {e}") return None # --- Memory Management Functions (Adapted for Sessions and new embedding method) --- def add_to_memory(mem_list, role, content): """ Add a message to the provided memory list along with its embedding. Returns the updated list. """ # Ensure content is not empty if not content or not content.strip(): logging.warning(f"Attempted to add empty content to memory for role: {role}") return mem_list # Do not add empty messages embedding = get_embedding(content) # Use the new get_embedding function if embedding is not None: mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) # Store embedding as list for JSON serializability else: # Add message without embedding if embedding failed logging.warning(f"Failed to get embedding for message: {content[:50]}...") mem_list.append({"role": role, "content": content, "embedding": None}) # Store None for embedding return mem_list def retrieve_relevant_memory(mem_list, user_input, top_k=5): """ Retrieve the top-k most relevant messages from the provided memory list based on cosine similarity with user_input. Returns a list of relevant messages (dictionaries). """ # Ensure we have valid memory entries with embeddings and the necessary models valid_memory_with_embeddings = [m for m in mem_list if m.get("embedding") is not None] if not valid_memory_with_embeddings: return [] try: # Compute the embedding of the user input using the new function user_embedding = get_embedding(user_input) if user_embedding is None: logging.error("Failed to get user input embedding for retrieval.") return [] # Cannot retrieve if user embedding fails # Calculate similarities. Ensure all memory entries have valid embeddings. memory_items = [] memory_embeddings = [] for m in valid_memory_with_embeddings: try: # Attempt to convert embedding list back to numpy array np_embedding = np.array(m["embedding"]) # Optional: Check dimension if known (e.g., 384 for all-MiniLM-L6-v2) if np_embedding.shape == (model.config.hidden_size,): # Check dimension based on loaded model config memory_items.append(m) memory_embeddings.append(np_embedding) else: logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...") except Exception as conv_e: logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}") pass # Skip this memory entry if embedding is invalid or conversion fails if not memory_items: # Check again after filtering return [] # Calculate similarities # Ensure both are numpy arrays similarities = cosine_similarity([user_embedding], np.array(memory_embeddings))[0] # Sort memory by similarity and return the top-k messages relevant_messages_sorted = sorted(zip(similarities, memory_items), key=lambda x: x[0], reverse=True) # Return the message dictionaries return [m[1] for m in relevant_messages_sorted[:top_k]] except Exception as e: logging.error(f"Error retrieving memory: {e}") return [] # construct_prompt, trim_memory, summarize_memory, index, chat, clear_memory routes # and the final if __name__ == '__main__': block remain largely the same, # except they now rely on the global `tokenizer` and `model` being initialized # and call the new `get_embedding` function internally. # Ensure the check in the chat route verifies tokenizer and model are not None @app.route('/chat', methods=['POST']) def chat(): """ Handle incoming chat messages, process with the bot logic, update session memory, and return the AI response. """ # Check if Groq client AND embedding models are initialized if client is None or tokenizer is None or model is None: status_code = 500 error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)." logging.error(error_message) return jsonify({"response": error_message}), status_code # ... (rest of the chat function is the same) ... user_input = request.json.get('message') if not user_input or not user_input.strip(): return jsonify({"response": "Please enter a message."}), 400 current_memory_serializable = session.get('chat_memory', []) # No need to convert embeddings to numpy here, construct_prompt does it if needed via retrieve_relevant_memory messages_for_api = construct_prompt(current_memory_serializable, user_input) try: # Get response from the model completion = client.chat.completions.create( model="llama-3.1-8b-instruct-fpt", # Use a suitable, available model messages=messages_for_api, # Pass the list of messages temperature=0.6, max_tokens=1024, # Limit response length top_p=0.95, stream=False, # Disable streaming for simpler HTTP response handling stop=None, ) ai_response_content = completion.choices[0].message.content except Exception as e: logging.error(f"Error calling Groq API: {e}") ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later." # Update Memory Buffer (get_embedding is called within add_to_memory) current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input) current_memory_serializable = add_to_memory(current_memory_serializable, "assistant", ai_response_content) # Trim Memory current_memory_serializable = trim_memory(current_memory_serializable, max_size=20) # Store updated memory back into the session session['chat_memory'] = current_memory_serializable return jsonify({"response": ai_response_content}) # The construct_prompt, trim_memory, summarize_memory, index, clear_memory functions are mostly unchanged, # but they now rely on the global `tokenizer` and `model` being available. # construct_prompt calls retrieve_relevant_memory which calls get_embedding. def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increased max tokens slightly # ... (This function remains the same as before, it calls retrieve_relevant_memory) ... """ Construct the list of messages suitable for the Groq API's 'messages' parameter by combining relevant memory and the current user input. Adds relevant memory chronologically from the session history. """ # Retrieve relevant memory *content* based on similarity relevant_memory_items = retrieve_relevant_memory(mem_list, user_input) # Create a set of content strings from the relevant items for quick lookup relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} # Added content check messages_for_api = [] # Add a system message messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."}) current_prompt_tokens = len(messages_for_api[0]["content"].split()) # Start count with system message # Iterate through chronological session memory and add relevant messages that are also in the relevant_content_set context_messages = [] for msg in mem_list: # Only add messages whose content is found in the top-k relevant messages # and which have a role suitable for the API messages list if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]: # Estimate tokens for this message (simple word count) msg_text = f'{msg["role"]}: {msg["content"]}\n' msg_tokens = len(msg_text.split()) if current_prompt_tokens + msg_tokens > max_tokens_in_prompt: break # Stop if adding this message exceeds the limit # Add the message in the format expected by the API context_messages.append({"role": msg["role"], "content": msg["content"]}) current_prompt_tokens += msg_tokens # Add the chronological context messages messages_for_api.extend(context_messages) # Add the current user input as the final message # Ensure user input itself doesn't push over the limit significantly (though it should always be included) user_input_tokens = len(user_input.split()) if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1: logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.") pass # User input is always added messages_for_api.append({"role": "user", "content": user_input}) return messages_for_api def trim_memory(mem_list, max_size=50): # ... (This function is unchanged) ... """ Trim the memory list to keep it within the specified max size. Removes the oldest entries (from the beginning of the list). Returns the trimmed list. """ while len(mem_list) > max_size: mem_list.pop(0) # Remove the oldest entry return mem_list def summarize_memory(mem_list): # ... (This function is unchanged, relies on global client) ... """ Summarize the memory buffer to free up space. """ if not mem_list or client is None: logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.") return [] long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m]) if not long_term_memory.strip(): logging.warning("Memory content is empty. Cannot summarize.") return [] try: summary_completion = client.chat.completions.create( model="llama-3.1-8b-instruct-fpt", messages=[ {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."}, {"role": "user", "content": long_term_memory}, ], max_tokens= 500, ) summary_text = summary_completion.choices[0].message.content logging.info("Memory summarized.") # When replacing with summary, the embedding logic becomes less relevant for this entry type return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}] except Exception as e: logging.error(f"Error summarizing memory: {e}") return mem_list @app.route('/') def index(): # ... (This route is unchanged) ... """ Serve the main chat interface page. """ if 'chat_memory' not in session: session['chat_memory'] = [] return render_template('index.html') @app.route('/clear_memory', methods=['POST']) def clear_memory(): # ... (This route is unchanged) ... """ Clear the chat memory from the session. """ session['chat_memory'] = [] logging.info("Chat memory cleared.") return jsonify({"status": "Memory cleared."}) # --- Running the App --- if __name__ == '__main__': logging.info("Starting Waitress server...") port = int(os.environ.get('PORT', 7860)) serve(app, host='0.0.0.0', port=port)