Ryan2219's picture
Upload 70 files
e1ced8e verified
"""In-session crop cache β€” avoids redundant Gemini API calls for identical crops.
Stored in ``st.session_state`` so it persists across questions within a single
Streamlit session, but is discarded when the session ends.
Matching strategy:
- **Exact match** on ``(page_num, crop_instruction)`` is the primary lookup.
- **Fuzzy match** with a simple normalized overlap score handles cases where
the planner rephrases slightly (e.g., "Crop the gymnasium area" vs
"Crop gymnasium area showing diffusers"). Only matches above a high
threshold (0.85) are considered hits to avoid false positives.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from state import ImageRef
logger = logging.getLogger(__name__)
@dataclass
class CachedCrop:
"""A cached crop entry with its original instruction and result."""
page_num: int
crop_instruction: str
label: str
image_ref: ImageRef
# Normalised token set for fuzzy matching (computed once at insert time)
_tokens: frozenset[str] = field(default_factory=frozenset, repr=False)
def _normalise_tokens(text: str) -> frozenset[str]:
"""Lowercase, strip punctuation, split into a token set."""
cleaned = re.sub(r"[^a-z0-9\s]", "", text.lower())
return frozenset(cleaned.split())
def _token_overlap(a: frozenset[str], b: frozenset[str]) -> float:
"""Jaccard-style overlap: |intersection| / |union|."""
if not a or not b:
return 0.0
return len(a & b) / len(a | b)
class CropCache:
"""Session-scoped cache mapping (page, instruction) β†’ ImageRef.
Thread-safe for concurrent reads (dict lookups under CPython's GIL) but
writes are serialised via the single-threaded Streamlit main thread.
"""
# Minimum token-overlap score to accept a fuzzy match.
# Tuned so that minor rephrasing (dropping "the", "all") still matches
# (~0.78 overlap) while genuinely different instructions miss (~0.06-0.42).
FUZZY_THRESHOLD = 0.70
def __init__(self) -> None:
# Primary index: exact (page_num, instruction) β†’ CachedCrop
self._exact: dict[tuple[int, str], CachedCrop] = {}
# Secondary list for fuzzy scanning (same objects as _exact values)
self._entries: list[CachedCrop] = []
self._hit_count = 0
self._miss_count = 0
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def lookup(self, page_num: int, crop_instruction: str) -> ImageRef | None:
"""Return a cached ImageRef if a matching crop exists, else None.
Tries exact match first, then falls back to fuzzy token overlap
restricted to the same page.
"""
key = (page_num, crop_instruction)
# 1. Exact match
if key in self._exact:
self._hit_count += 1
entry = self._exact[key]
logger.info(
"CropCache HIT (exact) page=%d instruction='%s' β†’ %s",
page_num, crop_instruction[:60], entry.image_ref["id"],
)
return entry.image_ref
# 2. Fuzzy match β€” only among entries on the same page
query_tokens = _normalise_tokens(crop_instruction)
best_score = 0.0
best_entry: CachedCrop | None = None
for entry in self._entries:
if entry.page_num != page_num:
continue
score = _token_overlap(query_tokens, entry._tokens)
if score > best_score:
best_score = score
best_entry = entry
if best_entry is not None and best_score >= self.FUZZY_THRESHOLD:
self._hit_count += 1
logger.info(
"CropCache HIT (fuzzy %.2f) page=%d instruction='%s' β†’ %s",
best_score, page_num, crop_instruction[:60],
best_entry.image_ref["id"],
)
return best_entry.image_ref
self._miss_count += 1
return None
def register(
self,
page_num: int,
crop_instruction: str,
label: str,
image_ref: ImageRef,
*,
is_fallback: bool = False,
) -> None:
"""Register a successful crop in the cache.
Parameters
----------
is_fallback
If True, the crop is a full-page fallback (Gemini failed to crop).
These are NOT cached because they don't represent a useful targeted crop.
"""
if is_fallback:
logger.debug(
"CropCache SKIP (fallback) page=%d instruction='%s'",
page_num, crop_instruction[:60],
)
return
key = (page_num, crop_instruction)
if key in self._exact:
return # already cached
entry = CachedCrop(
page_num=page_num,
crop_instruction=crop_instruction,
label=label,
image_ref=image_ref,
_tokens=_normalise_tokens(crop_instruction),
)
self._exact[key] = entry
self._entries.append(entry)
logger.info(
"CropCache REGISTER page=%d instruction='%s' β†’ %s",
page_num, crop_instruction[:60], image_ref["id"],
)
@property
def size(self) -> int:
return len(self._entries)
@property
def stats(self) -> str:
total = self._hit_count + self._miss_count
rate = (self._hit_count / total * 100) if total > 0 else 0
return (
f"CropCache: {self.size} entries, "
f"{self._hit_count} hits / {self._miss_count} misses "
f"({rate:.0f}% hit rate)"
)
def clear(self) -> None:
"""Reset the cache (e.g., when a new PDF is loaded)."""
self._exact.clear()
self._entries.clear()
self._hit_count = 0
self._miss_count = 0