File size: 5,563 Bytes
fac3f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afbd922
fac3f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afbd922
fac3f86
afbd922
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac3f86
 
afbd922
fac3f86
afbd922
fac3f86
afbd922
 
 
 
fac3f86
 
 
 
afbd922
fac3f86
 
afbd922
 
 
 
 
 
fac3f86
afbd922
 
fac3f86
 
afbd922
 
 
 
 
 
fac3f86
afbd922
 
fac3f86
 
afbd922
 
 
fac3f86
afbd922
 
 
fac3f86
 
afbd922
 
 
fac3f86
 
afbd922
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Shared model loading and embedding extraction utilities.

All evaluation scripts that need to load GAP-CLIP, the Fashion-CLIP baseline,
or the specialized color model should import from here instead of duplicating
the loading logic.
"""

from __future__ import annotations

import sys
from pathlib import Path
from typing import Tuple

import torch
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel as CLIPModelTransformers
from transformers import CLIPProcessor

# Make project root importable when running evaluation scripts directly.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))


# ---------------------------------------------------------------------------
# GAP-CLIP (main model)
# ---------------------------------------------------------------------------

def load_gap_clip(
    model_path: str,
    device: torch.device,
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
    """Load GAP-CLIP (LAION CLIP + fine-tuned checkpoint) and its processor.

    Args:
        model_path: Path to the `gap_clip.pth` checkpoint.
        device: Target device.

    Returns:
        (model, processor) ready for inference.
    """
    model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)

    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    model = model.to(device)
    model.eval()
    processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
    return model, processor


# ---------------------------------------------------------------------------
# Fashion-CLIP baseline
# ---------------------------------------------------------------------------

def load_baseline_fashion_clip(
    device: torch.device,
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
    """Load the Fashion-CLIP baseline (patrickjohncyh/fashion-clip).

    Returns:
        (model, processor) ready for inference.
    """
    model_name = "patrickjohncyh/fashion-clip"
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModelTransformers.from_pretrained(model_name).to(device)
    model.eval()
    return model, processor


# ---------------------------------------------------------------------------
# Specialized 16D color model
# ---------------------------------------------------------------------------

def load_color_model(
    color_model_path: str,
    device: torch.device,
):
    """Load the specialized 16D color model (CLIP-backbone).

    Returns:
        (color_model, None)  -- second element kept for API compatibility
    """
    from training.color_model import ColorCLIP  # type: ignore

    print("Loading ColorCLIP (CLIP-backbone, 16D) ...")
    color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
    print("Color model loaded successfully")
    return color_model, None


def load_hierarchy_model(
    hierarchy_model_path: str,
    device: torch.device,
):
    """Load the hierarchy model (CLIP-backbone).

    Returns:
        hierarchy_model ready for inference.
    """
    from training.hierarchy_model import HierarchyModel  # type: ignore

    print("Loading HierarchyModel (CLIP-backbone, 64D) ...")
    model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
    print("Hierarchy model loaded successfully")
    return model



# ---------------------------------------------------------------------------
# Core encoding helpers (same as notebook)
# ---------------------------------------------------------------------------

def encode_text(model, processor, text_queries, device):
    """Encode text queries into embeddings (unnormalized)."""
    if isinstance(text_queries, str):
        text_queries = [text_queries]
    inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)
    return text_features


def encode_image(model, processor, images, device):
    """Encode images into embeddings (unnormalized)."""
    if not isinstance(images, list):
        images = [images]
    inputs = processor(images=images, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)
    return image_features


# ---------------------------------------------------------------------------
# Normalized wrappers (preserve old call signatures used across eval scripts)
# ---------------------------------------------------------------------------

def get_text_embedding(model, processor, device, text):
    """Single normalized text embedding (shape: [512])."""
    return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0)


def get_text_embeddings_batch(model, processor, device, texts):
    """Normalized text embeddings for a batch (shape: [N, 512])."""
    return F.normalize(encode_text(model, processor, texts, device), dim=-1)


def get_image_embedding_from_pil(model, processor, device, pil_image):
    """Normalized image embedding from a PIL image (shape: [512])."""
    return F.normalize(encode_image(model, processor, pil_image, device), dim=-1).squeeze(0)