File size: 9,733 Bytes
13c3f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
"""
Redis Cache Service for Gapura AI
Provides caching layer for Google Sheets data and predictions
"""

import os
import time
import json
import logging
from typing import Optional, Any, Callable
from datetime import timedelta
from functools import wraps
import redis
from redis.exceptions import RedisError

logger = logging.getLogger(__name__)


class CacheService:
    """Redis-based caching service with L1 In-Memory and L2 Redis layers"""

    def __init__(self):
        self.redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
        self.client = None
        # In-memory cache stores: key -> (value, expiry_timestamp)
        self.in_memory_cache = {} 
        self.enabled = os.getenv("CACHE_ENABLED", "true").lower() == "true"
        self.backend = os.getenv("CACHE_BACKEND", "").lower()  # 'redis' | 'memory' | ''
        
        # Auto-detect HF Spaces and prefer memory backend if REDIS_URL is not set
        if (os.getenv("SPACE_ID") or os.getenv("HF_TOKEN")) and not os.getenv("REDIS_URL"):
            if self.backend == "":
                self.backend = "memory"
                
        if self.enabled and self.backend != "memory":
            self._connect()

    def _connect(self):
        """Connect to Redis with in-memory fallback"""
        if not self.enabled:
            logger.info("Cache disabled via environment variable")
            return
        if self.backend == "memory":
            logger.info("Cache backend set to memory; skipping Redis connection")
            self.client = None
            return

        try:
            self.client = redis.from_url(
                self.redis_url,
                decode_responses=True,
                socket_connect_timeout=2,
                socket_timeout=2,
            )
            self.client.ping()
            logger.info(f"Connected to Redis at {self.redis_url}")
        except RedisError as e:
            logger.warning(f"Failed to connect to Redis: {e}. Falling back to in-memory cache.")
            self.client = None

    def get(self, key: str) -> Optional[Any]:
        """Get value from cache (L1 Memory -> L2 Redis)"""
        if not self.enabled:
            return None

        # 1. Try L1 In-Memory Cache first
        if key in self.in_memory_cache:
            entry = self.in_memory_cache[key]
            # Check for tuple (value, expiry)
            if isinstance(entry, tuple) and len(entry) == 2:
                value, expiry = entry
                if expiry > time.time():
                    logger.debug(f"L1 Memory HIT: {key}")
                    return value
                else:
                    # Expired, remove it
                    del self.in_memory_cache[key]
            else:
                # Legacy format (just value), return it
                logger.debug(f"L1 Memory HIT (Legacy): {key}")
                return entry

        # 2. Try L2 Redis Cache
        if self.client:
            try:
                value_str = self.client.get(key)
                if value_str:
                    logger.debug(f"L2 Redis HIT: {key}")
                    data = json.loads(value_str)
                    
                    # Backfill L1 Memory
                    # Try to get TTL from Redis to sync expiration
                    try:
                        ttl = self.client.ttl(key)
                        if ttl > 0:
                            self.in_memory_cache[key] = (data, time.time() + ttl)
                        else:
                            self.in_memory_cache[key] = (data, time.time() + 300) # Default 5m
                    except:
                        self.in_memory_cache[key] = (data, time.time() + 300)
                        
                    return data
            except (RedisError, json.JSONDecodeError) as e:
                logger.warning(f"Redis get error for {key}: {e}")
        
        logger.debug(f"Cache MISS: {key}")
        return None

    def set(self, key: str, value: Any, ttl_seconds: int = 300) -> bool:
        """Set value in cache (L1 Memory + L2 Redis)"""
        if not self.enabled:
            return False

        success = False
        
        # 1. Set L2 Redis
        if self.client:
            try:
                serialized = json.dumps(value, default=str)
                self.client.setex(key, ttl_seconds, serialized)
                logger.debug(f"L2 Redis SET: {key}")
                success = True
            except (RedisError, TypeError) as e:
                logger.warning(f"Redis set error for {key}: {e}")

        # 2. Set L1 Memory
        expiry = time.time() + ttl_seconds
        self.in_memory_cache[key] = (value, expiry)
        
        # Memory Management: Simple FIFO if too large
        if len(self.in_memory_cache) > 1000:
            # Cleanup expired items first
            now = time.time()
            expired_keys = [k for k, v in self.in_memory_cache.items() 
                           if isinstance(v, tuple) and v[1] < now]
            for k in expired_keys:
                del self.in_memory_cache[k]
                
            # If still too big, remove oldest inserted
            if len(self.in_memory_cache) > 1000:
                first_key = next(iter(self.in_memory_cache))
                self.in_memory_cache.pop(first_key)
        
        success = True
        return success

    def delete(self, key: str) -> bool:
        """Delete key from cache"""
        if not self.enabled:
            return False

        # Delete from Redis
        if self.client:
            try:
                self.client.delete(key)
            except RedisError:
                pass
        
        # Delete from In-Memory
        if key in self.in_memory_cache:
            del self.in_memory_cache[key]
        
        return True

    def delete_pattern(self, pattern: str) -> int:
        """Delete all keys matching pattern"""
        if not self.enabled:
            return 0
            
        deleted_count = 0
        
        # 1. Delete from Redis
        if self.client:
            try:
                keys = self.client.keys(pattern)
                if keys:
                    deleted_count = self.client.delete(*keys)
                    logger.debug(f"Redis DELETE pattern {pattern}: {deleted_count} keys")
            except RedisError as e:
                logger.warning(f"Cache delete pattern error for {pattern}: {e}")

        # 2. Delete from In-Memory (using simple string matching)
        try:
            mem_deleted = 0
            if not self.in_memory_cache:
                return deleted_count

            # Simple wildcard matching
            token = pattern.replace("*", "")
            keys_to_check = list(self.in_memory_cache.keys())
            
            for k in keys_to_check:
                if token in k: # Simple substring match for now
                    del self.in_memory_cache[k]
                    mem_deleted += 1
            
            if mem_deleted > 0:
                logger.debug(f"In-memory DELETE pattern {pattern}: {mem_deleted} keys")
                
            # Return max of both (approximate)
            return max(deleted_count, mem_deleted)
            
        except Exception as e:
            logger.warning(f"In-memory delete pattern error for {pattern}: {e}")
            return deleted_count

    def health_check(self) -> dict:
        """Check cache health"""
        if not self.enabled:
            return {"status": "disabled", "message": "Caching is disabled"}

        status = {
            "backend": self.backend if self.backend else ("redis" if self.client else "memory"),
            "l1_items": len(self.in_memory_cache),
        }

        if self.client:
            try:
                self.client.ping()
                info = self.client.info("memory")
                status.update({
                    "redis_status": "connected",
                    "redis_used_memory": info.get("used_memory_human", "unknown"),
                    "redis_clients": self.client.client_list().__len__(),
                })
            except RedisError as e:
                status["redis_status"] = f"error: {str(e)}"
        else:
            status["redis_status"] = "not_configured"

        return status


def cached(key_prefix: str, ttl_seconds: int = 300):
    """
    Decorator for caching function results
    Usage: @cached("my_prefix", ttl_seconds=300)
    """

    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            bypass_cache = kwargs.pop("bypass_cache", False)

            # Use singleton instance!
            cache = get_cache()

            if bypass_cache:
                logger.debug(f"Cache bypassed for {key_prefix}")
                return func(*args, **kwargs)

            # Create a consistent cache key
            # Filter out authentication related args if present to avoid caching user-specifics if not needed
            # For now, just hash everything
            arg_str = str(args) + str(sorted(kwargs.items()))
            cache_key = f"{key_prefix}:{hash(arg_str)}"

            cached_result = cache.get(cache_key)
            if cached_result is not None:
                return cached_result

            result = func(*args, **kwargs)

            if result is not None:
                cache.set(cache_key, result, ttl_seconds)

            return result

        return wrapper

    return decorator


_cache_instance: Optional[CacheService] = None


def get_cache() -> CacheService:
    """Get singleton cache instance"""
    global _cache_instance
    if _cache_instance is None:
        _cache_instance = CacheService()
    return _cache_instance