File size: 2,070 Bytes
64d7fdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
import hashlib
from typing import Optional
from app.db.redis_client import redis_client
from ingestion.embedder import embedder
from app.config import config
from app.utils.logger import logger


class SemanticCache:
    def __init__(self):
        self.ttl = config["rag"]["cache"]["ttl"]
        self.enabled = config["rag"]["cache"]["enabled"]
    
    async def _get_cache_key(self, query: str, use_context: bool = True) -> str:
        context_flag = "rag" if use_context else "no-rag"
        return f"cache:{context_flag}:{hashlib.md5(query.encode()).hexdigest()}"
    
    async def get(self, query: str, use_context: bool = True) -> Optional[str]:
        if not self.enabled:
            return None
        
        try:
            cache_key = await self._get_cache_key(query, use_context)
            redis = await redis_client.get_client()
            cached = await redis.get(cache_key)
            
            if cached:
                logger.info(f"Cache hit for query: {query[:50]}...")
                return cached
            
            return None
        except Exception as e:
            logger.error(f"Cache get error: {e}")
            return None
    
    async def set(self, query: str, response: str, use_context: bool = True):
        if not self.enabled:
            return
        
        try:
            cache_key = await self._get_cache_key(query, use_context)
            redis = await redis_client.get_client()
            await redis.setex(cache_key, self.ttl, response)
            logger.info(f"Cached response for query: {query[:50]}...")
        except Exception as e:
            logger.error(f"Cache set error: {e}")
    
    async def clear(self):
        try:
            redis = await redis_client.get_client()
            keys = await redis.keys("cache:*")
            if keys:
                await redis.delete(*keys)
                logger.info(f"Cleared {len(keys)} cache entries")
        except Exception as e:
            logger.error(f"Cache clear error: {e}")


semantic_cache = SemanticCache()