File size: 14,918 Bytes
37023b5
 
e4ed1f7
 
37023b5
 
 
e4ed1f7
 
 
 
 
 
 
 
 
 
 
 
 
37023b5
 
 
 
e4ed1f7
 
 
 
 
37023b5
e4ed1f7
 
 
 
37023b5
e4ed1f7
 
37023b5
 
 
 
 
 
 
 
 
 
 
 
 
e4ed1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37023b5
e4ed1f7
 
 
37023b5
e4ed1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37023b5
 
 
 
 
 
e4ed1f7
 
 
 
37023b5
e4ed1f7
 
 
 
 
 
 
 
 
 
37023b5
 
 
 
 
 
 
 
e4ed1f7
 
 
 
37023b5
 
 
e4ed1f7
 
 
 
 
 
37023b5
 
e4ed1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37023b5
 
 
e4ed1f7
 
37023b5
 
 
 
 
 
 
 
 
 
 
 
e4ed1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37023b5
e4ed1f7
37023b5
 
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
e4ed1f7
 
 
37023b5
 
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
 
e4ed1f7
37023b5
e4ed1f7
 
37023b5
 
 
 
 
e4ed1f7
37023b5
 
 
 
e4ed1f7
37023b5
 
 
e4ed1f7
 
37023b5
 
e4ed1f7
37023b5
 
 
 
e4ed1f7
37023b5
 
 
 
 
 
 
 
 
e4ed1f7
37023b5
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import os
from flask import Flask, render_template, request, jsonify, session
# Removed SentenceTransformer
# from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import numpy as np
import logging
# Import necessary components from transformers and torch
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F # For normalization
# Ensure torch is using CPU if GPU is not available (standard for free tier)
torch.set_num_threads(1) # Limit threads for resource efficiency
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
logging.info(f"Using device: {device}")


# Configure logging
logging.basicConfig(level=logging.INFO)

# --- Initialize Models (Load these once using transformers) ---
tokenizer = None
model = None
client = None

try:
    # Load tokenizer and model from HuggingFace Hub using transformers
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
    model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device) # Move model to device
    logging.info("Tokenizer and AutoModel loaded successfully.")
except Exception as e:
    logging.error(f"Error loading Transformer models: {e}")
    # Models are None, will be handled below

# Initialize the Groq client
groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
    logging.error("GROQ_API_KEY environment variable not set.")
    # In deployment, this should ideally stop the app or show an error page
    # For this example, we'll proceed but Groq calls will fail
    client = None
else:
    client = Groq(api_key=groq_api_key)
    logging.info("Groq client initialized.")


# --- Helper function for Mean Pooling (from documentation) ---
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# --- Function to get embedding using transformers and pooling ---
def get_embedding(text):
    """
    Generate embedding for a single text using transformers and mean pooling.
    Returns a numpy array.
    """
    if tokenizer is None or model is None:
        logging.error("Embedding models not loaded. Cannot generate embedding.")
        return None
    try:
        # Tokenize the input text
        encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) # Move input to device

        # Compute token embeddings
        with torch.no_grad(): # Disable gradient calculation for inference
            model_output = model(**encoded_input)

        # Perform pooling
        sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)

        # Convert to numpy and return
        return sentence_embedding.cpu().numpy()[0] # Move back to CPU and get the single embedding array

    except Exception as e:
        logging.error(f"Error generating embedding: {e}")
        return None


# --- Memory Management Functions (Adapted for Sessions and new embedding method) ---

def add_to_memory(mem_list, role, content):
    """
    Add a message to the provided memory list along with its embedding.
    Returns the updated list.
    """
    # Ensure content is not empty
    if not content or not content.strip():
         logging.warning(f"Attempted to add empty content to memory for role: {role}")
         return mem_list # Do not add empty messages

    embedding = get_embedding(content) # Use the new get_embedding function

    if embedding is not None:
        mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) # Store embedding as list for JSON serializability
    else:
        # Add message without embedding if embedding failed
        logging.warning(f"Failed to get embedding for message: {content[:50]}...")
        mem_list.append({"role": role, "content": content, "embedding": None}) # Store None for embedding

    return mem_list


def retrieve_relevant_memory(mem_list, user_input, top_k=5):
    """
    Retrieve the top-k most relevant messages from the provided memory list
    based on cosine similarity with user_input.
    Returns a list of relevant messages (dictionaries).
    """
    # Ensure we have valid memory entries with embeddings and the necessary models
    valid_memory_with_embeddings = [m for m in mem_list if m.get("embedding") is not None]

    if not valid_memory_with_embeddings:
        return []

    try:
        # Compute the embedding of the user input using the new function
        user_embedding = get_embedding(user_input)

        if user_embedding is None:
            logging.error("Failed to get user input embedding for retrieval.")
            return [] # Cannot retrieve if user embedding fails

        # Calculate similarities. Ensure all memory entries have valid embeddings.
        memory_items = []
        memory_embeddings = []
        for m in valid_memory_with_embeddings:
            try:
                 # Attempt to convert embedding list back to numpy array
                np_embedding = np.array(m["embedding"])
                # Optional: Check dimension if known (e.g., 384 for all-MiniLM-L6-v2)
                if np_embedding.shape == (model.config.hidden_size,): # Check dimension based on loaded model config
                     memory_items.append(m)
                     memory_embeddings.append(np_embedding)
                else:
                    logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...")

            except Exception as conv_e:
                logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
                pass # Skip this memory entry if embedding is invalid or conversion fails


        if not memory_items: # Check again after filtering
             return []

        # Calculate similarities
        # Ensure both are numpy arrays
        similarities = cosine_similarity([user_embedding], np.array(memory_embeddings))[0]

        # Sort memory by similarity and return the top-k messages
        relevant_messages_sorted = sorted(zip(similarities, memory_items), key=lambda x: x[0], reverse=True)

        # Return the message dictionaries
        return [m[1] for m in relevant_messages_sorted[:top_k]]

    except Exception as e:
        logging.error(f"Error retrieving memory: {e}")
        return []


# construct_prompt, trim_memory, summarize_memory, index, chat, clear_memory routes
# and the final if __name__ == '__main__': block remain largely the same,
# except they now rely on the global `tokenizer` and `model` being initialized
# and call the new `get_embedding` function internally.

# Ensure the check in the chat route verifies tokenizer and model are not None
@app.route('/chat', methods=['POST'])
def chat():
    """
    Handle incoming chat messages, process with the bot logic,
    update session memory, and return the AI response.
    """
    # Check if Groq client AND embedding models are initialized
    if client is None or tokenizer is None or model is None:
         status_code = 500
         error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)."
         logging.error(error_message)
         return jsonify({"response": error_message}), status_code

    # ... (rest of the chat function is the same) ...
    user_input = request.json.get('message')
    if not user_input or not user_input.strip():
        return jsonify({"response": "Please enter a message."}), 400

    current_memory_serializable = session.get('chat_memory', [])
    # No need to convert embeddings to numpy here, construct_prompt does it if needed via retrieve_relevant_memory

    messages_for_api = construct_prompt(current_memory_serializable, user_input)

    try:
        # Get response from the model
        completion = client.chat.completions.create(
            model="llama-3.1-8b-instruct-fpt", # Use a suitable, available model
            messages=messages_for_api, # Pass the list of messages
            temperature=0.6,
            max_tokens=1024,  # Limit response length
            top_p=0.95,
            stream=False, # Disable streaming for simpler HTTP response handling
            stop=None,
        )
        ai_response_content = completion.choices[0].message.content

    except Exception as e:
        logging.error(f"Error calling Groq API: {e}")
        ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."


    # Update Memory Buffer (get_embedding is called within add_to_memory)
    current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
    current_memory_serializable = add_to_memory(current_memory_serializable, "assistant", ai_response_content)

    # Trim Memory
    current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)

    # Store updated memory back into the session
    session['chat_memory'] = current_memory_serializable

    return jsonify({"response": ai_response_content})


# The construct_prompt, trim_memory, summarize_memory, index, clear_memory functions are mostly unchanged,
# but they now rely on the global `tokenizer` and `model` being available.
# construct_prompt calls retrieve_relevant_memory which calls get_embedding.

def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): # Increased max tokens slightly
    # ... (This function remains the same as before, it calls retrieve_relevant_memory) ...
    """
    Construct the list of messages suitable for the Groq API's 'messages' parameter
    by combining relevant memory and the current user input.
    Adds relevant memory chronologically from the session history.
    """
    # Retrieve relevant memory *content* based on similarity
    relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
    # Create a set of content strings from the relevant items for quick lookup
    relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} # Added content check

    messages_for_api = []
    # Add a system message
    messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."})

    current_prompt_tokens = len(messages_for_api[0]["content"].split()) # Start count with system message

    # Iterate through chronological session memory and add relevant messages that are also in the relevant_content_set
    context_messages = []
    for msg in mem_list:
        # Only add messages whose content is found in the top-k relevant messages
        # and which have a role suitable for the API messages list
        if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
             # Estimate tokens for this message (simple word count)
            msg_text = f'{msg["role"]}: {msg["content"]}\n'
            msg_tokens = len(msg_text.split())
            if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
                break # Stop if adding this message exceeds the limit

            # Add the message in the format expected by the API
            context_messages.append({"role": msg["role"], "content": msg["content"]})
            current_prompt_tokens += msg_tokens


    # Add the chronological context messages
    messages_for_api.extend(context_messages)

    # Add the current user input as the final message
    # Ensure user input itself doesn't push over the limit significantly (though it should always be included)
    user_input_tokens = len(user_input.split())
    if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
         logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.")
         pass # User input is always added

    messages_for_api.append({"role": "user", "content": user_input})

    return messages_for_api


def trim_memory(mem_list, max_size=50):
    # ... (This function is unchanged) ...
    """
    Trim the memory list to keep it within the specified max size.
    Removes the oldest entries (from the beginning of the list).
    Returns the trimmed list.
    """
    while len(mem_list) > max_size:
        mem_list.pop(0)  # Remove the oldest entry
    return mem_list

def summarize_memory(mem_list):
    # ... (This function is unchanged, relies on global client) ...
    """
    Summarize the memory buffer to free up space.
    """
    if not mem_list or client is None:
        logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
        return []

    long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m])
    if not long_term_memory.strip():
         logging.warning("Memory content is empty. Cannot summarize.")
         return []

    try:
        summary_completion = client.chat.completions.create(
            model="llama-3.1-8b-instruct-fpt",
            messages=[
                {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
                {"role": "user", "content": long_term_memory},
            ],
            max_tokens= 500,
        )
        summary_text = summary_completion.choices[0].message.content
        logging.info("Memory summarized.")
        # When replacing with summary, the embedding logic becomes less relevant for this entry type
        return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}]
    except Exception as e:
        logging.error(f"Error summarizing memory: {e}")
        return mem_list


@app.route('/')
def index():
    # ... (This route is unchanged) ...
    """
    Serve the main chat interface page.
    """
    if 'chat_memory' not in session:
        session['chat_memory'] = []
    return render_template('index.html')

@app.route('/clear_memory', methods=['POST'])
def clear_memory():
    # ... (This route is unchanged) ...
    """
    Clear the chat memory from the session.
    """
    session['chat_memory'] = []
    logging.info("Chat memory cleared.")
    return jsonify({"status": "Memory cleared."})


# --- Running the App ---
if __name__ == '__main__':
    logging.info("Starting Waitress server...")
    port = int(os.environ.get('PORT', 7860))
    serve(app, host='0.0.0.0', port=port)