|
|
import os |
|
|
from flask import Flask, render_template, request, jsonify, session |
|
|
|
|
|
|
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from groq import Groq |
|
|
import numpy as np |
|
|
import logging |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
logging.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
client = None |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
|
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device) |
|
|
logging.info("Tokenizer and AutoModel loaded successfully.") |
|
|
except Exception as e: |
|
|
logging.error(f"Error loading Transformer models: {e}") |
|
|
|
|
|
|
|
|
|
|
|
groq_api_key = os.environ.get("GROQ_API_KEY") |
|
|
if not groq_api_key: |
|
|
logging.error("GROQ_API_KEY environment variable not set.") |
|
|
|
|
|
|
|
|
client = None |
|
|
else: |
|
|
client = Groq(api_key=groq_api_key) |
|
|
logging.info("Groq client initialized.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mean_pooling(model_output, attention_mask): |
|
|
token_embeddings = model_output[0] |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
|
|
|
|
|
|
def get_embedding(text): |
|
|
""" |
|
|
Generate embedding for a single text using transformers and mean pooling. |
|
|
Returns a numpy array. |
|
|
""" |
|
|
if tokenizer is None or model is None: |
|
|
logging.error("Embedding models not loaded. Cannot generate embedding.") |
|
|
return None |
|
|
try: |
|
|
|
|
|
encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
model_output = model(**encoded_input) |
|
|
|
|
|
|
|
|
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask']) |
|
|
|
|
|
|
|
|
sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1) |
|
|
|
|
|
|
|
|
return sentence_embedding.cpu().numpy()[0] |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error generating embedding: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_to_memory(mem_list, role, content): |
|
|
""" |
|
|
Add a message to the provided memory list along with its embedding. |
|
|
Returns the updated list. |
|
|
""" |
|
|
|
|
|
if not content or not content.strip(): |
|
|
logging.warning(f"Attempted to add empty content to memory for role: {role}") |
|
|
return mem_list |
|
|
|
|
|
embedding = get_embedding(content) |
|
|
|
|
|
if embedding is not None: |
|
|
mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) |
|
|
else: |
|
|
|
|
|
logging.warning(f"Failed to get embedding for message: {content[:50]}...") |
|
|
mem_list.append({"role": role, "content": content, "embedding": None}) |
|
|
|
|
|
return mem_list |
|
|
|
|
|
|
|
|
def retrieve_relevant_memory(mem_list, user_input, top_k=5): |
|
|
""" |
|
|
Retrieve the top-k most relevant messages from the provided memory list |
|
|
based on cosine similarity with user_input. |
|
|
Returns a list of relevant messages (dictionaries). |
|
|
""" |
|
|
|
|
|
valid_memory_with_embeddings = [m for m in mem_list if m.get("embedding") is not None] |
|
|
|
|
|
if not valid_memory_with_embeddings: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
user_embedding = get_embedding(user_input) |
|
|
|
|
|
if user_embedding is None: |
|
|
logging.error("Failed to get user input embedding for retrieval.") |
|
|
return [] |
|
|
|
|
|
|
|
|
memory_items = [] |
|
|
memory_embeddings = [] |
|
|
for m in valid_memory_with_embeddings: |
|
|
try: |
|
|
|
|
|
np_embedding = np.array(m["embedding"]) |
|
|
|
|
|
if np_embedding.shape == (model.config.hidden_size,): |
|
|
memory_items.append(m) |
|
|
memory_embeddings.append(np_embedding) |
|
|
else: |
|
|
logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...") |
|
|
|
|
|
except Exception as conv_e: |
|
|
logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}") |
|
|
pass |
|
|
|
|
|
|
|
|
if not memory_items: |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
similarities = cosine_similarity([user_embedding], np.array(memory_embeddings))[0] |
|
|
|
|
|
|
|
|
relevant_messages_sorted = sorted(zip(similarities, memory_items), key=lambda x: x[0], reverse=True) |
|
|
|
|
|
|
|
|
return [m[1] for m in relevant_messages_sorted[:top_k]] |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error retrieving memory: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/chat', methods=['POST']) |
|
|
def chat(): |
|
|
""" |
|
|
Handle incoming chat messages, process with the bot logic, |
|
|
update session memory, and return the AI response. |
|
|
""" |
|
|
|
|
|
if client is None or tokenizer is None or model is None: |
|
|
status_code = 500 |
|
|
error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)." |
|
|
logging.error(error_message) |
|
|
return jsonify({"response": error_message}), status_code |
|
|
|
|
|
|
|
|
user_input = request.json.get('message') |
|
|
if not user_input or not user_input.strip(): |
|
|
return jsonify({"response": "Please enter a message."}), 400 |
|
|
|
|
|
current_memory_serializable = session.get('chat_memory', []) |
|
|
|
|
|
|
|
|
messages_for_api = construct_prompt(current_memory_serializable, user_input) |
|
|
|
|
|
try: |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="llama-3.1-8b-instruct-fpt", |
|
|
messages=messages_for_api, |
|
|
temperature=0.6, |
|
|
max_tokens=1024, |
|
|
top_p=0.95, |
|
|
stream=False, |
|
|
stop=None, |
|
|
) |
|
|
ai_response_content = completion.choices[0].message.content |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error calling Groq API: {e}") |
|
|
ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later." |
|
|
|
|
|
|
|
|
|
|
|
current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input) |
|
|
current_memory_serializable = add_to_memory(current_memory_serializable, "assistant", ai_response_content) |
|
|
|
|
|
|
|
|
current_memory_serializable = trim_memory(current_memory_serializable, max_size=20) |
|
|
|
|
|
|
|
|
session['chat_memory'] = current_memory_serializable |
|
|
|
|
|
return jsonify({"response": ai_response_content}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): |
|
|
|
|
|
""" |
|
|
Construct the list of messages suitable for the Groq API's 'messages' parameter |
|
|
by combining relevant memory and the current user input. |
|
|
Adds relevant memory chronologically from the session history. |
|
|
""" |
|
|
|
|
|
relevant_memory_items = retrieve_relevant_memory(mem_list, user_input) |
|
|
|
|
|
relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} |
|
|
|
|
|
messages_for_api = [] |
|
|
|
|
|
messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."}) |
|
|
|
|
|
current_prompt_tokens = len(messages_for_api[0]["content"].split()) |
|
|
|
|
|
|
|
|
context_messages = [] |
|
|
for msg in mem_list: |
|
|
|
|
|
|
|
|
if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]: |
|
|
|
|
|
msg_text = f'{msg["role"]}: {msg["content"]}\n' |
|
|
msg_tokens = len(msg_text.split()) |
|
|
if current_prompt_tokens + msg_tokens > max_tokens_in_prompt: |
|
|
break |
|
|
|
|
|
|
|
|
context_messages.append({"role": msg["role"], "content": msg["content"]}) |
|
|
current_prompt_tokens += msg_tokens |
|
|
|
|
|
|
|
|
|
|
|
messages_for_api.extend(context_messages) |
|
|
|
|
|
|
|
|
|
|
|
user_input_tokens = len(user_input.split()) |
|
|
if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1: |
|
|
logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.") |
|
|
pass |
|
|
|
|
|
messages_for_api.append({"role": "user", "content": user_input}) |
|
|
|
|
|
return messages_for_api |
|
|
|
|
|
|
|
|
def trim_memory(mem_list, max_size=50): |
|
|
|
|
|
""" |
|
|
Trim the memory list to keep it within the specified max size. |
|
|
Removes the oldest entries (from the beginning of the list). |
|
|
Returns the trimmed list. |
|
|
""" |
|
|
while len(mem_list) > max_size: |
|
|
mem_list.pop(0) |
|
|
return mem_list |
|
|
|
|
|
def summarize_memory(mem_list): |
|
|
|
|
|
""" |
|
|
Summarize the memory buffer to free up space. |
|
|
""" |
|
|
if not mem_list or client is None: |
|
|
logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.") |
|
|
return [] |
|
|
|
|
|
long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m]) |
|
|
if not long_term_memory.strip(): |
|
|
logging.warning("Memory content is empty. Cannot summarize.") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
summary_completion = client.chat.completions.create( |
|
|
model="llama-3.1-8b-instruct-fpt", |
|
|
messages=[ |
|
|
{"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."}, |
|
|
{"role": "user", "content": long_term_memory}, |
|
|
], |
|
|
max_tokens= 500, |
|
|
) |
|
|
summary_text = summary_completion.choices[0].message.content |
|
|
logging.info("Memory summarized.") |
|
|
|
|
|
return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}] |
|
|
except Exception as e: |
|
|
logging.error(f"Error summarizing memory: {e}") |
|
|
return mem_list |
|
|
|
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
|
|
|
""" |
|
|
Serve the main chat interface page. |
|
|
""" |
|
|
if 'chat_memory' not in session: |
|
|
session['chat_memory'] = [] |
|
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/clear_memory', methods=['POST']) |
|
|
def clear_memory(): |
|
|
|
|
|
""" |
|
|
Clear the chat memory from the session. |
|
|
""" |
|
|
session['chat_memory'] = [] |
|
|
logging.info("Chat memory cleared.") |
|
|
return jsonify({"status": "Memory cleared."}) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
logging.info("Starting Waitress server...") |
|
|
port = int(os.environ.get('PORT', 7860)) |
|
|
serve(app, host='0.0.0.0', port=port) |