File size: 5,700 Bytes
fcaa164 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from camel.storages.key_value_storages import BaseKeyValueStorage
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
class RedisStorage(BaseKeyValueStorage):
r"""A concrete implementation of the :obj:`BaseCacheStorage` using Redis as
the backend. This is suitable for distributed cache systems that require
persistence and high availability.
"""
def __init__(
self,
sid: str,
url: str = "redis://localhost:6379",
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
) -> None:
r"""Initializes the RedisStorage instance with the provided URL and
options.
Args:
sid (str): The ID for the storage instance to identify the
record space.
url (str): The URL for connecting to the Redis server.
**kwargs: Additional keyword arguments for Redis client
configuration.
Raises:
ImportError: If the `redis.asyncio` module is not installed.
"""
try:
import redis.asyncio as aredis
except ImportError as exc:
logger.error(
"Please install `redis` first. You can install it by "
"running `pip install redis`."
)
raise exc
self._client: Optional[aredis.Redis] = None
self._url = url
self._sid = sid
self._loop = loop or asyncio.get_event_loop()
self._create_client(**kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self._run_async(self.close())
async def close(self) -> None:
r"""Closes the Redis client asynchronously."""
if self._client:
await self._client.close()
def _create_client(self, **kwargs) -> None:
r"""Creates the Redis client with the provided URL and options.
Args:
**kwargs: Additional keyword arguments for Redis client
configuration.
"""
import redis.asyncio as aredis
self._client = aredis.from_url(self._url, **kwargs)
@property
def client(self) -> Optional["Redis"]:
r"""Returns the Redis client instance.
Returns:
redis.asyncio.Redis: The Redis client instance.
"""
return self._client
def save(
self, records: List[Dict[str, Any]], expire: Optional[int] = None
) -> None:
r"""Saves a batch of records to the key-value storage system."""
try:
self._run_async(self._async_save(records, expire))
except Exception as e:
logger.error(f"Error in save: {e}")
def load(self) -> List[Dict[str, Any]]:
r"""Loads all stored records from the key-value storage system.
Returns:
List[Dict[str, Any]]: A list of dictionaries, where each dictionary
represents a stored record.
"""
try:
return self._run_async(self._async_load())
except Exception as e:
logger.error(f"Error in load: {e}")
return []
def clear(self) -> None:
r"""Removes all records from the key-value storage system."""
try:
self._run_async(self._async_clear())
except Exception as e:
logger.error(f"Error in clear: {e}")
async def _async_save(
self, records: List[Dict[str, Any]], expire: Optional[int] = None
) -> None:
if self._client is None:
raise ValueError("Redis client is not initialized")
try:
value = json.dumps(records)
if expire:
await self._client.setex(self._sid, expire, value)
else:
await self._client.set(self._sid, value)
except Exception as e:
logger.error(f"Error saving records: {e}")
async def _async_load(self) -> List[Dict[str, Any]]:
if self._client is None:
raise ValueError("Redis client is not initialized")
try:
value = await self._client.get(self._sid)
if value:
return json.loads(value)
return []
except Exception as e:
logger.error(f"Error loading records: {e}")
return []
async def _async_clear(self) -> None:
if self._client is None:
raise ValueError("Redis client is not initialized")
try:
await self._client.delete(self._sid)
except Exception as e:
logger.error(f"Error clearing records: {e}")
def _run_async(self, coro):
if not self._loop.is_running():
return self._loop.run_until_complete(coro)
else:
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
|