Spaces:
Sleeping
Sleeping
File size: 4,265 Bytes
549c270 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# src/models/image_encoder.py
from __future__ import annotations
from typing import List, Optional, Tuple
from pathlib import Path
import io
import requests
from PIL import Image, UnidentifiedImageError
import torch
import numpy as np
from tqdm import tqdm
import open_clip # pip install open-clip-torch
class ImageEncoder:
"""
Thin wrapper around OpenCLIP for image embeddings.
- Model: ViT-B/32 (laion2b_s34b_b79k) by default (fast & lightweight)
- GPU if available
- Takes PIL Image or URL strings
- Returns L2-normalized float32 numpy vectors
"""
def __init__(
self,
model_name: str = "ViT-B-32",
pretrained: str = "laion2b_s34b_b79k",
device: Optional[str] = None,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name, pretrained=pretrained, device=self.device
)
self.model.eval()
@property
def dim(self) -> int:
return self.model.visual.output_dims
# -------------------- image IO helpers -------------------- #
@staticmethod
def _read_image_from_url(url: str, timeout: float = 10.0) -> Optional[Image.Image]:
try:
r = requests.get(url, timeout=timeout, stream=True)
r.raise_for_status()
img = Image.open(io.BytesIO(r.content)).convert("RGB")
return img
except Exception:
return None
@staticmethod
def _read_image_from_path(path: Path) -> Optional[Image.Image]:
try:
img = Image.open(path).convert("RGB")
return img
except (FileNotFoundError, UnidentifiedImageError, OSError):
return None
# -------------------- main encode API -------------------- #
def encode_pils(self, images: List[Image.Image], batch_size: int = 64, desc: str = "Encoding") -> np.ndarray:
"""Encode a list of PIL images."""
if len(images) == 0:
return np.zeros((0, self.dim), dtype=np.float32)
vecs = []
for i in tqdm(range(0, len(images), batch_size), desc=desc):
chunk = images[i : i + batch_size]
tensors = torch.stack([self.preprocess(img) for img in chunk]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(self.device == "cuda")):
feats = self.model.encode_image(tensors)
feats = feats / feats.norm(dim=-1, keepdim=True)
vecs.append(feats.detach().cpu().numpy().astype(np.float32))
return np.vstack(vecs)
def encode_urls(
self,
urls: List[str],
batch_size: int = 64,
cache_dir: Optional[Path] = None,
desc: str = "Encoding",
timeout: float = 10.0,
) -> Tuple[np.ndarray, List[bool]]:
"""
Encode images from URLs. Optionally cache to disk.
Returns:
- ndarray [N, D]
- list[bool] whether entry was successfully encoded
"""
pils: List[Optional[Image.Image]] = []
ok: List[bool] = []
cache_dir = Path(cache_dir) if cache_dir else None
if cache_dir:
cache_dir.mkdir(parents=True, exist_ok=True)
for url in tqdm(urls, desc="Fetching"):
img = None
# cache hit?
if cache_dir and url:
# hash filename by simple replace (urls may be long)
# or just store as numbered files in caller
pass
if url:
img = self._read_image_from_url(url, timeout=timeout)
pils.append(img if img is not None else None)
ok.append(img is not None)
# Replace None with a tiny blank image to keep shape; will be masked out
blank = Image.new("RGB", (224, 224), color=(255, 255, 255))
safe_pils = [im if im is not None else blank for im in pils]
vecs = self.encode_pils(safe_pils, batch_size=batch_size, desc=desc)
# Zero-out failed rows so we can mask downstream
if not all(ok):
mask = np.array(ok, dtype=bool)
vecs[~mask] = 0.0
return vecs, ok
__all__ = ["ImageEncoder"] |