File size: 9,895 Bytes
02ce8d2
 
b251dd6
 
02ce8d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b251dd6
 
02ce8d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b251dd6
02ce8d2
 
b251dd6
02ce8d2
 
b251dd6
02ce8d2
b251dd6
02ce8d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b251dd6
02ce8d2
 
 
 
 
 
 
b251dd6
02ce8d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
from flask import Flask, render_template, request, jsonify, session
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import numpy as np
import logging
from transformers import AutoTokenizer, AutoModel # Keep these
import torch
import torch.nn.functional as F

# Configure logging
logging.basicConfig(level=logging.INFO)

# --- Flask App Setup --- (MUST come before routes or app-dependent code) ---
app = Flask(__name__)
app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'a_default_secret_key_please_change')

# --- Initialize Models ---
device = torch.device("cpu") # Force CPU for free tier
if torch.cuda.is_available():
    device = torch.device("cuda") # Should not happen on free tier
logging.info(f"Using device: {device}")

tokenizer = None
model = None
client = None

try:
    # Load tokenizer and model from HuggingFace Hub using transformers
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
    # Re-add from_tf=True here for AutoModel.from_pretrained
    model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', from_tf=True).to(device)
    logging.info("Tokenizer and AutoModel loaded successfully with from_tf=True.")
except Exception as e:
    logging.error(f"Error loading Transformer models: {e}")
    tokenizer = None
    model = None

# Initialize the Groq client
groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
    logging.error("GROQ_API_KEY environment variable not set.")
    client = None
else:
    client = Groq(api_key=groq_api_key)
    logging.info("Groq client initialized.")


# --- Helper function for Mean Pooling ---
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(token_embeddings.device)
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# --- Function to get embedding ---
def get_embedding(text):
    if tokenizer is None or model is None:
        logging.error("Embedding models not loaded. Cannot generate embedding.")
        return None
    try:
        encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)
        sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)
        return sentence_embedding.cpu().numpy()[0]
    except Exception as e:
        logging.error(f"Error generating embedding: {e}")
        return None

# --- Memory Management Functions (rely on get_embedding) ---
# ... (add_to_memory, retrieve_relevant_memory, construct_prompt, trim_memory, summarize_memory - these remain the same, calling get_embedding) ...

def add_to_memory(mem_list, role, content):
    if not content or not content.strip():
         logging.warning(f"Attempted to add empty content to memory for role: {role}")
         return mem_list
    embedding = get_embedding(content)
    if embedding is not None:
        mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()})
    else:
        logging.warning(f"Failed to get embedding for message: {content[:50]}...")
        mem_list.append({"role": role, "content": content, "embedding": None})
    return mem_list

def retrieve_relevant_memory(mem_list, user_input, top_k=5):
    if not mem_list or tokenizer is None or model is None:
        return []
    user_embedding = get_embedding(user_input)
    if user_embedding is None:
        logging.error("Failed to get user input embedding for retrieval.")
        return []

    valid_memory_items = []
    memory_embeddings_np = []
    for m in mem_list:
        if m.get("embedding") is not None and isinstance(m["embedding"], list):
            try:
                 np_embedding = np.array(m["embedding"])
                 if np_embedding.shape == (model.config.hidden_size,): # Use model config for dimension
                      valid_memory_items.append(m)
                      memory_embeddings_np.append(np_embedding)
                 else:
                     logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...")
            except Exception as conv_e:
                logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
                pass

    if not valid_memory_items:
         return []
    similarities = cosine_similarity([user_embedding], np.array(memory_embeddings_np))[0]
    relevant_messages_sorted = sorted(zip(similarities, valid_memory_items), key=lambda x: x[0], reverse=True)
    return [m[1] for m in relevant_messages_sorted[:top_k]]

def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000):
    relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
    relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m}

    messages_for_api = []
    messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."})
    current_prompt_tokens = len(messages_for_api[0]["content"].split())

    context_messages = []
    for msg in mem_list:
        if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
            msg_text = f'{msg["role"]}: {msg["content"]}\n'
            msg_tokens = len(msg_text.split())
            if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
                break
            context_messages.append({"role": msg["role"], "content": msg["content"]})
            current_prompt_tokens += msg_tokens

    messages_for_api.extend(context_messages)
    user_input_tokens = len(user_input.split())
    if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
         logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.")
    messages_for_api.append({"role": "user", "content": user_input})
    return messages_for_api

def trim_memory(mem_list, max_size=50):
    while len(mem_list) > max_size:
        mem_list.pop(0)
    return mem_list

def summarize_memory(mem_list):
    if not mem_list or client is None:
        logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
        return []
    long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m])
    if not long_term_memory.strip():
         logging.warning("Memory content is empty. Cannot summarize.")
         return []
    try:
        summary_completion = client.chat.completions.create(
            model="llama-3.1-8b-instruct-fpt",
            messages=[
                {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
                {"role": "user", "content": long_term_memory},
            ],
            max_tokens= 500,
        )
        summary_text = summary_completion.choices[0].message.content
        logging.info("Memory summarized.")
        return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}]
    except Exception as e:
        logging.error(f"Error summarizing memory: {e}")
        return mem_list


# --- Flask Routes --- (MUST come AFTER app is defined) ---

@app.route('/')
def index():
    if 'chat_memory' not in session:
        session['chat_memory'] = []
    return render_template('index.html')

@app.route('/chat', methods=['POST'])
def chat():
    # Check if Groq client AND embedding models are initialized
    if client is None or tokenizer is None or model is None:
         status_code = 500
         error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)."
         logging.error(error_message)
         return jsonify({"response": error_message}), status_code

    user_input = request.json.get('message')
    if not user_input or not user_input.strip():
        return jsonify({"response": "Please enter a message."}), 400

    current_memory_serializable = session.get('chat_memory', [])

    messages_for_api = construct_prompt(current_memory_serializable, user_input)

    try:
        completion = client.chat.completions.create(
            model="llama-3.1-8b-instruct-fpt",
            messages=messages_for_api,
            temperature=0.6,
            max_tokens=1024,
            top_p=0.95,
            stream=False,
            stop=None,
        )
        ai_response_content = completion.choices[0].message.content

    except Exception as e:
        logging.error(f"Error calling Groq API: {e}")
        ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."

    current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
    current_memory_serializable = add_to_memory(current_memory_serialable, "assistant", ai_response_content)

    current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)

    session['chat_memory'] = current_memory_serializable

    return jsonify({"response": ai_response_content})


@app.route('/clear_memory', methods=['POST'])
def clear_memory():
    session['chat_memory'] = []
    logging.info("Chat memory cleared.")
    return jsonify({"status": "Memory cleared."})


# --- Running the App ---
if __name__ == '__main__':
    # Using Uvicorn instead of Waitress
    logging.info("Starting Uvicorn server...")
    port = int(os.environ.get('PORT', 7860))
    # Use uvicorn.run to start the Flask app (which is a WSGI app)
    # It automatically detects it's a WSGI app
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=port)