MedCheck_Agent / cache /semantic_cache.py
Ashwin
Introducing the logger and caching logics
582e09f
Raw
History Blame Contribute Delete
7.43 kB
"""
cache/semantic_cache.py β€” Semantic cache for MedCheck Agent
How it works
────────────
1. Every (task, medicine_name) response is stored with a key derived
from a simple normalised string. For medicine names, "aspirin",
"Aspirin", "ASPIRIN" all hash to the same key β€” so an exact match
is tried first (O(1)).
2. If no exact match is found, a lightweight fuzzy/phonetic similarity
check is performed against all cached keys for the same task type.
This catches common misspellings (Asprin β†’ Aspirin already in cache).
3. Entries expire after a configurable TTL (default 24 h) so stale
medical data is never served indefinitely.
No heavy ML model is required β€” the similarity is intentionally kept
simple (edit-distance ratio) so there are zero extra dependencies.
If you later want vector embeddings, swap `_similarity()` for cosine
similarity on sentence-transformer vectors.
"""
from __future__ import annotations
import hashlib
import re
import time
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from typing import Optional
# ─────────────────────────────────────────────
# Cache entry
# ─────────────────────────────────────────────
@dataclass
class CacheEntry:
key: str # normalised lookup key
task: str # "spell" | "interactions" | "contraindications"
result: str # the full LLM response text
created_at: float # time.time() when stored
ttl: float # seconds until expiry
@property
def is_expired(self) -> bool:
return time.time() > self.created_at + self.ttl
@property
def age_seconds(self) -> float:
return time.time() - self.created_at
# ─────────────────────────────────────────────
# Semantic (fuzzy) cache
# ─────────────────────────────────────────────
class SemanticCache:
"""
In-process semantic cache keyed by (task, normalised_medicine_name).
Parameters
----------
ttl_seconds : float
How long entries live before being considered stale.
similarity_threshold : float
Minimum SequenceMatcher ratio (0–1) for a fuzzy hit.
0.82 catches single-letter typos while avoiding false positives
between chemically different names.
max_size : int
Maximum number of entries. LRU eviction when exceeded.
"""
def __init__(
self,
ttl_seconds: float = 86_400, # 24 hours
similarity_threshold: float = 0.82,
max_size: int = 512,
) -> None:
self._ttl = ttl_seconds
self._threshold = similarity_threshold
self._max_size = max_size
self._store: dict[str, CacheEntry] = {} # composite_key β†’ entry
self._access_order: list[str] = [] # for LRU eviction
# ── public API ───────────────────────────
def get(self, task: str, medicine: str) -> Optional[tuple[str, bool]]:
"""
Lookup a cached result.
Returns
-------
(result_text, is_fuzzy_hit) if found and not expired, else None.
is_fuzzy_hit=False means exact key match;
is_fuzzy_hit=True means fuzzy/semantic match.
"""
self._evict_expired()
norm = self._normalise(medicine)
# 1. Exact match
ck = self._composite(task, norm)
entry = self._store.get(ck)
if entry and not entry.is_expired:
self._touch(ck)
return entry.result, False
# 2. Fuzzy match against same-task entries
best_ratio = 0.0
best_entry: Optional[CacheEntry] = None
best_ck: Optional[str] = None
for key, e in self._store.items():
if e.task != task or e.is_expired:
continue
ratio = self._similarity(norm, e.key)
if ratio > best_ratio:
best_ratio = ratio
best_entry = e
best_ck = key
if best_ratio >= self._threshold and best_entry is not None:
self._touch(best_ck)
return best_entry.result, True
return None
def set(self, task: str, medicine: str, result: str) -> None:
"""Store a result for (task, medicine)."""
norm = self._normalise(medicine)
ck = self._composite(task, norm)
entry = CacheEntry(
key=norm,
task=task,
result=result,
created_at=time.time(),
ttl=self._ttl,
)
self._store[ck] = entry
self._touch(ck)
self._evict_lru_if_needed()
def invalidate(self, task: str, medicine: str) -> None:
ck = self._composite(task, self._normalise(medicine))
self._store.pop(ck, None)
if ck in self._access_order:
self._access_order.remove(ck)
def clear(self) -> None:
self._store.clear()
self._access_order.clear()
@property
def size(self) -> int:
return len(self._store)
def stats(self) -> dict:
valid = sum(1 for e in self._store.values() if not e.is_expired)
return {
"total_entries": len(self._store),
"valid_entries": valid,
"expired_entries": len(self._store) - valid,
"max_size": self._max_size,
"ttl_seconds": self._ttl,
"similarity_threshold": self._threshold,
}
# ── internals ────────────────────────────
@staticmethod
def _normalise(text: str) -> str:
"""Lower-case, strip punctuation/spaces for key comparison."""
return re.sub(r"[^a-z0-9]", "", text.lower().strip())
@staticmethod
def _composite(task: str, norm: str) -> str:
return f"{task}::{norm}"
@staticmethod
def _similarity(a: str, b: str) -> float:
return SequenceMatcher(None, a, b).ratio()
def _touch(self, ck: Optional[str]) -> None:
if ck is None:
return
if ck in self._access_order:
self._access_order.remove(ck)
self._access_order.append(ck)
def _evict_expired(self) -> None:
expired = [k for k, e in self._store.items() if e.is_expired]
for k in expired:
self._store.pop(k, None)
if k in self._access_order:
self._access_order.remove(k)
def _evict_lru_if_needed(self) -> None:
while len(self._store) > self._max_size and self._access_order:
oldest = self._access_order.pop(0)
self._store.pop(oldest, None)
# ─────────────────────────────────────────────
# Module-level singleton shared across requests
# ─────────────────────────────────────────────
_cache = SemanticCache()
def get_cache() -> SemanticCache:
"""Return the shared cache singleton."""
return _cache