Swaroop Ingavale commited on
Commit
02ce8d2
·
1 Parent(s): e2b56c1
Files changed (1) hide show
  1. app.py +226 -262
app.py CHANGED
@@ -1,272 +1,236 @@
1
- import gradio as gr
2
- from sentence_transformers import SentenceTransformer
3
  from sklearn.metrics.pairwise import cosine_similarity
4
- import numpy as np
5
  from groq import Groq
6
- import os
7
- import datetime
8
-
9
- client = Groq(
10
- api_key=os.environ.get("GROQ_API_KEY"),
11
- )
12
-
13
- # Initialize sentence transformer model
14
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
15
-
16
- # Global memory buffer with embeddings
17
- memory = []
18
-
19
- def add_to_memory(role, content):
20
- """
21
- Add a message to memory along with its embedding.
22
- """
23
- embedding = embedding_model.encode(content, convert_to_numpy=True)
24
- memory.append({"role": role, "content": content, "embedding": embedding})
25
-
26
- def retrieve_relevant_memory(user_input, top_k=5):
27
- """
28
- Retrieve the top-k most relevant messages from memory based on cosine similarity.
29
- """
30
- if not memory:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return []
32
 
33
- # Compute the embedding of the user input
34
- user_embedding = embedding_model.encode(user_input, convert_to_numpy=True)
35
-
36
- # Calculate similarities
37
- similarities = [cosine_similarity([user_embedding], [m["embedding"]])[0][0] for m in memory]
38
-
39
- # Sort memory by similarity and return the top-k messages
40
- relevant_messages = sorted(zip(similarities, memory), key=lambda x: x[0], reverse=True)
41
- return [m[1] for m in relevant_messages[:top_k]]
42
-
43
- def construct_prompt(memory, user_input, max_tokens=500):
44
- """
45
- Construct the prompt by combining relevant memory and the current user input.
46
- """
47
- relevant_memory = retrieve_relevant_memory(user_input)
48
-
49
- # Combine relevant memory into the prompt
50
- prompt = ""
51
- token_count = 0
52
- for message in relevant_memory:
53
- message_text = f'{message["role"]}: {message["content"]}\n'
54
- token_count += len(message_text.split())
55
- if token_count > max_tokens:
56
- break
57
- prompt += message_text
58
-
59
- # Add the user input at the end
60
- prompt += f'user: {user_input}\n'
61
- return prompt
62
-
63
- def trim_memory(max_size=50):
64
- """
65
- Trim the memory to keep it within the specified max size.
66
- """
67
- if len(memory) > max_size:
68
- memory.pop(0) # Remove the oldest entry
69
-
70
- def summarize_memory():
71
- """
72
- Summarize the memory buffer to free up space.
73
- """
74
- if not memory:
75
- return
76
-
77
- long_term_memory = " ".join([m["content"] for m in memory])
78
- summary = client.chat.completions.create(
79
- messages=[
80
- {"role": "system", "content": "Summarize the following text for key points."},
81
- {"role": "user", "content": long_term_memory},
82
- ],
83
- model="meta-llama/llama-4-scout-17b-16e-instruct",
84
- max_tokens=4096,
85
- )
86
- memory.clear()
87
- # Match the access pattern from main.py if needed
 
 
 
 
88
  try:
89
- # Try the format in app.py first
90
- summary_content = summary.choices[0].message.content
91
- except AttributeError:
92
- # Fall back to the format in main.py
93
- summary_content = summary.choices[0].text
94
-
95
- memory.append({"role": "system", "content": summary_content})
96
-
97
- def get_chatbot_response(
98
- message,
99
- history,
100
- system_message,
101
- max_tokens,
102
- temperature,
103
- top_p,
104
- use_memory=True,
105
- memory_size=50,
106
- ):
107
- """
108
- Generate a response using the chatbot with memory capabilities.
109
- """
110
- if use_memory:
111
- # Process history to maintain memory
112
- for i, (user_msg, bot_msg) in enumerate(history):
113
- if i < len(history) - 1: # Skip the current message which is already in the history
114
- add_to_memory("user", user_msg)
115
- if bot_msg: # Check if bot message exists (might be None for the most recent one)
116
- add_to_memory("assistant", bot_msg)
117
-
118
- # Construct prompt with relevant memory
119
- prompt = construct_prompt(memory, message)
120
-
121
- # Use the prompt with groq client
122
- completion = client.chat.completions.create(
123
  messages=[
124
- {"role": "system", "content": system_message},
125
- {"role": "user", "content": prompt}
126
  ],
127
- model="deepseek-r1-distill-llama-70b",
128
- temperature=temperature,
129
- max_tokens=max_tokens,
130
- top_p=top_p,
131
- stream=True,
132
  )
133
-
134
- # Stream the response
135
- response = ""
136
- for chunk in completion:
137
- response_part = chunk.choices[0].delta.content or ""
138
- response += response_part
139
- yield response
140
-
141
- # Update memory with the current message and response
142
- add_to_memory("user", message)
143
- add_to_memory("assistant", response)
144
-
145
- # Trim memory if needed
146
- trim_memory(max_size=memory_size)
147
-
148
- else:
149
- # If not using memory, just use regular chat completion
150
- messages = [{"role": "system", "content": system_message}]
151
-
152
- for val in history:
153
- if val[0]:
154
- messages.append({"role": "user", "content": val[0]})
155
- if val[1]:
156
- messages.append({"role": "assistant", "content": val[1]})
157
-
158
- messages.append({"role": "user", "content": message})
159
-
 
 
 
 
 
 
 
160
  completion = client.chat.completions.create(
161
- messages=messages,
162
- model="deepseek-r1-distill-llama-70b",
163
- temperature=temperature,
164
- max_tokens=max_tokens,
165
- top_p=top_p,
166
- stream=True,
 
167
  )
168
-
169
- response = ""
170
- for chunk in completion:
171
- response_part = chunk.choices[0].delta.content or ""
172
- response += response_part
173
- yield response
174
-
175
- def view_memory():
176
- """
177
- Create a formatted string showing the current memory contents.
178
- """
179
- if not memory:
180
- return "Memory is empty."
181
-
182
- memory_view = "Current Memory Contents:\n\n"
183
- for i, m in enumerate(memory):
184
- memory_view += f"Memory {i+1}: {m['role']}: {m['content']}\n\n"
185
-
186
- return memory_view
187
-
188
- def clear_memory_action():
189
- """
190
- Clear the memory buffer.
191
- """
192
- memory.clear()
193
- return "Memory has been cleared."
194
-
195
- # Custom CSS for the chat interface - apply using elem_classes
196
- custom_css = """
197
- .user-message {
198
- background-color: #e3f2fd !important;
199
- border-radius: 15px !important;
200
- padding: 10px 15px !important;
201
- }
202
-
203
- .bot-message {
204
- background-color: #f1f8e9 !important;
205
- border-radius: 15px !important;
206
- padding: 10px 15px !important;
207
- }
208
- """
209
-
210
- # Create the Gradio interface
211
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
212
- # Header
213
- with gr.Row(elem_classes="header-row"):
214
- gr.Markdown("""
215
- <div style="text-align: center; margin-bottom: 10px; padding: 10px; background-color: #f0f4f8; border-radius: 8px;">
216
- <h1 style="margin: 0; color: #2c3e50;">AI Chatbot With Memory</h1>
217
- <h3 style="margin: 5px 0 0 0; color: #34495e;">Developed by Dhiraj and Swaroop</h3>
218
- </div>
219
- """)
220
-
221
- with gr.Row():
222
- with gr.Column(scale=3):
223
- # Create ChatInterface without css_classes parameter
224
- chatbot = gr.ChatInterface(
225
- get_chatbot_response,
226
- additional_inputs=[
227
- gr.Textbox(value="You are a helpful assistant with memory capabilities.", label="System message"),
228
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
229
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
230
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
231
- gr.Checkbox(value=True, label="Use Memory", info="Enable or disable memory capabilities"),
232
- gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Memory Size", info="Maximum number of entries in memory"),
233
- ],
234
- examples=[
235
- ["Tell me about machine learning"],
236
- ["What are the best practices for data preprocessing?"],
237
- ["Can you explain neural networks?"],
238
- ],
239
- title="Chat with AI Assistant",
240
- # Removed css_classes parameter
241
- )
242
-
243
- with gr.Column(scale=1):
244
- with gr.Group():
245
- gr.Markdown("## Memory Management")
246
- memory_display = gr.Textbox(label="Memory Contents", lines=20, max_lines=30, interactive=False)
247
- view_memory_btn = gr.Button("View Memory Contents")
248
- clear_memory_btn = gr.Button("Clear Memory")
249
- summarize_memory_btn = gr.Button("Summarize Memory")
250
- memory_status = gr.Textbox(label="Memory Status", lines=2, interactive=False)
251
-
252
- # Set up button actions
253
- view_memory_btn.click(view_memory, inputs=[], outputs=[memory_display])
254
- clear_memory_btn.click(clear_memory_action, inputs=[], outputs=[memory_status])
255
- summarize_memory_btn.click(
256
- lambda: (summarize_memory(), "Memory summarized successfully."),
257
- inputs=[],
258
- outputs=[memory_status]
259
- )
260
-
261
- # Footer
262
- with gr.Row(elem_classes="footer-row"):
263
- gr.Markdown(f"""
264
- <div style="text-align: center; margin-top: 20px; padding: 10px; background-color: #f0f4f8; border-radius: 8px;">
265
- <p style="margin: 0; color: #2c3e50;">
266
- Developed by Dhiraj and Swaroop | © {datetime.datetime.now().year} | Version 1.0
267
- </p>
268
- </div>
269
- """)
270
-
271
- if __name__ == "__main__":
272
- demo.launch()
 
1
+ import os
2
+ from flask import Flask, render_template, request, jsonify, session
3
  from sklearn.metrics.pairwise import cosine_similarity
 
4
  from groq import Groq
5
+ import numpy as np
6
+ import logging
7
+ from transformers import AutoTokenizer, AutoModel # Keep these
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ # --- Flask App Setup --- (MUST come before routes or app-dependent code) ---
15
+ app = Flask(__name__)
16
+ app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'a_default_secret_key_please_change')
17
+
18
+ # --- Initialize Models ---
19
+ device = torch.device("cpu") # Force CPU for free tier
20
+ if torch.cuda.is_available():
21
+ device = torch.device("cuda") # Should not happen on free tier
22
+ logging.info(f"Using device: {device}")
23
+
24
+ tokenizer = None
25
+ model = None
26
+ client = None
27
+
28
+ try:
29
+ # Load tokenizer and model from HuggingFace Hub using transformers
30
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
31
+ # Re-add from_tf=True here for AutoModel.from_pretrained
32
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', from_tf=True).to(device)
33
+ logging.info("Tokenizer and AutoModel loaded successfully with from_tf=True.")
34
+ except Exception as e:
35
+ logging.error(f"Error loading Transformer models: {e}")
36
+ tokenizer = None
37
+ model = None
38
+
39
+ # Initialize the Groq client
40
+ groq_api_key = os.environ.get("GROQ_API_KEY")
41
+ if not groq_api_key:
42
+ logging.error("GROQ_API_KEY environment variable not set.")
43
+ client = None
44
+ else:
45
+ client = Groq(api_key=groq_api_key)
46
+ logging.info("Groq client initialized.")
47
+
48
+
49
+ # --- Helper function for Mean Pooling ---
50
+ def mean_pooling(model_output, attention_mask):
51
+ token_embeddings = model_output[0]
52
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(token_embeddings.device)
53
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
54
+
55
+ # --- Function to get embedding ---
56
+ def get_embedding(text):
57
+ if tokenizer is None or model is None:
58
+ logging.error("Embedding models not loaded. Cannot generate embedding.")
59
+ return None
60
+ try:
61
+ encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
62
+ with torch.no_grad():
63
+ model_output = model(**encoded_input)
64
+ sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
65
+ sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)
66
+ return sentence_embedding.cpu().numpy()[0]
67
+ except Exception as e:
68
+ logging.error(f"Error generating embedding: {e}")
69
+ return None
70
+
71
+ # --- Memory Management Functions (rely on get_embedding) ---
72
+ # ... (add_to_memory, retrieve_relevant_memory, construct_prompt, trim_memory, summarize_memory - these remain the same, calling get_embedding) ...
73
+
74
+ def add_to_memory(mem_list, role, content):
75
+ if not content or not content.strip():
76
+ logging.warning(f"Attempted to add empty content to memory for role: {role}")
77
+ return mem_list
78
+ embedding = get_embedding(content)
79
+ if embedding is not None:
80
+ mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()})
81
+ else:
82
+ logging.warning(f"Failed to get embedding for message: {content[:50]}...")
83
+ mem_list.append({"role": role, "content": content, "embedding": None})
84
+ return mem_list
85
+
86
+ def retrieve_relevant_memory(mem_list, user_input, top_k=5):
87
+ if not mem_list or tokenizer is None or model is None:
88
+ return []
89
+ user_embedding = get_embedding(user_input)
90
+ if user_embedding is None:
91
+ logging.error("Failed to get user input embedding for retrieval.")
92
  return []
93
 
94
+ valid_memory_items = []
95
+ memory_embeddings_np = []
96
+ for m in mem_list:
97
+ if m.get("embedding") is not None and isinstance(m["embedding"], list):
98
+ try:
99
+ np_embedding = np.array(m["embedding"])
100
+ if np_embedding.shape == (model.config.hidden_size,): # Use model config for dimension
101
+ valid_memory_items.append(m)
102
+ memory_embeddings_np.append(np_embedding)
103
+ else:
104
+ logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...")
105
+ except Exception as conv_e:
106
+ logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
107
+ pass
108
+
109
+ if not valid_memory_items:
110
+ return []
111
+ similarities = cosine_similarity([user_embedding], np.array(memory_embeddings_np))[0]
112
+ relevant_messages_sorted = sorted(zip(similarities, valid_memory_items), key=lambda x: x[0], reverse=True)
113
+ return [m[1] for m in relevant_messages_sorted[:top_k]]
114
+
115
+ def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000):
116
+ relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
117
+ relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m}
118
+
119
+ messages_for_api = []
120
+ messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."})
121
+ current_prompt_tokens = len(messages_for_api[0]["content"].split())
122
+
123
+ context_messages = []
124
+ for msg in mem_list:
125
+ if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
126
+ msg_text = f'{msg["role"]}: {msg["content"]}\n'
127
+ msg_tokens = len(msg_text.split())
128
+ if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
129
+ break
130
+ context_messages.append({"role": msg["role"], "content": msg["content"]})
131
+ current_prompt_tokens += msg_tokens
132
+
133
+ messages_for_api.extend(context_messages)
134
+ user_input_tokens = len(user_input.split())
135
+ if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
136
+ logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.")
137
+ messages_for_api.append({"role": "user", "content": user_input})
138
+ return messages_for_api
139
+
140
+ def trim_memory(mem_list, max_size=50):
141
+ while len(mem_list) > max_size:
142
+ mem_list.pop(0)
143
+ return mem_list
144
+
145
+ def summarize_memory(mem_list):
146
+ if not mem_list or client is None:
147
+ logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
148
+ return []
149
+ long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m])
150
+ if not long_term_memory.strip():
151
+ logging.warning("Memory content is empty. Cannot summarize.")
152
+ return []
153
  try:
154
+ summary_completion = client.chat.completions.create(
155
+ model="llama-3.1-8b-instruct-fpt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  messages=[
157
+ {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
158
+ {"role": "user", "content": long_term_memory},
159
  ],
160
+ max_tokens= 500,
 
 
 
 
161
  )
162
+ summary_text = summary_completion.choices[0].message.content
163
+ logging.info("Memory summarized.")
164
+ return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}]
165
+ except Exception as e:
166
+ logging.error(f"Error summarizing memory: {e}")
167
+ return mem_list
168
+
169
+
170
+ # --- Flask Routes --- (MUST come AFTER app is defined) ---
171
+
172
+ @app.route('/')
173
+ def index():
174
+ if 'chat_memory' not in session:
175
+ session['chat_memory'] = []
176
+ return render_template('index.html')
177
+
178
+ @app.route('/chat', methods=['POST'])
179
+ def chat():
180
+ # Check if Groq client AND embedding models are initialized
181
+ if client is None or tokenizer is None or model is None:
182
+ status_code = 500
183
+ error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)."
184
+ logging.error(error_message)
185
+ return jsonify({"response": error_message}), status_code
186
+
187
+ user_input = request.json.get('message')
188
+ if not user_input or not user_input.strip():
189
+ return jsonify({"response": "Please enter a message."}), 400
190
+
191
+ current_memory_serializable = session.get('chat_memory', [])
192
+
193
+ messages_for_api = construct_prompt(current_memory_serializable, user_input)
194
+
195
+ try:
196
  completion = client.chat.completions.create(
197
+ model="llama-3.1-8b-instruct-fpt",
198
+ messages=messages_for_api,
199
+ temperature=0.6,
200
+ max_tokens=1024,
201
+ top_p=0.95,
202
+ stream=False,
203
+ stop=None,
204
  )
205
+ ai_response_content = completion.choices[0].message.content
206
+
207
+ except Exception as e:
208
+ logging.error(f"Error calling Groq API: {e}")
209
+ ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."
210
+
211
+ current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
212
+ current_memory_serializable = add_to_memory(current_memory_serialable, "assistant", ai_response_content)
213
+
214
+ current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)
215
+
216
+ session['chat_memory'] = current_memory_serializable
217
+
218
+ return jsonify({"response": ai_response_content})
219
+
220
+
221
+ @app.route('/clear_memory', methods=['POST'])
222
+ def clear_memory():
223
+ session['chat_memory'] = []
224
+ logging.info("Chat memory cleared.")
225
+ return jsonify({"status": "Memory cleared."})
226
+
227
+
228
+ # --- Running the App ---
229
+ if __name__ == '__main__':
230
+ # Using Uvicorn instead of Waitress
231
+ logging.info("Starting Uvicorn server...")
232
+ port = int(os.environ.get('PORT', 7860))
233
+ # Use uvicorn.run to start the Flask app (which is a WSGI app)
234
+ # It automatically detects it's a WSGI app
235
+ import uvicorn
236
+ uvicorn.run(app, host="0.0.0.0", port=port)