File size: 7,801 Bytes
586aed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fbc687
586aed7
 
 
 
 
 
4fbc687
 
 
 
586aed7
 
 
 
 
 
 
 
 
4fbc687
586aed7
4fbc687
586aed7
 
4fbc687
586aed7
 
4fbc687
586aed7
4fbc687
586aed7
 
 
 
e7dfd0c
 
 
 
586aed7
a9fdc98
4fbc687
 
a9fdc98
4fbc687
a9fdc98
 
 
 
4fbc687
 
586aed7
 
 
 
 
 
 
 
 
 
 
 
4fbc687
 
586aed7
 
 
 
 
 
 
 
 
 
 
 
4fbc687
 
586aed7
4fbc687
586aed7
 
 
4fbc687
586aed7
 
 
 
 
4fbc687
586aed7
4fbc687
586aed7
 
 
a9fdc98
 
 
586aed7
a9fdc98
586aed7
 
a9fdc98
 
586aed7
 
a9fdc98
586aed7
 
4fbc687
 
586aed7
4fbc687
586aed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fbc687
586aed7
 
4fbc687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586aed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""Image β†’ detection-ready embedding.

Loads CLIP (ViT-B/32) and the trained ``CLIPProjector`` and exposes
``get_image_embedding``, which encodes a PIL image and projects it into the
DETree embedding space β€” ready to be passed to ``detect_embedding``.

Usage::

    from PIL import Image
    from Apps.image_embedder import get_image_embedding
    from Apps.detector import detect_embedding

    pil_img = Image.open("photo.jpg")
    emb     = get_image_embedding(pil_img)
    result  = detect_embedding(emb, mode="image")
    # {"predicted_class": "Real"|"AI", "confidence": 0.91}
"""

from __future__ import annotations

import os
import sys
from typing import Optional
import logging
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download


log = logging.getLogger("image_embedder")
logging.basicConfig(level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s")

# ---------------------------------------------------------------------------
# Make the local 'detree' package importable
# ---------------------------------------------------------------------------
_current_dir = os.path.dirname(os.path.abspath(__file__))
if _current_dir not in sys.path:
    sys.path.append(_current_dir)

try:
    import clip as _clip_lib
    log.info("clip package imported successfully.")
except ImportError:
    log.error("'clip' package not found β€” image embedding will return zeros.")
    _clip_lib = None


try:
    from detree.model.clip_projector import CLIPProjector
    log.info("CLIPProjector imported successfully.")
except ImportError as _e:
    log.error(f"Could not import CLIPProjector: {_e} β€” image embedding will return zeros.")
    CLIPProjector = None

# Hugging face
_BASE_DIR = "MAS-AI-0000/Authentica"
_PROJECTOR_DIR = hf_hub_download(
    repo_id=_BASE_DIR,
    filename="Lib/Models/Image/clip_projector.pt",
)


log.info(f"[paths] _BASE_DIR      = {_BASE_DIR!r}")
log.info(f"[paths] _PROJECTOR_DIR = {_PROJECTOR_DIR!r}  exists={os.path.exists(_PROJECTOR_DIR)}")
if os.path.isdir(_PROJECTOR_DIR):
    log.info(f"[paths] _PROJECTOR_DIR contents: {os.listdir(_PROJECTOR_DIR)}")
elif os.path.isfile(_PROJECTOR_DIR):
    log.info(f"[paths] _PROJECTOR_DIR is a file (hf_hub_download path), not a directory.")
else:
    log.warning(f"[paths] _PROJECTOR_DIR does not exist.")


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
CLIP_MODEL = "ViT-B/32"
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

REPO_ID = "MAS-AI-0000/Authentica"
CLIP_PROJECTOR_FILENAME = "Lib/Models/Image/clip_projector.pt"

# ==== Load assets ====
clip_projector_path = hf_hub_download(repo_id=REPO_ID, filename=CLIP_PROJECTOR_FILENAME)

log.info(f"[config] device={DEVICE!r}  clip_model={CLIP_MODEL!r}")

# ---------------------------------------------------------------------------
# Module-level initialisation
# ---------------------------------------------------------------------------

_clip_model:  Optional[object] = None
_clip_prep:   Optional[object] = None
_projector:   Optional[object] = None


def _init() -> None:
    global _clip_model, _clip_prep, _projector

    log.info("_init: starting ImageEmbedder initialisation.")

    if _clip_lib is None or CLIPProjector is None:
        log.error("_init: required packages unavailable β€” embedding disabled.")
        return

    # Load CLIP
    log.info(f"_init: loading CLIP model {CLIP_MODEL!r} on device={DEVICE!r} ...")
    try:
        _clip_model, _clip_prep = _clip_lib.load(CLIP_MODEL, jit=False, device=DEVICE)
        _clip_model.eval()
        for param in _clip_model.parameters():
            param.requires_grad = False
        log.info(f"_init: CLIP ({CLIP_MODEL}) loaded OK on {DEVICE!r}")
    except Exception as exc:
        log.exception(f"_init: error loading CLIP: {exc}")
        return

    # Load CLIPProjector
    # _PROJECTOR_DIR may be either:
    #   - a directory (local / Dockerfile copy) β†’ pass as-is to from_pretrained
    #   - a file path (hf_hub_download result)  β†’ pass the parent directory
    if not os.path.exists(_PROJECTOR_DIR):
        log.error(f"_init: projector path not found at {_PROJECTOR_DIR!r} β€” embedding disabled.")
        return

    projector_dir = _PROJECTOR_DIR if os.path.isdir(_PROJECTOR_DIR) else os.path.dirname(_PROJECTOR_DIR)
    log.info(f"_init: loading CLIPProjector from {projector_dir!r} ...")
    try:
        _projector = CLIPProjector.from_pretrained(
            projector_dir, device=DEVICE
        ).to(DEVICE)
        _projector.eval()
        log.info(f"_init: CLIPProjector loaded OK. "
                 f"clip_dim={_projector.clip_dim}  target_dim={_projector.target_dim}")
    except Exception as exc:
        log.exception(f"_init: error loading CLIPProjector: {exc}")


_init()


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

@torch.no_grad()
def get_image_embedding(image: Image.Image) -> np.ndarray:
    """Return a (1, embedding_dim) float32 numpy array for the given PIL image.

    The embedding is CLIP-encoded, L2-normalised, and projected through the
    trained ``CLIPProjector`` so it lives in the same space as the DETree
    database.  Pass the result directly to ``detect_embedding(emb, mode="image")``.

    Args:
        image: A ``PIL.Image.Image`` object (any mode; converted to RGB internally).

    Returns:
        ``np.ndarray`` of shape ``(1, embedding_dim)`` and dtype float32.
    """
    if _clip_model is None or _projector is None:
        log.error("get_image_embedding: clip_model or projector is None β€” returning zeros. Check _init logs.")
        return np.zeros((1, 1), dtype=np.float32)

    log.info(f"get_image_embedding: input image size={image.size}  mode={image.mode!r}")
    try:
        image = image.convert("RGB")
        image_tensor = _clip_prep(image).unsqueeze(0).to(DEVICE)
        log.info(f"get_image_embedding: preprocessed tensor shape={tuple(image_tensor.shape)}")

        # CLIP encode β†’ L2-normalise
        clip_emb = _clip_model.encode_image(image_tensor).float()
        log.info(f"get_image_embedding: raw CLIP embedding shape={tuple(clip_emb.shape)}  "
                 f"norm={clip_emb.norm(dim=-1).item():.4f}")
        clip_emb = F.normalize(clip_emb, dim=-1)
        clip_emb = clip_emb.float()

        # Project into the DETree embedding space (projector normalises output)
        projected = _projector(clip_emb, normalize=True)
        log.info(f"get_image_embedding: projected shape={tuple(projected.shape)}  "
                 f"norm={projected.norm(dim=-1).item():.4f}")
    except Exception as exc:
        log.exception(f"get_image_embedding: failed during inference: {exc}")
        return np.zeros((1, 1), dtype=np.float32)

    return projected.cpu().numpy().astype(np.float32)


@torch.no_grad()
def get_image_embeddings_batch(images: list[Image.Image]) -> np.ndarray:
    """Return an (N, embedding_dim) float32 array for a list of PIL images.

    Args:
        images: List of ``PIL.Image.Image`` objects.

    Returns:
        ``np.ndarray`` of shape ``(N, embedding_dim)`` and dtype float32.
    """
    if _clip_model is None or _projector is None:
        return np.zeros((len(images), 1), dtype=np.float32)

    tensors = torch.stack(
        [_clip_prep(img.convert("RGB")) for img in images]
    ).to(DEVICE)

    clip_embs  = _clip_model.encode_image(tensors).float()
    clip_embs  = F.normalize(clip_embs, dim=-1)
    projected  = _projector(clip_embs, normalize=True)

    return projected.cpu().numpy().astype(np.float32)