Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, List, Optional | |
| import requests | |
| from langchain_core.messages import get_buffer_string | |
| from langchain.memory.chat_memory import BaseChatMemory | |
| MANAGED_URL = "https://api.getmetal.io/v1/motorhead" | |
| # LOCAL_URL = "http://localhost:8080" | |
| class MotorheadMemory(BaseChatMemory): | |
| """Chat message memory backed by Motorhead service.""" | |
| url: str = MANAGED_URL | |
| timeout: int = 3000 | |
| memory_key: str = "history" | |
| session_id: str | |
| context: Optional[str] = None | |
| # Managed Params | |
| api_key: Optional[str] = None | |
| client_id: Optional[str] = None | |
| def __get_headers(self) -> Dict[str, str]: | |
| is_managed = self.url == MANAGED_URL | |
| headers = { | |
| "Content-Type": "application/json", | |
| } | |
| if is_managed and not (self.api_key and self.client_id): | |
| raise ValueError( | |
| """ | |
| You must provide an API key or a client ID to use the managed | |
| version of Motorhead. Visit https://getmetal.io for more information. | |
| """ | |
| ) | |
| if is_managed and self.api_key and self.client_id: | |
| headers["x-metal-api-key"] = self.api_key | |
| headers["x-metal-client-id"] = self.client_id | |
| return headers | |
| async def init(self) -> None: | |
| res = requests.get( | |
| f"{self.url}/sessions/{self.session_id}/memory", | |
| timeout=self.timeout, | |
| headers=self.__get_headers(), | |
| ) | |
| res_data = res.json() | |
| res_data = res_data.get("data", res_data) # Handle Managed Version | |
| messages = res_data.get("messages", []) | |
| context = res_data.get("context", "NONE") | |
| for message in reversed(messages): | |
| if message["role"] == "AI": | |
| self.chat_memory.add_ai_message(message["content"]) | |
| else: | |
| self.chat_memory.add_user_message(message["content"]) | |
| if context and context != "NONE": | |
| self.context = context | |
| def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]: | |
| if self.return_messages: | |
| return {self.memory_key: self.chat_memory.messages} | |
| else: | |
| return {self.memory_key: get_buffer_string(self.chat_memory.messages)} | |
| def memory_variables(self) -> List[str]: | |
| return [self.memory_key] | |
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
| input_str, output_str = self._get_input_output(inputs, outputs) | |
| requests.post( | |
| f"{self.url}/sessions/{self.session_id}/memory", | |
| timeout=self.timeout, | |
| json={ | |
| "messages": [ | |
| {"role": "Human", "content": f"{input_str}"}, | |
| {"role": "AI", "content": f"{output_str}"}, | |
| ] | |
| }, | |
| headers=self.__get_headers(), | |
| ) | |
| super().save_context(inputs, outputs) | |
| def delete_session(self) -> None: | |
| """Delete a session""" | |
| requests.delete(f"{self.url}/sessions/{self.session_id}/memory") | |