import os from flask import Flask, render_template, request, jsonify, session from sklearn.metrics.pairwise import cosine_similarity from groq import Groq import numpy as np import logging from transformers import AutoTokenizer, AutoModel # Keep these import torch import torch.nn.functional as F # Configure logging logging.basicConfig(level=logging.INFO) # --- Flask App Setup --- (MUST come before routes or app-dependent code) --- app = Flask(__name__) app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'a_default_secret_key_please_change') # --- Initialize Models --- device = torch.device("cpu") # Force CPU for free tier if torch.cuda.is_available(): device = torch.device("cuda") # Should not happen on free tier logging.info(f"Using device: {device}") tokenizer = None model = None client = None try: # Load tokenizer and model from HuggingFace Hub using transformers tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') # Re-add from_tf=True here for AutoModel.from_pretrained model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', from_tf=True).to(device) logging.info("Tokenizer and AutoModel loaded successfully with from_tf=True.") except Exception as e: logging.error(f"Error loading Transformer models: {e}") tokenizer = None model = None # Initialize the Groq client groq_api_key = os.environ.get("GROQ_API_KEY") if not groq_api_key: logging.error("GROQ_API_KEY environment variable not set.") client = None else: client = Groq(api_key=groq_api_key) logging.info("Groq client initialized.") # --- Helper function for Mean Pooling --- def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(token_embeddings.device) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) # --- Function to get embedding --- def get_embedding(text): if tokenizer is None or model is None: logging.error("Embedding models not loaded. Cannot generate embedding.") return None try: encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) with torch.no_grad(): model_output = model(**encoded_input) sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask']) sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1) return sentence_embedding.cpu().numpy()[0] except Exception as e: logging.error(f"Error generating embedding: {e}") return None # --- Memory Management Functions (rely on get_embedding) --- # ... (add_to_memory, retrieve_relevant_memory, construct_prompt, trim_memory, summarize_memory - these remain the same, calling get_embedding) ... def add_to_memory(mem_list, role, content): if not content or not content.strip(): logging.warning(f"Attempted to add empty content to memory for role: {role}") return mem_list embedding = get_embedding(content) if embedding is not None: mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) else: logging.warning(f"Failed to get embedding for message: {content[:50]}...") mem_list.append({"role": role, "content": content, "embedding": None}) return mem_list def retrieve_relevant_memory(mem_list, user_input, top_k=5): if not mem_list or tokenizer is None or model is None: return [] user_embedding = get_embedding(user_input) if user_embedding is None: logging.error("Failed to get user input embedding for retrieval.") return [] valid_memory_items = [] memory_embeddings_np = [] for m in mem_list: if m.get("embedding") is not None and isinstance(m["embedding"], list): try: np_embedding = np.array(m["embedding"]) if np_embedding.shape == (model.config.hidden_size,): # Use model config for dimension valid_memory_items.append(m) memory_embeddings_np.append(np_embedding) else: logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...") except Exception as conv_e: logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}") pass if not valid_memory_items: return [] similarities = cosine_similarity([user_embedding], np.array(memory_embeddings_np))[0] relevant_messages_sorted = sorted(zip(similarities, valid_memory_items), key=lambda x: x[0], reverse=True) return [m[1] for m in relevant_messages_sorted[:top_k]] def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): relevant_memory_items = retrieve_relevant_memory(mem_list, user_input) relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} messages_for_api = [] messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."}) current_prompt_tokens = len(messages_for_api[0]["content"].split()) context_messages = [] for msg in mem_list: if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]: msg_text = f'{msg["role"]}: {msg["content"]}\n' msg_tokens = len(msg_text.split()) if current_prompt_tokens + msg_tokens > max_tokens_in_prompt: break context_messages.append({"role": msg["role"], "content": msg["content"]}) current_prompt_tokens += msg_tokens messages_for_api.extend(context_messages) user_input_tokens = len(user_input.split()) if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1: logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.") messages_for_api.append({"role": "user", "content": user_input}) return messages_for_api def trim_memory(mem_list, max_size=50): while len(mem_list) > max_size: mem_list.pop(0) return mem_list def summarize_memory(mem_list): if not mem_list or client is None: logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.") return [] long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m]) if not long_term_memory.strip(): logging.warning("Memory content is empty. Cannot summarize.") return [] try: summary_completion = client.chat.completions.create( model="llama-3.1-8b-instruct-fpt", messages=[ {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."}, {"role": "user", "content": long_term_memory}, ], max_tokens= 500, ) summary_text = summary_completion.choices[0].message.content logging.info("Memory summarized.") return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}] except Exception as e: logging.error(f"Error summarizing memory: {e}") return mem_list # --- Flask Routes --- (MUST come AFTER app is defined) --- @app.route('/') def index(): if 'chat_memory' not in session: session['chat_memory'] = [] return render_template('index.html') @app.route('/chat', methods=['POST']) def chat(): # Check if Groq client AND embedding models are initialized if client is None or tokenizer is None or model is None: status_code = 500 error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)." logging.error(error_message) return jsonify({"response": error_message}), status_code user_input = request.json.get('message') if not user_input or not user_input.strip(): return jsonify({"response": "Please enter a message."}), 400 current_memory_serializable = session.get('chat_memory', []) messages_for_api = construct_prompt(current_memory_serializable, user_input) try: completion = client.chat.completions.create( model="llama-3.1-8b-instruct-fpt", messages=messages_for_api, temperature=0.6, max_tokens=1024, top_p=0.95, stream=False, stop=None, ) ai_response_content = completion.choices[0].message.content except Exception as e: logging.error(f"Error calling Groq API: {e}") ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later." current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input) current_memory_serializable = add_to_memory(current_memory_serialable, "assistant", ai_response_content) current_memory_serializable = trim_memory(current_memory_serializable, max_size=20) session['chat_memory'] = current_memory_serializable return jsonify({"response": ai_response_content}) @app.route('/clear_memory', methods=['POST']) def clear_memory(): session['chat_memory'] = [] logging.info("Chat memory cleared.") return jsonify({"status": "Memory cleared."}) # --- Running the App --- if __name__ == '__main__': # Using Uvicorn instead of Waitress logging.info("Starting Uvicorn server...") port = int(os.environ.get('PORT', 7860)) # Use uvicorn.run to start the Flask app (which is a WSGI app) # It automatically detects it's a WSGI app import uvicorn uvicorn.run(app, host="0.0.0.0", port=port)