Authentica / image_embedder.py
MAS-AI-0000's picture
Update image_embedder.py
247a9cd verified
"""Image β†’ detection-ready embedding.
Loads CLIP (ViT-B/32) and the trained ``CLIPProjector`` and exposes
``get_image_embedding``, which encodes a PIL image and projects it into the
DETree embedding space β€” ready to be passed to ``detect_embedding``.
Usage::
from PIL import Image
from Apps.image_embedder import get_image_embedding
from Apps.detector import detect_embedding
pil_img = Image.open("photo.jpg")
emb = get_image_embedding(pil_img)
result = detect_embedding(emb, mode="image")
# {"predicted_class": "Real"|"AI", "confidence": 0.91}
"""
from __future__ import annotations
import os
import sys
from typing import Optional
import logging
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download
log = logging.getLogger("image_embedder")
logging.basicConfig(level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s")
# ---------------------------------------------------------------------------
# Make the local 'detree' package importable
# ---------------------------------------------------------------------------
_current_dir = os.path.dirname(os.path.abspath(__file__))
if _current_dir not in sys.path:
sys.path.append(_current_dir)
try:
import clip as _clip_lib
log.info("clip package imported successfully.")
except ImportError:
log.error("'clip' package not found β€” image embedding will return zeros.")
_clip_lib = None
try:
from detree.model.clip_projector import CLIPProjector
log.info("CLIPProjector imported successfully.")
except ImportError as _e:
log.error(f"Could not import CLIPProjector: {_e} β€” image embedding will return zeros.")
CLIPProjector = None
# Hugging face
_BASE_DIR = "MAS-AI-0000/Authentica"
_PROJECTOR_DIR = hf_hub_download(
repo_id=_BASE_DIR,
filename="Lib/Models/Image/clip_projector.pt",
)
log.info(f"[paths] _BASE_DIR = {_BASE_DIR!r}")
log.info(f"[paths] _PROJECTOR_DIR = {_PROJECTOR_DIR!r} exists={os.path.exists(_PROJECTOR_DIR)}")
if os.path.isdir(_PROJECTOR_DIR):
log.info(f"[paths] _PROJECTOR_DIR contents: {os.listdir(_PROJECTOR_DIR)}")
elif os.path.isfile(_PROJECTOR_DIR):
log.info(f"[paths] _PROJECTOR_DIR is a file (hf_hub_download path), not a directory.")
else:
log.warning(f"[paths] _PROJECTOR_DIR does not exist.")
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
CLIP_MODEL = "ViT-B/32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "MAS-AI-0000/Authentica"
CLIP_PROJECTOR_FILENAME = "Lib/Models/Image/clip_projector.pt"
# ==== Load assets ====
clip_projector_path = hf_hub_download(repo_id=REPO_ID, filename=CLIP_PROJECTOR_FILENAME)
log.info(f"[config] device={DEVICE!r} clip_model={CLIP_MODEL!r}")
# ---------------------------------------------------------------------------
# Module-level initialisation
# ---------------------------------------------------------------------------
_clip_model: Optional[object] = None
_clip_prep: Optional[object] = None
_projector: Optional[object] = None
def _init() -> None:
global _clip_model, _clip_prep, _projector
log.info("_init: starting ImageEmbedder initialisation.")
if _clip_lib is None or CLIPProjector is None:
log.error("_init: required packages unavailable β€” embedding disabled.")
return
# Load CLIP
log.info(f"_init: loading CLIP model {CLIP_MODEL!r} on device={DEVICE!r} ...")
try:
_clip_model, _clip_prep = _clip_lib.load(CLIP_MODEL, jit=False, device=DEVICE)
_clip_model.eval()
for param in _clip_model.parameters():
param.requires_grad = False
log.info(f"_init: CLIP ({CLIP_MODEL}) loaded OK on {DEVICE!r}")
except Exception as exc:
log.exception(f"_init: error loading CLIP: {exc}")
return
# Load CLIPProjector
# _PROJECTOR_DIR may be either:
# - a directory (local / Dockerfile copy) β†’ pass as-is to from_pretrained
# - a file path (hf_hub_download result) β†’ pass the parent directory
if not os.path.exists(_PROJECTOR_DIR):
log.error(f"_init: projector path not found at {_PROJECTOR_DIR!r} β€” embedding disabled.")
return
projector_dir = _PROJECTOR_DIR if os.path.isdir(_PROJECTOR_DIR) else os.path.dirname(_PROJECTOR_DIR)
log.info(f"_init: loading CLIPProjector from {projector_dir!r} ...")
try:
_projector = CLIPProjector.from_pretrained(
projector_dir, device=DEVICE
).to(DEVICE)
_projector.eval()
log.info(f"_init: CLIPProjector loaded OK. "
f"clip_dim={_projector.clip_dim} target_dim={_projector.target_dim}")
except Exception as exc:
log.exception(f"_init: error loading CLIPProjector: {exc}")
_init()
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
@torch.no_grad()
def get_image_embedding(image: Image.Image) -> np.ndarray:
"""Return a (1, embedding_dim) float32 numpy array for the given PIL image.
The embedding is CLIP-encoded, L2-normalised, and projected through the
trained ``CLIPProjector`` so it lives in the same space as the DETree
database. Pass the result directly to ``detect_embedding(emb, mode="image")``.
Args:
image: A ``PIL.Image.Image`` object (any mode; converted to RGB internally).
Returns:
``np.ndarray`` of shape ``(1, embedding_dim)`` and dtype float32.
"""
if _clip_model is None or _projector is None:
log.error("get_image_embedding: clip_model or projector is None β€” returning zeros. Check _init logs.")
return np.zeros((1, 1), dtype=np.float32)
log.info(f"get_image_embedding: input image size={image.size} mode={image.mode!r}")
try:
image = image.convert("RGB")
image_tensor = _clip_prep(image).unsqueeze(0).to(DEVICE)
log.info(f"get_image_embedding: preprocessed tensor shape={tuple(image_tensor.shape)}")
# CLIP encode β†’ L2-normalise
clip_emb = _clip_model.encode_image(image_tensor).float()
log.info(f"get_image_embedding: raw CLIP embedding shape={tuple(clip_emb.shape)} "
f"norm={clip_emb.norm(dim=-1).item():.4f}")
clip_emb = F.normalize(clip_emb, dim=-1)
clip_emb = clip_emb.float()
# Project into the DETree embedding space (projector normalises output)
projected = _projector(clip_emb, normalize=True)
log.info(f"get_image_embedding: projected shape={tuple(projected.shape)} "
f"norm={projected.norm(dim=-1).item():.4f}")
except Exception as exc:
log.exception(f"get_image_embedding: failed during inference: {exc}")
return np.zeros((1, 1), dtype=np.float32)
return projected.cpu().numpy().astype(np.float32)
@torch.no_grad()
def get_image_embeddings_batch(images: list[Image.Image]) -> np.ndarray:
"""Return an (N, embedding_dim) float32 array for a list of PIL images.
Args:
images: List of ``PIL.Image.Image`` objects.
Returns:
``np.ndarray`` of shape ``(N, embedding_dim)`` and dtype float32.
"""
if _clip_model is None or _projector is None:
return np.zeros((len(images), 1), dtype=np.float32)
tensors = torch.stack(
[_clip_prep(img.convert("RGB")) for img in images]
).to(DEVICE)
clip_embs = _clip_model.encode_image(tensors).float()
clip_embs = F.normalize(clip_embs, dim=-1)
projected = _projector(clip_embs, normalize=True)
return projected.cpu().numpy().astype(np.float32)