chatgpt2api / services /protocol /chat_completion_cache.py
peijun1's picture
clean deploy
dc03b36
Raw
History Blame Contribute Delete
8.06 kB
from __future__ import annotations
import copy
import hashlib
import json
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Iterator
from services.config import config
CACHEABLE_TEXT_KEYS = {
"frequency_penalty",
"max_completion_tokens",
"max_tokens",
"metadata",
"model",
"presence_penalty",
"reasoning_effort",
"response_format",
"seed",
"stop",
"temperature",
"tool_choice",
"tools",
"top_p",
"user",
}
@dataclass
class CacheEntry:
expires_at: float
value: Any
@dataclass
class InflightCall:
condition: threading.Condition = field(default_factory=lambda: threading.Condition(threading.RLock()))
done: bool = False
value: Any = None
error: BaseException | None = None
def _json_safe(value: Any) -> Any:
if isinstance(value, bytes):
return {"__bytes_sha256__": hashlib.sha256(value).hexdigest(), "length": len(value)}
if isinstance(value, bytearray):
data = bytes(value)
return {"__bytes_sha256__": hashlib.sha256(data).hexdigest(), "length": len(data)}
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in value.items()}
if isinstance(value, (list, tuple)):
return [_json_safe(item) for item in value]
return value
def canonical_body(body: dict[str, Any], messages: list[dict[str, Any]], *, stream: bool) -> dict[str, Any]:
payload = {key: body.get(key) for key in CACHEABLE_TEXT_KEYS if key in body}
payload["messages"] = messages
payload["stream"] = bool(stream)
return payload
def cache_key(body: dict[str, Any], messages: list[dict[str, Any]], *, stream: bool) -> str:
encoded = json.dumps(
_json_safe(canonical_body(body, messages, stream=stream)),
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def _message_signature(message: dict[str, Any]) -> str:
return json.dumps(_json_safe(message), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def normalize_text_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
settings = config.get_chat_completion_cache_settings()
if not settings.get("normalize_messages"):
return messages
normalized: list[dict[str, Any]] = []
previous_signature = ""
for message in messages:
if settings.get("drop_assistant_history") and str(message.get("role") or "") == "assistant":
continue
signature = _message_signature(message)
if settings.get("drop_adjacent_duplicates") and signature == previous_signature:
continue
normalized.append(message)
previous_signature = signature
return normalized
class ChatCompletionCache:
def __init__(self) -> None:
self._lock = threading.RLock()
self._entries: dict[str, CacheEntry] = {}
self._inflight: dict[str, InflightCall] = {}
def clear(self) -> None:
with self._lock:
self._entries.clear()
self._inflight.clear()
def _settings(self) -> dict[str, object]:
return config.get_chat_completion_cache_settings()
def _prune_locked(self, now: float, max_entries: int) -> None:
expired = [key for key, item in self._entries.items() if item.expires_at <= now]
for key in expired:
self._entries.pop(key, None)
while len(self._entries) > max_entries:
oldest_key = min(self._entries, key=lambda key: self._entries[key].expires_at)
self._entries.pop(oldest_key, None)
@staticmethod
def _copy(value: Any) -> Any:
return copy.deepcopy(value)
def get_or_compute_response(self, key: str, compute: Callable[[], dict[str, Any]]) -> dict[str, Any]:
settings = self._settings()
if not settings.get("enabled") or int(settings.get("ttl_seconds") or 0) <= 0:
return compute()
now = time.time()
max_entries = int(settings.get("max_entries") or 1)
with self._lock:
self._prune_locked(now, max_entries)
entry = self._entries.get(key)
if entry and entry.expires_at > now:
return self._copy(entry.value)
inflight = self._inflight.get(key) if settings.get("dedupe_inflight") else None
if inflight is None:
inflight = InflightCall()
if settings.get("dedupe_inflight"):
self._inflight[key] = inflight
owner = True
else:
owner = False
if not owner:
with inflight.condition:
while not inflight.done:
inflight.condition.wait()
if inflight.error:
raise inflight.error
return self._copy(inflight.value)
try:
value = compute()
except BaseException as exc:
with self._lock:
self._inflight.pop(key, None)
with inflight.condition:
inflight.error = exc
inflight.done = True
inflight.condition.notify_all()
raise
expires_at = time.time() + int(settings.get("ttl_seconds") or 0)
with self._lock:
self._entries[key] = CacheEntry(expires_at=expires_at, value=self._copy(value))
self._prune_locked(time.time(), max_entries)
self._inflight.pop(key, None)
with inflight.condition:
inflight.value = self._copy(value)
inflight.done = True
inflight.condition.notify_all()
return value
def get_or_compute_stream(self, key: str, compute: Callable[[], Iterable[dict[str, Any]]]) -> Iterator[dict[str, Any]]:
settings = self._settings()
if (
not settings.get("enabled")
or not settings.get("stream_cache")
or int(settings.get("ttl_seconds") or 0) <= 0
):
yield from compute()
return
now = time.time()
max_entries = int(settings.get("max_entries") or 1)
with self._lock:
self._prune_locked(now, max_entries)
entry = self._entries.get(key)
if entry and entry.expires_at > now:
yield from self._copy(entry.value)
return
inflight = self._inflight.get(key) if settings.get("dedupe_inflight") else None
if inflight is None:
inflight = InflightCall()
if settings.get("dedupe_inflight"):
self._inflight[key] = inflight
owner = True
else:
owner = False
if not owner:
with inflight.condition:
while not inflight.done:
inflight.condition.wait()
if inflight.error:
raise inflight.error
yield from self._copy(inflight.value)
return
chunks: list[dict[str, Any]] = []
try:
for chunk in compute():
chunks.append(self._copy(chunk))
yield chunk
except BaseException as exc:
with self._lock:
self._inflight.pop(key, None)
with inflight.condition:
inflight.error = exc
inflight.done = True
inflight.condition.notify_all()
raise
expires_at = time.time() + int(settings.get("ttl_seconds") or 0)
with self._lock:
self._entries[key] = CacheEntry(expires_at=expires_at, value=self._copy(chunks))
self._prune_locked(time.time(), max_entries)
self._inflight.pop(key, None)
with inflight.condition:
inflight.value = self._copy(chunks)
inflight.done = True
inflight.condition.notify_all()
chat_completion_cache = ChatCompletionCache()