Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import copy | |
| import hashlib | |
| import json | |
| import threading | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Iterable, Iterator | |
| from services.config import config | |
| CACHEABLE_TEXT_KEYS = { | |
| "frequency_penalty", | |
| "max_completion_tokens", | |
| "max_tokens", | |
| "metadata", | |
| "model", | |
| "presence_penalty", | |
| "reasoning_effort", | |
| "response_format", | |
| "seed", | |
| "stop", | |
| "temperature", | |
| "tool_choice", | |
| "tools", | |
| "top_p", | |
| "user", | |
| } | |
| class CacheEntry: | |
| expires_at: float | |
| value: Any | |
| class InflightCall: | |
| condition: threading.Condition = field(default_factory=lambda: threading.Condition(threading.RLock())) | |
| done: bool = False | |
| value: Any = None | |
| error: BaseException | None = None | |
| def _json_safe(value: Any) -> Any: | |
| if isinstance(value, bytes): | |
| return {"__bytes_sha256__": hashlib.sha256(value).hexdigest(), "length": len(value)} | |
| if isinstance(value, bytearray): | |
| data = bytes(value) | |
| return {"__bytes_sha256__": hashlib.sha256(data).hexdigest(), "length": len(data)} | |
| if isinstance(value, dict): | |
| return {str(key): _json_safe(item) for key, item in value.items()} | |
| if isinstance(value, (list, tuple)): | |
| return [_json_safe(item) for item in value] | |
| return value | |
| def canonical_body(body: dict[str, Any], messages: list[dict[str, Any]], *, stream: bool) -> dict[str, Any]: | |
| payload = {key: body.get(key) for key in CACHEABLE_TEXT_KEYS if key in body} | |
| payload["messages"] = messages | |
| payload["stream"] = bool(stream) | |
| return payload | |
| def cache_key(body: dict[str, Any], messages: list[dict[str, Any]], *, stream: bool) -> str: | |
| encoded = json.dumps( | |
| _json_safe(canonical_body(body, messages, stream=stream)), | |
| ensure_ascii=False, | |
| sort_keys=True, | |
| separators=(",", ":"), | |
| ).encode("utf-8") | |
| return hashlib.sha256(encoded).hexdigest() | |
| def _message_signature(message: dict[str, Any]) -> str: | |
| return json.dumps(_json_safe(message), ensure_ascii=False, sort_keys=True, separators=(",", ":")) | |
| def normalize_text_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| settings = config.get_chat_completion_cache_settings() | |
| if not settings.get("normalize_messages"): | |
| return messages | |
| normalized: list[dict[str, Any]] = [] | |
| previous_signature = "" | |
| for message in messages: | |
| if settings.get("drop_assistant_history") and str(message.get("role") or "") == "assistant": | |
| continue | |
| signature = _message_signature(message) | |
| if settings.get("drop_adjacent_duplicates") and signature == previous_signature: | |
| continue | |
| normalized.append(message) | |
| previous_signature = signature | |
| return normalized | |
| class ChatCompletionCache: | |
| def __init__(self) -> None: | |
| self._lock = threading.RLock() | |
| self._entries: dict[str, CacheEntry] = {} | |
| self._inflight: dict[str, InflightCall] = {} | |
| def clear(self) -> None: | |
| with self._lock: | |
| self._entries.clear() | |
| self._inflight.clear() | |
| def _settings(self) -> dict[str, object]: | |
| return config.get_chat_completion_cache_settings() | |
| def _prune_locked(self, now: float, max_entries: int) -> None: | |
| expired = [key for key, item in self._entries.items() if item.expires_at <= now] | |
| for key in expired: | |
| self._entries.pop(key, None) | |
| while len(self._entries) > max_entries: | |
| oldest_key = min(self._entries, key=lambda key: self._entries[key].expires_at) | |
| self._entries.pop(oldest_key, None) | |
| def _copy(value: Any) -> Any: | |
| return copy.deepcopy(value) | |
| def get_or_compute_response(self, key: str, compute: Callable[[], dict[str, Any]]) -> dict[str, Any]: | |
| settings = self._settings() | |
| if not settings.get("enabled") or int(settings.get("ttl_seconds") or 0) <= 0: | |
| return compute() | |
| now = time.time() | |
| max_entries = int(settings.get("max_entries") or 1) | |
| with self._lock: | |
| self._prune_locked(now, max_entries) | |
| entry = self._entries.get(key) | |
| if entry and entry.expires_at > now: | |
| return self._copy(entry.value) | |
| inflight = self._inflight.get(key) if settings.get("dedupe_inflight") else None | |
| if inflight is None: | |
| inflight = InflightCall() | |
| if settings.get("dedupe_inflight"): | |
| self._inflight[key] = inflight | |
| owner = True | |
| else: | |
| owner = False | |
| if not owner: | |
| with inflight.condition: | |
| while not inflight.done: | |
| inflight.condition.wait() | |
| if inflight.error: | |
| raise inflight.error | |
| return self._copy(inflight.value) | |
| try: | |
| value = compute() | |
| except BaseException as exc: | |
| with self._lock: | |
| self._inflight.pop(key, None) | |
| with inflight.condition: | |
| inflight.error = exc | |
| inflight.done = True | |
| inflight.condition.notify_all() | |
| raise | |
| expires_at = time.time() + int(settings.get("ttl_seconds") or 0) | |
| with self._lock: | |
| self._entries[key] = CacheEntry(expires_at=expires_at, value=self._copy(value)) | |
| self._prune_locked(time.time(), max_entries) | |
| self._inflight.pop(key, None) | |
| with inflight.condition: | |
| inflight.value = self._copy(value) | |
| inflight.done = True | |
| inflight.condition.notify_all() | |
| return value | |
| def get_or_compute_stream(self, key: str, compute: Callable[[], Iterable[dict[str, Any]]]) -> Iterator[dict[str, Any]]: | |
| settings = self._settings() | |
| if ( | |
| not settings.get("enabled") | |
| or not settings.get("stream_cache") | |
| or int(settings.get("ttl_seconds") or 0) <= 0 | |
| ): | |
| yield from compute() | |
| return | |
| now = time.time() | |
| max_entries = int(settings.get("max_entries") or 1) | |
| with self._lock: | |
| self._prune_locked(now, max_entries) | |
| entry = self._entries.get(key) | |
| if entry and entry.expires_at > now: | |
| yield from self._copy(entry.value) | |
| return | |
| inflight = self._inflight.get(key) if settings.get("dedupe_inflight") else None | |
| if inflight is None: | |
| inflight = InflightCall() | |
| if settings.get("dedupe_inflight"): | |
| self._inflight[key] = inflight | |
| owner = True | |
| else: | |
| owner = False | |
| if not owner: | |
| with inflight.condition: | |
| while not inflight.done: | |
| inflight.condition.wait() | |
| if inflight.error: | |
| raise inflight.error | |
| yield from self._copy(inflight.value) | |
| return | |
| chunks: list[dict[str, Any]] = [] | |
| try: | |
| for chunk in compute(): | |
| chunks.append(self._copy(chunk)) | |
| yield chunk | |
| except BaseException as exc: | |
| with self._lock: | |
| self._inflight.pop(key, None) | |
| with inflight.condition: | |
| inflight.error = exc | |
| inflight.done = True | |
| inflight.condition.notify_all() | |
| raise | |
| expires_at = time.time() + int(settings.get("ttl_seconds") or 0) | |
| with self._lock: | |
| self._entries[key] = CacheEntry(expires_at=expires_at, value=self._copy(chunks)) | |
| self._prune_locked(time.time(), max_entries) | |
| self._inflight.pop(key, None) | |
| with inflight.condition: | |
| inflight.value = self._copy(chunks) | |
| inflight.done = True | |
| inflight.condition.notify_all() | |
| chat_completion_cache = ChatCompletionCache() | |