File size: 7,658 Bytes
53bec59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""

Redis Cache Implementation for Production

"""

import json
import hashlib
from typing import Any, Optional, Union
from datetime import timedelta
import redis.asyncio as aioredis

from src.core.config import settings
from src.core.logging import logger
from src.core.exceptions import CacheError


class RedisCache:
    """Redis cache manager with async support"""
    
    def __init__(self):
        self.redis: Optional[aioredis.Redis] = None
        self.enabled = settings.CACHE_PREDICTIONS
    
    async def connect(self):
        """Connect to Redis"""
        if not self.enabled:
            logger.info("Redis cache is disabled")
            return
        
        try:
            self.redis = await aioredis.from_url(
                settings.REDIS_URL,
                encoding="utf-8",
                decode_responses=True,
                max_connections=50
            )
            # Test connection
            await self.redis.ping()
            logger.info(f"Connected to Redis at {settings.REDIS_HOST}:{settings.REDIS_PORT}")
        except Exception as e:
            logger.error(f"Failed to connect to Redis: {e}")
            self.enabled = False
            raise CacheError(f"Redis connection failed: {e}")
    
    async def disconnect(self):
        """Disconnect from Redis"""
        if self.redis:
            await self.redis.close()
            logger.info("Disconnected from Redis")
    
    def _generate_cache_key(self, prefix: str, data: Union[str, dict]) -> str:
        """Generate cache key from data"""
        if isinstance(data, dict):
            data_str = json.dumps(data, sort_keys=True)
        else:
            data_str = str(data)
        
        hash_value = hashlib.sha256(data_str.encode()).hexdigest()[:16]
        return f"{prefix}:{hash_value}"
    
    async def get(self, key: str) -> Optional[Any]:
        """Get value from cache"""
        if not self.enabled or not self.redis:
            return None
        
        try:
            value = await self.redis.get(key)
            if value:
                logger.debug(f"Cache hit: {key}")
                return json.loads(value)
            logger.debug(f"Cache miss: {key}")
            return None
        except Exception as e:
            logger.warning(f"Cache get error for {key}: {e}")
            return None
    
    async def set(

        self,

        key: str,

        value: Any,

        ttl: Optional[int] = None

    ) -> bool:
        """Set value in cache with TTL"""
        if not self.enabled or not self.redis:
            return False
        
        try:
            ttl = ttl or settings.CACHE_TTL
            value_json = json.dumps(value)
            await self.redis.setex(key, ttl, value_json)
            logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
            return True
        except Exception as e:
            logger.warning(f"Cache set error for {key}: {e}")
            return False
    
    async def delete(self, key: str) -> bool:
        """Delete key from cache"""
        if not self.enabled or not self.redis:
            return False
        
        try:
            await self.redis.delete(key)
            logger.debug(f"Cache delete: {key}")
            return True
        except Exception as e:
            logger.warning(f"Cache delete error for {key}: {e}")
            return False
    
    async def get_prediction(

        self,

        model_type: str,

        input_data: Union[str, dict]

    ) -> Optional[dict]:
        """Get cached prediction"""
        key = self._generate_cache_key(f"pred:{model_type}", input_data)
        return await self.get(key)
    
    async def set_prediction(

        self,

        model_type: str,

        input_data: Union[str, dict],

        result: dict,

        ttl: Optional[int] = None

    ) -> bool:
        """Cache prediction result"""
        key = self._generate_cache_key(f"pred:{model_type}", input_data)
        return await self.set(key, result, ttl)
    
    async def increment_rate_limit(

        self,

        identifier: str,

        window_seconds: int

    ) -> int:
        """Increment rate limit counter"""
        if not self.enabled or not self.redis:
            return 0
        
        try:
            key = f"ratelimit:{identifier}"
            pipe = self.redis.pipeline()
            pipe.incr(key)
            pipe.expire(key, window_seconds)
            result = await pipe.execute()
            count = result[0]
            logger.debug(f"Rate limit count for {identifier}: {count}")
            return count
        except Exception as e:
            logger.warning(f"Rate limit increment error: {e}")
            return 0
    
    async def get_rate_limit_count(self, identifier: str) -> int:
        """Get current rate limit count"""
        if not self.enabled or not self.redis:
            return 0
        
        try:
            key = f"ratelimit:{identifier}"
            count = await self.redis.get(key)
            return int(count) if count else 0
        except Exception as e:
            logger.warning(f"Rate limit get error: {e}")
            return 0
    
    async def clear_all(self) -> bool:
        """Clear all cache (use with caution!)"""
        if not self.enabled or not self.redis:
            return False
        
        try:
            await self.redis.flushdb()
            logger.warning("All cache cleared!")
            return True
        except Exception as e:
            logger.error(f"Cache clear error: {e}")
            return False


# Global cache instance
cache = RedisCache()


# Decorator for caching function results
def cached(prefix: str, ttl: Optional[int] = None):
    """Decorator to cache function results"""
    def decorator(func):
        async def wrapper(*args, **kwargs):
            # Generate cache key from function arguments
            cache_data = {"args": str(args), "kwargs": str(kwargs)}
            cache_key = cache._generate_cache_key(prefix, cache_data)
            
            # Try to get from cache
            cached_result = await cache.get(cache_key)
            if cached_result is not None:
                return cached_result
            
            # Execute function
            result = await func(*args, **kwargs)
            
            # Cache result
            await cache.set(cache_key, result, ttl)
            
            return result
        return wrapper
    return decorator


if __name__ == "__main__":
    import asyncio
    
    async def test_cache():
        # Connect
        await cache.connect()
        
        # Test basic operations
        await cache.set("test_key", {"value": 123}, ttl=60)
        result = await cache.get("test_key")
        print(f"Retrieved: {result}")
        
        # Test prediction caching
        await cache.set_prediction(
            "deepfake",
            {"image": "test.jpg"},
            {"prediction": "FAKE", "confidence": 0.95},
            ttl=300
        )
        
        cached_pred = await cache.get_prediction("deepfake", {"image": "test.jpg"})
        print(f"Cached prediction: {cached_pred}")
        
        # Test rate limiting
        for i in range(5):
            count = await cache.increment_rate_limit("user:123", 60)
            print(f"Request {i+1}: Rate limit count = {count}")
        
        # Disconnect
        await cache.disconnect()
    
    asyncio.run(test_cache())