grok2api / app /platform /storage /media_cache.py
FUCAT's picture
Deploy grok2api to HF Spaces (Docker)
7e55e53
Raw
History Blame Contribute Delete
16.8 kB
"""Local media cache helpers with optional capacity enforcement."""
import asyncio
import os
import sqlite3
import threading
import time
import uuid
from contextlib import closing, contextmanager
from pathlib import Path
from typing import Any, Callable, Literal
from app.platform.config.snapshot import get_config
from app.platform.logging.logger import logger
from .media_paths import (
image_files_dir,
local_media_cache_db_path,
local_media_lock_path,
video_files_dir,
)
MediaType = Literal["image", "video"]
_LOW_WATERMARK_RATIO = 0.60
_TABLE = "local_media_files"
_IMAGE_EXTS = frozenset({".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"})
_VIDEO_EXTS = frozenset({".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"})
class LocalMediaCacheStore:
"""Manage local media files and enforce per-type cache limits."""
def __init__(self, *, config_provider: Callable[[], Any] = get_config) -> None:
self._config_provider = config_provider
self._init_lock = threading.Lock()
self._thread_locks = {
"image": threading.Lock(),
"video": threading.Lock(),
}
self._initialized_dbs: set[Path] = set()
def save_image(self, raw: bytes, mime: str, file_id: str) -> str:
"""Persist an image and return the stable local file ID."""
ext = ".png" if "png" in mime.lower() else ".jpg"
self._save("image", file_id=file_id, raw=raw, suffix=ext)
return file_id
def save_video(self, raw: bytes, file_id: str) -> Path:
"""Persist a video and return the final file path."""
return self._save("video", file_id=file_id, raw=raw, suffix=".mp4")
def reconcile(self, media_type: MediaType) -> None:
"""Rebuild the on-disk index for one media type and enforce limits."""
max_bytes = self._limit_bytes(media_type)
if max_bytes <= 0:
return
with self._guard(media_type), closing(self._connect()) as conn:
rows = []
for path in self._iter_files(media_type):
try:
stat = path.stat()
except OSError:
continue
rows.append(
(
media_type,
path.name,
int(stat.st_size),
int(stat.st_mtime_ns),
int(stat.st_mtime_ns),
)
)
conn.execute(f"DELETE FROM {_TABLE} WHERE media_type = ?", (media_type,))
if rows:
conn.executemany(
f"""
INSERT INTO {_TABLE} (
media_type,
name,
size_bytes,
created_at_ns,
updated_at_ns
) VALUES (?, ?, ?, ?, ?)
""",
rows,
)
self._enforce_limit_locked(conn, media_type)
conn.commit()
def delete(self, media_type: MediaType, name: str) -> bool:
"""Delete a single local media file and keep the index consistent."""
safe_name = self._validate_name(media_type, name)
path = self._path_for_name(media_type, safe_name)
with self._guard(media_type):
existed = path.is_file()
if existed:
path.unlink(missing_ok=True)
self._delete_index_row_if_present(media_type, safe_name)
return existed
def clear(self, media_type: MediaType) -> int:
"""Delete all tracked local media files for one media type."""
removed = 0
with self._guard(media_type):
for path in self._iter_files(media_type):
if path.is_file():
path.unlink(missing_ok=True)
removed += 1
self._delete_index_rows_if_present(media_type)
return removed
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _save(
self,
media_type: MediaType,
*,
file_id: str,
raw: bytes,
suffix: str,
) -> Path:
path = self._media_dir(media_type) / f"{file_id}{suffix}"
if self._limit_bytes(media_type) <= 0:
self._write_if_missing(path, raw)
return path
with self._guard(media_type), closing(self._connect()) as conn:
if path.exists():
self._upsert_existing_row(conn, media_type, path)
else:
self._atomic_write(path, raw)
self._upsert_new_row(conn, media_type, path)
self._enforce_limit_locked(conn, media_type, protected_names={path.name})
conn.commit()
return path
def _limit_bytes(self, media_type: MediaType) -> int:
cfg = self._config_provider()
limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0)))
return limit_mb * 1024 * 1024
def _target_bytes(self, max_bytes: int) -> int:
return max(0, int(max_bytes * _LOW_WATERMARK_RATIO))
def _media_dir(self, media_type: MediaType) -> Path:
return image_files_dir() if media_type == "image" else video_files_dir()
def _allowed_exts(self, media_type: MediaType) -> frozenset[str]:
return _IMAGE_EXTS if media_type == "image" else _VIDEO_EXTS
def _iter_files(self, media_type: MediaType):
allowed = self._allowed_exts(media_type)
for path in self._media_dir(media_type).glob("*"):
if path.is_file() and path.suffix.lower() in allowed:
yield path
def _path_for_name(self, media_type: MediaType, name: str) -> Path:
return self._media_dir(media_type) / name
def _validate_name(self, media_type: MediaType, name: str) -> str:
value = (name or "").strip()
if not value:
raise ValueError("missing file name")
if Path(value).name != value:
raise ValueError("invalid file name")
if Path(value).suffix.lower() not in self._allowed_exts(media_type):
raise ValueError("unsupported file type")
return value
def _write_if_missing(self, path: Path, raw: bytes) -> None:
if path.exists():
return
self._atomic_write(path, raw)
def _atomic_write(self, path: Path, raw: bytes) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
return
tmp = path.parent / f".{path.name}.{uuid.uuid4().hex}.part"
try:
with tmp.open("wb") as handle:
handle.write(raw)
if path.exists():
return
os.replace(tmp, path)
finally:
tmp.unlink(missing_ok=True)
@contextmanager
def _guard(self, media_type: MediaType):
thread_lock = self._thread_locks[media_type]
thread_lock.acquire()
fd: int | None = None
try:
lock_path = local_media_lock_path(media_type)
lock_path.parent.mkdir(parents=True, exist_ok=True)
try:
import fcntl
fd = os.open(str(lock_path), os.O_CREAT | os.O_RDWR, 0o644)
fcntl.flock(fd, fcntl.LOCK_EX)
except (ImportError, OSError, AttributeError):
fd = None
yield
finally:
if fd is not None:
try:
import fcntl
fcntl.flock(fd, fcntl.LOCK_UN)
os.close(fd)
except (ImportError, OSError):
pass
thread_lock.release()
def _connect(self) -> sqlite3.Connection:
db_path = local_media_cache_db_path()
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA busy_timeout=5000")
self._ensure_schema(conn, db_path)
return conn
def _ensure_schema(self, conn: sqlite3.Connection, db_path: Path) -> None:
if db_path in self._initialized_dbs:
return
with self._init_lock:
if db_path in self._initialized_dbs:
return
conn.executescript(
f"""
CREATE TABLE IF NOT EXISTS {_TABLE} (
media_type TEXT NOT NULL,
name TEXT NOT NULL,
size_bytes INTEGER NOT NULL,
created_at_ns INTEGER NOT NULL,
updated_at_ns INTEGER NOT NULL,
PRIMARY KEY (media_type, name)
);
CREATE INDEX IF NOT EXISTS idx_local_media_order
ON {_TABLE} (media_type, created_at_ns, name);
"""
)
conn.commit()
self._initialized_dbs.add(db_path)
def _upsert_existing_row(
self,
conn: sqlite3.Connection,
media_type: MediaType,
path: Path,
) -> None:
try:
stat = path.stat()
except OSError:
return
created_at_ns = self._lookup_created_at_ns(
conn,
media_type,
path.name,
fallback=int(stat.st_mtime_ns),
)
conn.execute(
f"""
INSERT INTO {_TABLE} (
media_type,
name,
size_bytes,
created_at_ns,
updated_at_ns
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(media_type, name) DO UPDATE SET
size_bytes = excluded.size_bytes,
updated_at_ns = excluded.updated_at_ns
""",
(
media_type,
path.name,
int(stat.st_size),
created_at_ns,
int(stat.st_mtime_ns),
),
)
def _upsert_new_row(
self,
conn: sqlite3.Connection,
media_type: MediaType,
path: Path,
) -> None:
stat = path.stat()
now_ns = time.time_ns()
conn.execute(
f"""
INSERT INTO {_TABLE} (
media_type,
name,
size_bytes,
created_at_ns,
updated_at_ns
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(media_type, name) DO UPDATE SET
size_bytes = excluded.size_bytes,
created_at_ns = excluded.created_at_ns,
updated_at_ns = excluded.updated_at_ns
""",
(
media_type,
path.name,
int(stat.st_size),
now_ns,
now_ns,
),
)
def _lookup_created_at_ns(
self,
conn: sqlite3.Connection,
media_type: MediaType,
name: str,
*,
fallback: int,
) -> int:
row = conn.execute(
f"""
SELECT created_at_ns
FROM {_TABLE}
WHERE media_type = ? AND name = ?
""",
(media_type, name),
).fetchone()
return int(row["created_at_ns"]) if row else int(fallback)
def _usage_bytes(self, conn: sqlite3.Connection, media_type: MediaType) -> int:
row = conn.execute(
f"SELECT COALESCE(SUM(size_bytes), 0) AS total FROM {_TABLE} WHERE media_type = ?",
(media_type,),
).fetchone()
return int(row["total"]) if row else 0
def _newest_name(self, conn: sqlite3.Connection, media_type: MediaType) -> str:
row = conn.execute(
f"""
SELECT name
FROM {_TABLE}
WHERE media_type = ?
ORDER BY created_at_ns DESC, name DESC
LIMIT 1
""",
(media_type,),
).fetchone()
return str(row["name"]) if row else ""
def _enforce_limit_locked(
self,
conn: sqlite3.Connection,
media_type: MediaType,
*,
protected_names: set[str] | None = None,
) -> None:
max_bytes = self._limit_bytes(media_type)
if max_bytes <= 0:
return
usage_bytes = self._usage_bytes(conn, media_type)
if usage_bytes <= max_bytes:
return
protected = set(protected_names or ())
newest_name = self._newest_name(conn, media_type)
if newest_name:
protected.add(newest_name)
target_bytes = self._target_bytes(max_bytes)
rows = conn.execute(
f"""
SELECT name, size_bytes
FROM {_TABLE}
WHERE media_type = ?
ORDER BY created_at_ns ASC, name ASC
""",
(media_type,),
).fetchall()
removed = 0
for row in rows:
if usage_bytes <= target_bytes:
break
name = str(row["name"])
if name in protected:
continue
path = self._path_for_name(media_type, name)
size_bytes = int(row["size_bytes"])
if path.exists():
try:
path.unlink(missing_ok=True)
except OSError as exc:
logger.warning(
"local media cache delete failed: media_type={} name={} error={}",
media_type,
name,
exc,
)
continue
conn.execute(
f"DELETE FROM {_TABLE} WHERE media_type = ? AND name = ?",
(media_type, name),
)
usage_bytes = max(0, usage_bytes - size_bytes)
removed += 1
if removed:
logger.info(
"local media cache trimmed: media_type={} removed={} usage_bytes={} limit_bytes={} target_bytes={}",
media_type,
removed,
usage_bytes,
max_bytes,
target_bytes,
)
elif usage_bytes > max_bytes:
logger.info(
"local media cache kept newest oversized file: media_type={} usage_bytes={} limit_bytes={}",
media_type,
usage_bytes,
max_bytes,
)
def _delete_index_row_if_present(self, media_type: MediaType, name: str) -> None:
db_path = local_media_cache_db_path()
if not db_path.exists():
return
with closing(self._connect()) as conn:
conn.execute(
f"DELETE FROM {_TABLE} WHERE media_type = ? AND name = ?",
(media_type, name),
)
conn.commit()
def _delete_index_rows_if_present(self, media_type: MediaType) -> None:
db_path = local_media_cache_db_path()
if not db_path.exists():
return
with closing(self._connect()) as conn:
conn.execute(f"DELETE FROM {_TABLE} WHERE media_type = ?", (media_type,))
conn.commit()
local_media_cache = LocalMediaCacheStore()
def save_local_image(raw: bytes, mime: str, file_id: str) -> str:
"""Persist an image to local cache and return the file ID."""
return local_media_cache.save_image(raw, mime, file_id)
def save_local_video(raw: bytes, file_id: str) -> Path:
"""Persist a video to local cache and return the file path."""
return local_media_cache.save_video(raw, file_id)
def clear_local_media_files(media_type: MediaType) -> int:
"""Delete all local media files for the requested type."""
return local_media_cache.clear(media_type)
def delete_local_media_file(media_type: MediaType, name: str) -> bool:
"""Delete one local media file by file name."""
return local_media_cache.delete(media_type, name)
async def reconcile_local_media_cache_async(
media_type: MediaType | None = None,
) -> None:
"""Rebuild local media cache indexes and enforce configured limits."""
if media_type is None:
for item in ("image", "video"):
await asyncio.to_thread(local_media_cache.reconcile, item)
return
await asyncio.to_thread(local_media_cache.reconcile, media_type)
__all__ = [
"LocalMediaCacheStore",
"clear_local_media_files",
"delete_local_media_file",
"local_media_cache",
"reconcile_local_media_cache_async",
"save_local_image",
"save_local_video",
]