cove-api / src /models /image_encoder.py
mickey1976's picture
Deploy: Minimal FastAPI backend for CoVE Space
549c270
# src/models/image_encoder.py
from __future__ import annotations
from typing import List, Optional, Tuple
from pathlib import Path
import io
import requests
from PIL import Image, UnidentifiedImageError
import torch
import numpy as np
from tqdm import tqdm
import open_clip # pip install open-clip-torch
class ImageEncoder:
"""
Thin wrapper around OpenCLIP for image embeddings.
- Model: ViT-B/32 (laion2b_s34b_b79k) by default (fast & lightweight)
- GPU if available
- Takes PIL Image or URL strings
- Returns L2-normalized float32 numpy vectors
"""
def __init__(
self,
model_name: str = "ViT-B-32",
pretrained: str = "laion2b_s34b_b79k",
device: Optional[str] = None,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained, device=self.device
)
self.model.eval()
@property
def dim(self) -> int:
return self.model.visual.output_dims
# -------------------- image IO helpers -------------------- #
@staticmethod
def _read_image_from_url(url: str, timeout: float = 10.0) -> Optional[Image.Image]:
try:
r = requests.get(url, timeout=timeout, stream=True)
r.raise_for_status()
img = Image.open(io.BytesIO(r.content)).convert("RGB")
return img
except Exception:
return None
@staticmethod
def _read_image_from_path(path: Path) -> Optional[Image.Image]:
try:
img = Image.open(path).convert("RGB")
return img
except (FileNotFoundError, UnidentifiedImageError, OSError):
return None
# -------------------- main encode API -------------------- #
def encode_pils(self, images: List[Image.Image], batch_size: int = 64, desc: str = "Encoding") -> np.ndarray:
"""Encode a list of PIL images."""
if len(images) == 0:
return np.zeros((0, self.dim), dtype=np.float32)
vecs = []
for i in tqdm(range(0, len(images), batch_size), desc=desc):
chunk = images[i : i + batch_size]
tensors = torch.stack([self.preprocess(img) for img in chunk]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(self.device == "cuda")):
feats = self.model.encode_image(tensors)
feats = feats / feats.norm(dim=-1, keepdim=True)
vecs.append(feats.detach().cpu().numpy().astype(np.float32))
return np.vstack(vecs)
def encode_urls(
self,
urls: List[str],
batch_size: int = 64,
cache_dir: Optional[Path] = None,
desc: str = "Encoding",
timeout: float = 10.0,
) -> Tuple[np.ndarray, List[bool]]:
"""
Encode images from URLs. Optionally cache to disk.
Returns:
- ndarray [N, D]
- list[bool] whether entry was successfully encoded
"""
pils: List[Optional[Image.Image]] = []
ok: List[bool] = []
cache_dir = Path(cache_dir) if cache_dir else None
if cache_dir:
cache_dir.mkdir(parents=True, exist_ok=True)
for url in tqdm(urls, desc="Fetching"):
img = None
# cache hit?
if cache_dir and url:
# hash filename by simple replace (urls may be long)
# or just store as numbered files in caller
pass
if url:
img = self._read_image_from_url(url, timeout=timeout)
pils.append(img if img is not None else None)
ok.append(img is not None)
# Replace None with a tiny blank image to keep shape; will be masked out
blank = Image.new("RGB", (224, 224), color=(255, 255, 255))
safe_pils = [im if im is not None else blank for im in pils]
vecs = self.encode_pils(safe_pils, batch_size=batch_size, desc=desc)
# Zero-out failed rows so we can mask downstream
if not all(ok):
mask = np.array(ok, dtype=bool)
vecs[~mask] = 0.0
return vecs, ok
__all__ = ["ImageEncoder"]