Spaces:
Build error
Build error
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| from groq import Groq | |
| import os | |
| import datetime | |
| client = Groq( | |
| api_key=os.environ.get("GROQ_API_KEY"), | |
| ) | |
| # Initialize sentence transformer model | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Global memory buffer with embeddings | |
| memory = [] | |
| def add_to_memory(role, content): | |
| """ | |
| Add a message to memory along with its embedding. | |
| """ | |
| embedding = embedding_model.encode(content, convert_to_numpy=True) | |
| memory.append({"role": role, "content": content, "embedding": embedding}) | |
| def retrieve_relevant_memory(user_input, top_k=5): | |
| """ | |
| Retrieve the top-k most relevant messages from memory based on cosine similarity. | |
| """ | |
| if not memory: | |
| return [] | |
| # Compute the embedding of the user input | |
| user_embedding = embedding_model.encode(user_input, convert_to_numpy=True) | |
| # Calculate similarities | |
| similarities = [cosine_similarity([user_embedding], [m["embedding"]])[0][0] for m in memory] | |
| # Sort memory by similarity and return the top-k messages | |
| relevant_messages = sorted(zip(similarities, memory), key=lambda x: x[0], reverse=True) | |
| return [m[1] for m in relevant_messages[:top_k]] | |
| def construct_prompt(memory, user_input, max_tokens=500): | |
| """ | |
| Construct the prompt by combining relevant memory and the current user input. | |
| """ | |
| relevant_memory = retrieve_relevant_memory(user_input) | |
| # Combine relevant memory into the prompt | |
| prompt = "" | |
| token_count = 0 | |
| for message in relevant_memory: | |
| message_text = f'{message["role"]}: {message["content"]}\n' | |
| token_count += len(message_text.split()) | |
| if token_count > max_tokens: | |
| break | |
| prompt += message_text | |
| # Add the user input at the end | |
| prompt += f'user: {user_input}\n' | |
| return prompt | |
| def trim_memory(max_size=50): | |
| """ | |
| Trim the memory to keep it within the specified max size. | |
| """ | |
| if len(memory) > max_size: | |
| memory.pop(0) # Remove the oldest entry | |
| def summarize_memory(): | |
| """ | |
| Summarize the memory buffer to free up space. | |
| """ | |
| if not memory: | |
| return | |
| long_term_memory = " ".join([m["content"] for m in memory]) | |
| summary = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "Summarize the following text for key points."}, | |
| {"role": "user", "content": long_term_memory}, | |
| ], | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| max_tokens=4096, | |
| ) | |
| memory.clear() | |
| # Match the access pattern from main.py if needed | |
| try: | |
| # Try the format in app.py first | |
| summary_content = summary.choices[0].message.content | |
| except AttributeError: | |
| # Fall back to the format in main.py | |
| summary_content = summary.choices[0].text | |
| memory.append({"role": "system", "content": summary_content}) | |
| def get_chatbot_response( | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| use_memory=True, | |
| memory_size=50, | |
| ): | |
| """ | |
| Generate a response using the chatbot with memory capabilities. | |
| """ | |
| if use_memory: | |
| # Process history to maintain memory | |
| for i, (user_msg, bot_msg) in enumerate(history): | |
| if i < len(history) - 1: # Skip the current message which is already in the history | |
| add_to_memory("user", user_msg) | |
| if bot_msg: # Check if bot message exists (might be None for the most recent one) | |
| add_to_memory("assistant", bot_msg) | |
| # Construct prompt with relevant memory | |
| prompt = construct_prompt(memory, message) | |
| # Use the prompt with groq client | |
| completion = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| model="deepseek-r1-distill-llama-70b", | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=top_p, | |
| stream=True, | |
| ) | |
| # Stream the response | |
| response = "" | |
| for chunk in completion: | |
| response_part = chunk.choices[0].delta.content or "" | |
| response += response_part | |
| yield response | |
| # Update memory with the current message and response | |
| add_to_memory("user", message) | |
| add_to_memory("assistant", response) | |
| # Trim memory if needed | |
| trim_memory(max_size=memory_size) | |
| else: | |
| # If not using memory, just use regular chat completion | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| completion = client.chat.completions.create( | |
| messages=messages, | |
| model="deepseek-r1-distill-llama-70b", | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=top_p, | |
| stream=True, | |
| ) | |
| response = "" | |
| for chunk in completion: | |
| response_part = chunk.choices[0].delta.content or "" | |
| response += response_part | |
| yield response | |
| def view_memory(): | |
| """ | |
| Create a formatted string showing the current memory contents. | |
| """ | |
| if not memory: | |
| return "Memory is empty." | |
| memory_view = "Current Memory Contents:\n\n" | |
| for i, m in enumerate(memory): | |
| memory_view += f"Memory {i+1}: {m['role']}: {m['content']}\n\n" | |
| return memory_view | |
| def clear_memory_action(): | |
| """ | |
| Clear the memory buffer. | |
| """ | |
| memory.clear() | |
| return "Memory has been cleared." | |
| # Custom CSS for the chat interface - apply using elem_classes | |
| custom_css = """ | |
| .user-message { | |
| background-color: #e3f2fd !important; | |
| border-radius: 15px !important; | |
| padding: 10px 15px !important; | |
| } | |
| .bot-message { | |
| background-color: #f1f8e9 !important; | |
| border-radius: 15px !important; | |
| padding: 10px 15px !important; | |
| } | |
| """ | |
| # Create the Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| # Header | |
| with gr.Row(elem_classes="header-row"): | |
| gr.Markdown(""" | |
| <div style="text-align: center; margin-bottom: 10px; padding: 10px; background-color: #f0f4f8; border-radius: 8px;"> | |
| <h1 style="margin: 0; color: #2c3e50;">AI Chatbot With Memory</h1> | |
| <h3 style="margin: 5px 0 0 0; color: #34495e;">Developed by Dhiraj and Swaroop</h3> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Create ChatInterface without css_classes parameter | |
| chatbot = gr.ChatInterface( | |
| get_chatbot_response, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a helpful assistant with memory capabilities.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
| gr.Checkbox(value=True, label="Use Memory", info="Enable or disable memory capabilities"), | |
| gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Memory Size", info="Maximum number of entries in memory"), | |
| ], | |
| examples=[ | |
| ["Tell me about machine learning"], | |
| ["What are the best practices for data preprocessing?"], | |
| ["Can you explain neural networks?"], | |
| ], | |
| title="Chat with AI Assistant", | |
| # Removed css_classes parameter | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("## Memory Management") | |
| memory_display = gr.Textbox(label="Memory Contents", lines=20, max_lines=30, interactive=False) | |
| view_memory_btn = gr.Button("View Memory Contents") | |
| clear_memory_btn = gr.Button("Clear Memory") | |
| summarize_memory_btn = gr.Button("Summarize Memory") | |
| memory_status = gr.Textbox(label="Memory Status", lines=2, interactive=False) | |
| # Set up button actions | |
| view_memory_btn.click(view_memory, inputs=[], outputs=[memory_display]) | |
| clear_memory_btn.click(clear_memory_action, inputs=[], outputs=[memory_status]) | |
| summarize_memory_btn.click( | |
| lambda: (summarize_memory(), "Memory summarized successfully."), | |
| inputs=[], | |
| outputs=[memory_status] | |
| ) | |
| # Footer | |
| with gr.Row(elem_classes="footer-row"): | |
| gr.Markdown(f""" | |
| <div style="text-align: center; margin-top: 20px; padding: 10px; background-color: #f0f4f8; border-radius: 8px;"> | |
| <p style="margin: 0; color: #2c3e50;"> | |
| Developed by Dhiraj and Swaroop | © {datetime.datetime.now().year} | Version 1.0 | |
| </p> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |