modelforge-backend / backend /services /inference_cache.py
ModelForge CI
deploy: 2026-06-19 19:24 UTC
6761f70
Raw
History Blame Contribute Delete
5.48 kB
"""
LRU in-memory model cache for the /infer endpoint.
Holds up to MAX_CACHED loaded models; evicts the least-recently-used on overflow.
All model loads and forward passes run in a thread pool via asyncio.to_thread().
asyncio.Lock() ensures only one coroutine loads/evicts at a time — prevents
double-loading and cache corruption under concurrent requests.
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
MAX_CACHED = 3
_INFER_TIMEOUT = 30 # seconds hard cap on inference
@dataclass
class _CachedModel:
run_id: str
model: Any
tokenizer: Any
label_names: list[str]
last_used: float = field(default_factory=time.monotonic)
class ModelCache:
def __init__(self) -> None:
self._cache: dict[str, _CachedModel] = {}
self._lock = asyncio.Lock()
async def predict(
self,
*,
run_id: str,
text: str,
artifact_path: str,
label_names: list[str],
) -> dict[str, Any]:
"""
Run classification on `text`. Loads model on first call, then reuses cache.
Raises asyncio.TimeoutError if inference exceeds _INFER_TIMEOUT seconds.
Lock ensures only one load/eviction runs at a time; _infer runs outside the lock.
"""
async with self._lock:
entry = self._cache.get(run_id)
if entry is None:
entry = await self._load(run_id=run_id, artifact_path=artifact_path, label_names=label_names)
entry.last_used = time.monotonic()
# _infer is CPU/GPU bound — runs outside lock so other requests aren't blocked
try:
result = await asyncio.wait_for(
asyncio.to_thread(self._infer, entry, text),
timeout=_INFER_TIMEOUT,
)
except asyncio.TimeoutError:
raise asyncio.TimeoutError(
f"Inference timed out after {_INFER_TIMEOUT}s. "
"Try shorter text or a lighter model."
)
return result
async def _load(self, *, run_id: str, artifact_path: str, label_names: list[str]) -> _CachedModel:
"""Load model + tokenizer from disk. Evicts LRU if cache is full. Caller must hold _lock."""
if len(self._cache) >= MAX_CACHED:
lru_id = min(self._cache, key=lambda k: self._cache[k].last_used)
logger.info("ModelCache: evicting run %s from cache", lru_id)
del self._cache[lru_id]
model_dir = Path(artifact_path)
if not model_dir.exists():
raise FileNotFoundError(
f"Model files not found at '{artifact_path}'. "
"The server may have restarted since training completed."
)
logger.info("ModelCache: loading model from %s", model_dir)
model, tokenizer = await asyncio.to_thread(
self._blocking_load, str(model_dir), label_names
)
entry = _CachedModel(
run_id=run_id,
model=model,
tokenizer=tokenizer,
label_names=label_names,
)
self._cache[run_id] = entry
logger.info("ModelCache: loaded %d model(s) in cache", len(self._cache))
return entry
@staticmethod
def _blocking_load(model_dir: str, label_names: list[str]) -> tuple:
try:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
except ImportError as exc:
raise RuntimeError(
"Inference libraries not installed. pip install torch transformers"
) from exc
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.eval()
if label_names and (
not model.config.id2label
or model.config.id2label == {0: "LABEL_0"}
):
model.config.id2label = {i: lbl for i, lbl in enumerate(label_names)}
model.config.label2id = {lbl: i for i, lbl in enumerate(label_names)}
return model, tokenizer
@staticmethod
def _infer(entry: _CachedModel, text: str) -> dict[str, Any]:
import torch
import torch.nn.functional as F
inputs = entry.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
with torch.no_grad():
logits = entry.model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze().tolist()
if isinstance(probs, float):
probs = [probs]
label_names = entry.label_names or [
entry.model.config.id2label.get(i, f"class_{i}") for i in range(len(probs))
]
scored = sorted(
[{"label": lbl, "score": round(float(p), 6), "pct": round(float(p) * 100, 1)}
for lbl, p in zip(label_names, probs)],
key=lambda x: x["score"],
reverse=True,
)
return {
"predicted_label": scored[0]["label"],
"confidence": scored[0]["score"],
"all_scores": scored,
}
def evict(self, run_id: str) -> None:
self._cache.pop(run_id, None)
# Module-level singleton used by main.py
cache = ModelCache()