DhirajBot / app.py
swaroop77's picture
Update app.py
e4ed1f7 verified
import os
from flask import Flask, render_template, request, jsonify, session
# Removed SentenceTransformer
# from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import numpy as np
import logging
# Import necessary components from transformers and torch
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F # For normalization
# Ensure torch is using CPU if GPU is not available (standard for free tier)
torch.set_num_threads(1) # Limit threads for resource efficiency
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logging.info(f"Using device: {device}")
# Configure logging
logging.basicConfig(level=logging.INFO)
# --- Initialize Models (Load these once using transformers) ---
tokenizer = None
model = None
client = None
try:
# Load tokenizer and model from HuggingFace Hub using transformers
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device) # Move model to device
logging.info("Tokenizer and AutoModel loaded successfully.")
except Exception as e:
logging.error(f"Error loading Transformer models: {e}")
# Models are None, will be handled below
# Initialize the Groq client
groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
logging.error("GROQ_API_KEY environment variable not set.")
# In deployment, this should ideally stop the app or show an error page
# For this example, we'll proceed but Groq calls will fail
client = None
else:
client = Groq(api_key=groq_api_key)
logging.info("Groq client initialized.")
# --- Helper function for Mean Pooling (from documentation) ---
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# --- Function to get embedding using transformers and pooling ---
def get_embedding(text):
"""
Generate embedding for a single text using transformers and mean pooling.
Returns a numpy array.
"""
if tokenizer is None or model is None:
logging.error("Embedding models not loaded. Cannot generate embedding.")
return None
try:
# Tokenize the input text
encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) # Move input to device
# Compute token embeddings
with torch.no_grad(): # Disable gradient calculation for inference
model_output = model(**encoded_input)
# Perform pooling
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)
# Convert to numpy and return
return sentence_embedding.cpu().numpy()[0] # Move back to CPU and get the single embedding array
except Exception as e:
logging.error(f"Error generating embedding: {e}")
return None
# --- Memory Management Functions (Adapted for Sessions and new embedding method) ---
def add_to_memory(mem_list, role, content):
"""
Add a message to the provided memory list along with its embedding.
Returns the updated list.
"""
# Ensure content is not empty
if not content or not content.strip():
logging.warning(f"Attempted to add empty content to memory for role: {role}")
return mem_list # Do not add empty messages
embedding = get_embedding(content) # Use the new get_embedding function
if embedding is not None:
mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) # Store embedding as list for JSON serializability
else:
# Add message without embedding if embedding failed
logging.warning(f"Failed to get embedding for message: {content[:50]}...")
mem_list.append({"role": role, "content": content, "embedding": None}) # Store None for embedding
return mem_list
def retrieve_relevant_memory(mem_list, user_input, top_k=5):
"""
Retrieve the top-k most relevant messages from the provided memory list
based on cosine similarity with user_input.
Returns a list of relevant messages (dictionaries).
"""
# Ensure we have valid memory entries with embeddings and the necessary models
valid_memory_with_embeddings = [m for m in mem_list if m.get("embedding") is not None]
if not valid_memory_with_embeddings:
return []
try:
# Compute the embedding of the user input using the new function
user_embedding = get_embedding(user_input)
if user_embedding is None:
logging.error("Failed to get user input embedding for retrieval.")
return [] # Cannot retrieve if user embedding fails
# Calculate similarities. Ensure all memory entries have valid embeddings.
memory_items = []
memory_embeddings = []
for m in valid_memory_with_embeddings:
try:
# Attempt to convert embedding list back to numpy array
np_embedding = np.array(m["embedding"])
# Optional: Check dimension if known (e.g., 384 for all-MiniLM-L6-v2)
if np_embedding.shape == (model.config.hidden_size,): # Check dimension based on loaded model config
memory_items.append(m)
memory_embeddings.append(np_embedding)
else:
logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...")
except Exception as conv_e:
logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
pass # Skip this memory entry if embedding is invalid or conversion fails
if not memory_items: # Check again after filtering
return []
# Calculate similarities
# Ensure both are numpy arrays
similarities = cosine_similarity([user_embedding], np.array(memory_embeddings))[0]
# Sort memory by similarity and return the top-k messages
relevant_messages_sorted = sorted(zip(similarities, memory_items), key=lambda x: x[0], reverse=True)
# Return the message dictionaries
return [m[1] for m in relevant_messages_sorted[:top_k]]
except Exception as e:
logging.error(f"Error retrieving memory: {e}")
return []
# construct_prompt, trim_memory, summarize_memory, index, chat, clear_memory routes
# and the final if __name__ == '__main__': block remain largely the same,
# except they now rely on the global `tokenizer` and `model` being initialized
# and call the new `get_embedding` function internally.
# Ensure the check in the chat route verifies tokenizer and model are not None
@app.route('/chat', methods=['POST'])
def chat():
"""
Handle incoming chat messages, process with the bot logic,
update session memory, and return the AI response.
"""
# Check if Groq client AND embedding models are initialized
if client is None or tokenizer is None or model is None:
status_code = 500
error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)."
logging.error(error_message)
return jsonify({"response": error_message}), status_code
# ... (rest of the chat function is the same) ...
user_input = request.json.get('message')
if not user_input or not user_input.strip():
return jsonify({"response": "Please enter a message."}), 400
current_memory_serializable = session.get('chat_memory', [])
# No need to convert embeddings to numpy here, construct_prompt does it if needed via retrieve_relevant_memory
messages_for_api = construct_prompt(current_memory_serializable, user_input)
try:
# Get response from the model
completion = client.chat.completions.create(
model="llama-3.1-8b-instruct-fpt", # Use a suitable, available model
messages=messages_for_api, # Pass the list of messages
temperature=0.6,
max_tokens=1024, # Limit response length
top_p=0.95,
stream=False, # Disable streaming for simpler HTTP response handling
stop=None,
)
ai_response_content = completion.choices[0].message.content
except Exception as e:
logging.error(f"Error calling Groq API: {e}")
ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."
# Update Memory Buffer (get_embedding is called within add_to_memory)
current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
current_memory_serializable = add_to_memory(current_memory_serializable, "assistant", ai_response_content)
# Trim Memory
current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)
# Store updated memory back into the session
session['chat_memory'] = current_memory_serializable
return jsonify({"response": ai_response_content})
# The construct_prompt, trim_memory, summarize_memory, index, clear_memory functions are mostly unchanged,
# but they now rely on the global `tokenizer` and `model` being available.
# construct_prompt calls retrieve_relevant_memory which calls get_embedding.
def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increased max tokens slightly
# ... (This function remains the same as before, it calls retrieve_relevant_memory) ...
"""
Construct the list of messages suitable for the Groq API's 'messages' parameter
by combining relevant memory and the current user input.
Adds relevant memory chronologically from the session history.
"""
# Retrieve relevant memory *content* based on similarity
relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
# Create a set of content strings from the relevant items for quick lookup
relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} # Added content check
messages_for_api = []
# Add a system message
messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."})
current_prompt_tokens = len(messages_for_api[0]["content"].split()) # Start count with system message
# Iterate through chronological session memory and add relevant messages that are also in the relevant_content_set
context_messages = []
for msg in mem_list:
# Only add messages whose content is found in the top-k relevant messages
# and which have a role suitable for the API messages list
if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
# Estimate tokens for this message (simple word count)
msg_text = f'{msg["role"]}: {msg["content"]}\n'
msg_tokens = len(msg_text.split())
if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
break # Stop if adding this message exceeds the limit
# Add the message in the format expected by the API
context_messages.append({"role": msg["role"], "content": msg["content"]})
current_prompt_tokens += msg_tokens
# Add the chronological context messages
messages_for_api.extend(context_messages)
# Add the current user input as the final message
# Ensure user input itself doesn't push over the limit significantly (though it should always be included)
user_input_tokens = len(user_input.split())
if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.")
pass # User input is always added
messages_for_api.append({"role": "user", "content": user_input})
return messages_for_api
def trim_memory(mem_list, max_size=50):
# ... (This function is unchanged) ...
"""
Trim the memory list to keep it within the specified max size.
Removes the oldest entries (from the beginning of the list).
Returns the trimmed list.
"""
while len(mem_list) > max_size:
mem_list.pop(0) # Remove the oldest entry
return mem_list
def summarize_memory(mem_list):
# ... (This function is unchanged, relies on global client) ...
"""
Summarize the memory buffer to free up space.
"""
if not mem_list or client is None:
logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
return []
long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m])
if not long_term_memory.strip():
logging.warning("Memory content is empty. Cannot summarize.")
return []
try:
summary_completion = client.chat.completions.create(
model="llama-3.1-8b-instruct-fpt",
messages=[
{"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
{"role": "user", "content": long_term_memory},
],
max_tokens= 500,
)
summary_text = summary_completion.choices[0].message.content
logging.info("Memory summarized.")
# When replacing with summary, the embedding logic becomes less relevant for this entry type
return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}]
except Exception as e:
logging.error(f"Error summarizing memory: {e}")
return mem_list
@app.route('/')
def index():
# ... (This route is unchanged) ...
"""
Serve the main chat interface page.
"""
if 'chat_memory' not in session:
session['chat_memory'] = []
return render_template('index.html')
@app.route('/clear_memory', methods=['POST'])
def clear_memory():
# ... (This route is unchanged) ...
"""
Clear the chat memory from the session.
"""
session['chat_memory'] = []
logging.info("Chat memory cleared.")
return jsonify({"status": "Memory cleared."})
# --- Running the App ---
if __name__ == '__main__':
logging.info("Starting Waitress server...")
port = int(os.environ.get('PORT', 7860))
serve(app, host='0.0.0.0', port=port)