File size: 5,700 Bytes
fcaa164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from camel.storages.key_value_storages import BaseKeyValueStorage

if TYPE_CHECKING:
    from redis.asyncio import Redis

logger = logging.getLogger(__name__)


class RedisStorage(BaseKeyValueStorage):
    r"""A concrete implementation of the :obj:`BaseCacheStorage` using Redis as
    the backend. This is suitable for distributed cache systems that require
    persistence and high availability.
    """

    def __init__(
        self,
        sid: str,
        url: str = "redis://localhost:6379",
        loop: Optional[asyncio.AbstractEventLoop] = None,
        **kwargs,
    ) -> None:
        r"""Initializes the RedisStorage instance with the provided URL and
        options.

        Args:
            sid (str): The ID for the storage instance to identify the
                       record space.
            url (str): The URL for connecting to the Redis server.
            **kwargs: Additional keyword arguments for Redis client
                      configuration.

        Raises:
            ImportError: If the `redis.asyncio` module is not installed.
        """
        try:
            import redis.asyncio as aredis
        except ImportError as exc:
            logger.error(
                "Please install `redis` first. You can install it by "
                "running `pip install redis`."
            )
            raise exc

        self._client: Optional[aredis.Redis] = None
        self._url = url
        self._sid = sid
        self._loop = loop or asyncio.get_event_loop()

        self._create_client(**kwargs)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        self._run_async(self.close())

    async def close(self) -> None:
        r"""Closes the Redis client asynchronously."""
        if self._client:
            await self._client.close()

    def _create_client(self, **kwargs) -> None:
        r"""Creates the Redis client with the provided URL and options.

        Args:
            **kwargs: Additional keyword arguments for Redis client
                      configuration.
        """
        import redis.asyncio as aredis

        self._client = aredis.from_url(self._url, **kwargs)

    @property
    def client(self) -> Optional["Redis"]:
        r"""Returns the Redis client instance.

        Returns:
            redis.asyncio.Redis: The Redis client instance.
        """
        return self._client

    def save(
        self, records: List[Dict[str, Any]], expire: Optional[int] = None
    ) -> None:
        r"""Saves a batch of records to the key-value storage system."""
        try:
            self._run_async(self._async_save(records, expire))
        except Exception as e:
            logger.error(f"Error in save: {e}")

    def load(self) -> List[Dict[str, Any]]:
        r"""Loads all stored records from the key-value storage system.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries, where each dictionary
                represents a stored record.
        """
        try:
            return self._run_async(self._async_load())
        except Exception as e:
            logger.error(f"Error in load: {e}")
            return []

    def clear(self) -> None:
        r"""Removes all records from the key-value storage system."""
        try:
            self._run_async(self._async_clear())
        except Exception as e:
            logger.error(f"Error in clear: {e}")

    async def _async_save(
        self, records: List[Dict[str, Any]], expire: Optional[int] = None
    ) -> None:
        if self._client is None:
            raise ValueError("Redis client is not initialized")
        try:
            value = json.dumps(records)
            if expire:
                await self._client.setex(self._sid, expire, value)
            else:
                await self._client.set(self._sid, value)
        except Exception as e:
            logger.error(f"Error saving records: {e}")

    async def _async_load(self) -> List[Dict[str, Any]]:
        if self._client is None:
            raise ValueError("Redis client is not initialized")
        try:
            value = await self._client.get(self._sid)
            if value:
                return json.loads(value)
            return []
        except Exception as e:
            logger.error(f"Error loading records: {e}")
            return []

    async def _async_clear(self) -> None:
        if self._client is None:
            raise ValueError("Redis client is not initialized")
        try:
            await self._client.delete(self._sid)
        except Exception as e:
            logger.error(f"Error clearing records: {e}")

    def _run_async(self, coro):
        if not self._loop.is_running():
            return self._loop.run_until_complete(coro)
        else:
            future = asyncio.run_coroutine_threadsafe(coro, self._loop)
            return future.result()