""" cache/semantic_cache.py — Semantic cache for MedCheck Agent How it works ──────────── 1. Every (task, medicine_name) response is stored with a key derived from a simple normalised string. For medicine names, "aspirin", "Aspirin", "ASPIRIN" all hash to the same key — so an exact match is tried first (O(1)). 2. If no exact match is found, a lightweight fuzzy/phonetic similarity check is performed against all cached keys for the same task type. This catches common misspellings (Asprin → Aspirin already in cache). 3. Entries expire after a configurable TTL (default 24 h) so stale medical data is never served indefinitely. No heavy ML model is required — the similarity is intentionally kept simple (edit-distance ratio) so there are zero extra dependencies. If you later want vector embeddings, swap `_similarity()` for cosine similarity on sentence-transformer vectors. """ from __future__ import annotations import hashlib import re import time from dataclasses import dataclass, field from difflib import SequenceMatcher from typing import Optional # ───────────────────────────────────────────── # Cache entry # ───────────────────────────────────────────── @dataclass class CacheEntry: key: str # normalised lookup key task: str # "spell" | "interactions" | "contraindications" result: str # the full LLM response text created_at: float # time.time() when stored ttl: float # seconds until expiry @property def is_expired(self) -> bool: return time.time() > self.created_at + self.ttl @property def age_seconds(self) -> float: return time.time() - self.created_at # ───────────────────────────────────────────── # Semantic (fuzzy) cache # ───────────────────────────────────────────── class SemanticCache: """ In-process semantic cache keyed by (task, normalised_medicine_name). Parameters ---------- ttl_seconds : float How long entries live before being considered stale. similarity_threshold : float Minimum SequenceMatcher ratio (0–1) for a fuzzy hit. 0.82 catches single-letter typos while avoiding false positives between chemically different names. max_size : int Maximum number of entries. LRU eviction when exceeded. """ def __init__( self, ttl_seconds: float = 86_400, # 24 hours similarity_threshold: float = 0.82, max_size: int = 512, ) -> None: self._ttl = ttl_seconds self._threshold = similarity_threshold self._max_size = max_size self._store: dict[str, CacheEntry] = {} # composite_key → entry self._access_order: list[str] = [] # for LRU eviction # ── public API ─────────────────────────── def get(self, task: str, medicine: str) -> Optional[tuple[str, bool]]: """ Lookup a cached result. Returns ------- (result_text, is_fuzzy_hit) if found and not expired, else None. is_fuzzy_hit=False means exact key match; is_fuzzy_hit=True means fuzzy/semantic match. """ self._evict_expired() norm = self._normalise(medicine) # 1. Exact match ck = self._composite(task, norm) entry = self._store.get(ck) if entry and not entry.is_expired: self._touch(ck) return entry.result, False # 2. Fuzzy match against same-task entries best_ratio = 0.0 best_entry: Optional[CacheEntry] = None best_ck: Optional[str] = None for key, e in self._store.items(): if e.task != task or e.is_expired: continue ratio = self._similarity(norm, e.key) if ratio > best_ratio: best_ratio = ratio best_entry = e best_ck = key if best_ratio >= self._threshold and best_entry is not None: self._touch(best_ck) return best_entry.result, True return None def set(self, task: str, medicine: str, result: str) -> None: """Store a result for (task, medicine).""" norm = self._normalise(medicine) ck = self._composite(task, norm) entry = CacheEntry( key=norm, task=task, result=result, created_at=time.time(), ttl=self._ttl, ) self._store[ck] = entry self._touch(ck) self._evict_lru_if_needed() def invalidate(self, task: str, medicine: str) -> None: ck = self._composite(task, self._normalise(medicine)) self._store.pop(ck, None) if ck in self._access_order: self._access_order.remove(ck) def clear(self) -> None: self._store.clear() self._access_order.clear() @property def size(self) -> int: return len(self._store) def stats(self) -> dict: valid = sum(1 for e in self._store.values() if not e.is_expired) return { "total_entries": len(self._store), "valid_entries": valid, "expired_entries": len(self._store) - valid, "max_size": self._max_size, "ttl_seconds": self._ttl, "similarity_threshold": self._threshold, } # ── internals ──────────────────────────── @staticmethod def _normalise(text: str) -> str: """Lower-case, strip punctuation/spaces for key comparison.""" return re.sub(r"[^a-z0-9]", "", text.lower().strip()) @staticmethod def _composite(task: str, norm: str) -> str: return f"{task}::{norm}" @staticmethod def _similarity(a: str, b: str) -> float: return SequenceMatcher(None, a, b).ratio() def _touch(self, ck: Optional[str]) -> None: if ck is None: return if ck in self._access_order: self._access_order.remove(ck) self._access_order.append(ck) def _evict_expired(self) -> None: expired = [k for k, e in self._store.items() if e.is_expired] for k in expired: self._store.pop(k, None) if k in self._access_order: self._access_order.remove(k) def _evict_lru_if_needed(self) -> None: while len(self._store) > self._max_size and self._access_order: oldest = self._access_order.pop(0) self._store.pop(oldest, None) # ───────────────────────────────────────────── # Module-level singleton shared across requests # ───────────────────────────────────────────── _cache = SemanticCache() def get_cache() -> SemanticCache: """Return the shared cache singleton.""" return _cache