Swaroop Ingavale commited on
Commit
b251dd6
·
1 Parent(s): bccdade
Files changed (2) hide show
  1. app.py +254 -46
  2. custom_css.css +38 -0
app.py CHANGED
@@ -1,64 +1,272 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  message,
12
- history: list[tuple[str, str]],
13
  system_message,
14
  max_tokens,
15
  temperature,
16
  top_p,
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
  """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import numpy as np
5
+ from groq import Groq
6
+ import os
7
+ import datetime
8
 
9
+ client = Groq(
10
+ api_key=os.environ.get("GROQ_API_KEY"),
11
+ )
12
+
13
+ # Initialize sentence transformer model
14
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
15
+
16
+ # Global memory buffer with embeddings
17
+ memory = []
18
+
19
+ def add_to_memory(role, content):
20
+ """
21
+ Add a message to memory along with its embedding.
22
+ """
23
+ embedding = embedding_model.encode(content, convert_to_numpy=True)
24
+ memory.append({"role": role, "content": content, "embedding": embedding})
25
+
26
+ def retrieve_relevant_memory(user_input, top_k=5):
27
+ """
28
+ Retrieve the top-k most relevant messages from memory based on cosine similarity.
29
+ """
30
+ if not memory:
31
+ return []
32
+
33
+ # Compute the embedding of the user input
34
+ user_embedding = embedding_model.encode(user_input, convert_to_numpy=True)
35
+
36
+ # Calculate similarities
37
+ similarities = [cosine_similarity([user_embedding], [m["embedding"]])[0][0] for m in memory]
38
+
39
+ # Sort memory by similarity and return the top-k messages
40
+ relevant_messages = sorted(zip(similarities, memory), key=lambda x: x[0], reverse=True)
41
+ return [m[1] for m in relevant_messages[:top_k]]
42
+
43
+ def construct_prompt(memory, user_input, max_tokens=500):
44
+ """
45
+ Construct the prompt by combining relevant memory and the current user input.
46
+ """
47
+ relevant_memory = retrieve_relevant_memory(user_input)
48
 
49
+ # Combine relevant memory into the prompt
50
+ prompt = ""
51
+ token_count = 0
52
+ for message in relevant_memory:
53
+ message_text = f'{message["role"]}: {message["content"]}\n'
54
+ token_count += len(message_text.split())
55
+ if token_count > max_tokens:
56
+ break
57
+ prompt += message_text
58
 
59
+ # Add the user input at the end
60
+ prompt += f'user: {user_input}\n'
61
+ return prompt
62
+
63
+ def trim_memory(max_size=50):
64
+ """
65
+ Trim the memory to keep it within the specified max size.
66
+ """
67
+ if len(memory) > max_size:
68
+ memory.pop(0) # Remove the oldest entry
69
+
70
+ def summarize_memory():
71
+ """
72
+ Summarize the memory buffer to free up space.
73
+ """
74
+ if not memory:
75
+ return
76
+
77
+ long_term_memory = " ".join([m["content"] for m in memory])
78
+ summary = client.chat.completions.create(
79
+ messages=[
80
+ {"role": "system", "content": "Summarize the following text for key points."},
81
+ {"role": "user", "content": long_term_memory},
82
+ ],
83
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
84
+ max_tokens=4096,
85
+ )
86
+ memory.clear()
87
+ # Match the access pattern from main.py if needed
88
+ try:
89
+ # Try the format in app.py first
90
+ summary_content = summary.choices[0].message.content
91
+ except AttributeError:
92
+ # Fall back to the format in main.py
93
+ summary_content = summary.choices[0].text
94
+
95
+ memory.append({"role": "system", "content": summary_content})
96
+
97
+ def get_chatbot_response(
98
  message,
99
+ history,
100
  system_message,
101
  max_tokens,
102
  temperature,
103
  top_p,
104
+ use_memory=True,
105
+ memory_size=50,
106
  ):
107
+ """
108
+ Generate a response using the chatbot with memory capabilities.
109
+ """
110
+ if use_memory:
111
+ # Process history to maintain memory
112
+ for i, (user_msg, bot_msg) in enumerate(history):
113
+ if i < len(history) - 1: # Skip the current message which is already in the history
114
+ add_to_memory("user", user_msg)
115
+ if bot_msg: # Check if bot message exists (might be None for the most recent one)
116
+ add_to_memory("assistant", bot_msg)
117
+
118
+ # Construct prompt with relevant memory
119
+ prompt = construct_prompt(memory, message)
120
+
121
+ # Use the prompt with groq client
122
+ completion = client.chat.completions.create(
123
+ messages=[
124
+ {"role": "system", "content": system_message},
125
+ {"role": "user", "content": prompt}
126
+ ],
127
+ model="deepseek-r1-distill-llama-70b",
128
+ temperature=temperature,
129
+ max_tokens=max_tokens,
130
+ top_p=top_p,
131
+ stream=True,
132
+ )
133
+
134
+ # Stream the response
135
+ response = ""
136
+ for chunk in completion:
137
+ response_part = chunk.choices[0].delta.content or ""
138
+ response += response_part
139
+ yield response
140
+
141
+ # Update memory with the current message and response
142
+ add_to_memory("user", message)
143
+ add_to_memory("assistant", response)
144
+
145
+ # Trim memory if needed
146
+ trim_memory(max_size=memory_size)
147
+
148
+ else:
149
+ # If not using memory, just use regular chat completion
150
+ messages = [{"role": "system", "content": system_message}]
151
+
152
+ for val in history:
153
+ if val[0]:
154
+ messages.append({"role": "user", "content": val[0]})
155
+ if val[1]:
156
+ messages.append({"role": "assistant", "content": val[1]})
157
+
158
+ messages.append({"role": "user", "content": message})
159
+
160
+ completion = client.chat.completions.create(
161
+ messages=messages,
162
+ model="deepseek-r1-distill-llama-70b",
163
+ temperature=temperature,
164
+ max_tokens=max_tokens,
165
+ top_p=top_p,
166
+ stream=True,
167
+ )
168
+
169
+ response = ""
170
+ for chunk in completion:
171
+ response_part = chunk.choices[0].delta.content or ""
172
+ response += response_part
173
+ yield response
174
 
175
+ def view_memory():
176
+ """
177
+ Create a formatted string showing the current memory contents.
178
+ """
179
+ if not memory:
180
+ return "Memory is empty."
181
+
182
+ memory_view = "Current Memory Contents:\n\n"
183
+ for i, m in enumerate(memory):
184
+ memory_view += f"Memory {i+1}: {m['role']}: {m['content']}\n\n"
185
+
186
+ return memory_view
187
 
188
+ def clear_memory_action():
189
+ """
190
+ Clear the memory buffer.
191
+ """
192
+ memory.clear()
193
+ return "Memory has been cleared."
194
 
195
+ # Custom CSS for the chat interface - apply using elem_classes
196
+ custom_css = """
197
+ .user-message {
198
+ background-color: #e3f2fd !important;
199
+ border-radius: 15px !important;
200
+ padding: 10px 15px !important;
201
+ }
 
202
 
203
+ .bot-message {
204
+ background-color: #f1f8e9 !important;
205
+ border-radius: 15px !important;
206
+ padding: 10px 15px !important;
207
+ }
 
208
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # Create the Gradio interface
211
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
212
+ # Header
213
+ with gr.Row(elem_classes="header-row"):
214
+ gr.Markdown("""
215
+ <div style="text-align: center; margin-bottom: 10px; padding: 10px; background-color: #f0f4f8; border-radius: 8px;">
216
+ <h1 style="margin: 0; color: #2c3e50;">AI Chatbot With Memory</h1>
217
+ <h3 style="margin: 5px 0 0 0; color: #34495e;">Developed by Dhiraj and Swaroop</h3>
218
+ </div>
219
+ """)
220
+
221
+ with gr.Row():
222
+ with gr.Column(scale=3):
223
+ # Create ChatInterface without css_classes parameter
224
+ chatbot = gr.ChatInterface(
225
+ get_chatbot_response,
226
+ additional_inputs=[
227
+ gr.Textbox(value="You are a helpful assistant with memory capabilities.", label="System message"),
228
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
229
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
230
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
231
+ gr.Checkbox(value=True, label="Use Memory", info="Enable or disable memory capabilities"),
232
+ gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Memory Size", info="Maximum number of entries in memory"),
233
+ ],
234
+ examples=[
235
+ ["Tell me about machine learning"],
236
+ ["What are the best practices for data preprocessing?"],
237
+ ["Can you explain neural networks?"],
238
+ ],
239
+ title="Chat with AI Assistant",
240
+ # Removed css_classes parameter
241
+ )
242
+
243
+ with gr.Column(scale=1):
244
+ with gr.Group():
245
+ gr.Markdown("## Memory Management")
246
+ memory_display = gr.Textbox(label="Memory Contents", lines=20, max_lines=30, interactive=False)
247
+ view_memory_btn = gr.Button("View Memory Contents")
248
+ clear_memory_btn = gr.Button("Clear Memory")
249
+ summarize_memory_btn = gr.Button("Summarize Memory")
250
+ memory_status = gr.Textbox(label="Memory Status", lines=2, interactive=False)
251
+
252
+ # Set up button actions
253
+ view_memory_btn.click(view_memory, inputs=[], outputs=[memory_display])
254
+ clear_memory_btn.click(clear_memory_action, inputs=[], outputs=[memory_status])
255
+ summarize_memory_btn.click(
256
+ lambda: (summarize_memory(), "Memory summarized successfully."),
257
+ inputs=[],
258
+ outputs=[memory_status]
259
+ )
260
+
261
+ # Footer
262
+ with gr.Row(elem_classes="footer-row"):
263
+ gr.Markdown(f"""
264
+ <div style="text-align: center; margin-top: 20px; padding: 10px; background-color: #f0f4f8; border-radius: 8px;">
265
+ <p style="margin: 0; color: #2c3e50;">
266
+ Developed by Dhiraj and Swaroop | © {datetime.datetime.now().year} | Version 1.0
267
+ </p>
268
+ </div>
269
+ """)
270
 
271
  if __name__ == "__main__":
272
+ demo.launch()
custom_css.css ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Custom CSS for the AI Chatbot */
2
+
3
+ .header-row {
4
+ margin-bottom: 20px;
5
+ }
6
+
7
+ .footer-row {
8
+ margin-top: 30px;
9
+ }
10
+
11
+ .gradio-container {
12
+ max-width: 1200px !important;
13
+ margin-left: auto !important;
14
+ margin-right: auto !important;
15
+ }
16
+
17
+ /* Styling for chat messages */
18
+ .chat-bot .user-message {
19
+ background-color: #e3f2fd !important;
20
+ border-radius: 15px !important;
21
+ padding: 10px 15px !important;
22
+ }
23
+
24
+ .chat-bot .bot-message {
25
+ background-color: #f1f8e9 !important;
26
+ border-radius: 15px !important;
27
+ padding: 10px 15px !important;
28
+ }
29
+
30
+ /* Button styling */
31
+ button {
32
+ transition: all 0.3s ease !important;
33
+ }
34
+
35
+ button:hover {
36
+ transform: translateY(-2px) !important;
37
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1) !important;
38
+ }