File size: 5,563 Bytes
fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 fac3f86 afbd922 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
Shared model loading and embedding extraction utilities.
All evaluation scripts that need to load GAP-CLIP, the Fashion-CLIP baseline,
or the specialized color model should import from here instead of duplicating
the loading logic.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Tuple
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel as CLIPModelTransformers
from transformers import CLIPProcessor
# Make project root importable when running evaluation scripts directly.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
# ---------------------------------------------------------------------------
# GAP-CLIP (main model)
# ---------------------------------------------------------------------------
def load_gap_clip(
model_path: str,
device: torch.device,
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
"""Load GAP-CLIP (LAION CLIP + fine-tuned checkpoint) and its processor.
Args:
model_path: Path to the `gap_clip.pth` checkpoint.
device: Target device.
Returns:
(model, processor) ready for inference.
"""
model = CLIPModelTransformers.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
return model, processor
# ---------------------------------------------------------------------------
# Fashion-CLIP baseline
# ---------------------------------------------------------------------------
def load_baseline_fashion_clip(
device: torch.device,
) -> Tuple[CLIPModelTransformers, CLIPProcessor]:
"""Load the Fashion-CLIP baseline (patrickjohncyh/fashion-clip).
Returns:
(model, processor) ready for inference.
"""
model_name = "patrickjohncyh/fashion-clip"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModelTransformers.from_pretrained(model_name).to(device)
model.eval()
return model, processor
# ---------------------------------------------------------------------------
# Specialized 16D color model
# ---------------------------------------------------------------------------
def load_color_model(
color_model_path: str,
device: torch.device,
):
"""Load the specialized 16D color model (CLIP-backbone).
Returns:
(color_model, None) -- second element kept for API compatibility
"""
from training.color_model import ColorCLIP # type: ignore
print("Loading ColorCLIP (CLIP-backbone, 16D) ...")
color_model = ColorCLIP.from_checkpoint(color_model_path, device=device)
print("Color model loaded successfully")
return color_model, None
def load_hierarchy_model(
hierarchy_model_path: str,
device: torch.device,
):
"""Load the hierarchy model (CLIP-backbone).
Returns:
hierarchy_model ready for inference.
"""
from training.hierarchy_model import HierarchyModel # type: ignore
print("Loading HierarchyModel (CLIP-backbone, 64D) ...")
model = HierarchyModel.from_checkpoint(hierarchy_model_path, device=device)
print("Hierarchy model loaded successfully")
return model
# ---------------------------------------------------------------------------
# Core encoding helpers (same as notebook)
# ---------------------------------------------------------------------------
def encode_text(model, processor, text_queries, device):
"""Encode text queries into embeddings (unnormalized)."""
if isinstance(text_queries, str):
text_queries = [text_queries]
inputs = processor(text=text_queries, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
text_features = model.get_text_features(**inputs)
return text_features
def encode_image(model, processor, images, device):
"""Encode images into embeddings (unnormalized)."""
if not isinstance(images, list):
images = [images]
inputs = processor(images=images, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features
# ---------------------------------------------------------------------------
# Normalized wrappers (preserve old call signatures used across eval scripts)
# ---------------------------------------------------------------------------
def get_text_embedding(model, processor, device, text):
"""Single normalized text embedding (shape: [512])."""
return F.normalize(encode_text(model, processor, text, device), dim=-1).squeeze(0)
def get_text_embeddings_batch(model, processor, device, texts):
"""Normalized text embeddings for a batch (shape: [N, 512])."""
return F.normalize(encode_text(model, processor, texts, device), dim=-1)
def get_image_embedding_from_pil(model, processor, device, pil_image):
"""Normalized image embedding from a PIL image (shape: [512])."""
return F.normalize(encode_image(model, processor, pil_image, device), dim=-1).squeeze(0)
|