File size: 4,945 Bytes
052f26d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import tqdm
import torch
import numpy as np
import torch.nn as nn
from PIL import Image
from pathlib import Path
from torchvision import transforms
from typing import List, Union, Tuple, Optional
from torch.utils.data import DataLoader, Dataset


class ImageEmbeddingDataset(Dataset):
    """Dataset for batch image embedding generation"""
    def __init__(
        self,
        image_paths: List[Union[str, Path]],
        transform=None
    ):
        self.image_paths = [Path(p) for p in image_paths]
        self.transform = transform or self.default_transform()

    @staticmethod
    def default_transform():
        # I-JEPA uses mean=05 and std=0.5 normalization
        return transforms.Compose([
            transforms.Resize(
                224, interpolation=transforms.InterpolationMode.BICUBIC
            ),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
            )
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        return image, str(img_path)


class EmbeddingGenerator:
    """Generate embeddings for image database using batch inference."""
    def __init__(
        self,
        model: nn.Module,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        batch_size: int = 4,
        num_workers: int = 1,
        layer_strategy: str = "second_last",
        specific_indices: Optional[List[int]] = None,
    ):
        self.model = model.to(device).eval()
        self.device = device
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.layer_strategy = layer_strategy
        self.specific_indices = specific_indices

        # Freeze model
        for param in self.model.parameters():
            param.requires_grad = False

    def _get_features(self, images: torch.Tensor) -> torch.Tensor:
        """Central routing method - all forward calss go through here."""
        if self.layer_strategy == "last":
            return self.model(images)
        return self.model.get_layer_representations(
            images,
            strategy=self.layer_strategy,
            specific_indices=self.specific_indices,
        )

    @torch.no_grad()
    def generate_embeddings(
        self,
        image_paths: List[Union[str, Path]],
        return_paths: bool = True,
        show_progress: bool = True,
    ) -> Union[np.ndarray, Tuple[np.ndarray, List[str]]]:
        """
        Generate embeddings for all images.

        Returns:
            embeddings: (N, D) array of embeddings
            paths: (optional) list of image paths
        """
        print("   3.1 Image Embedding Dataset...")
        dataset = ImageEmbeddingDataset(image_paths)
        print("   3.2 DataLoader...")
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        all_embeddings = []
        all_paths = []

        print("   3.3 tqdm iterator...\n")
        iterator = tqdm.tqdm(dataloader, desc="Generating embeddings") if show_progress else dataloader

        print("   3.4 for loop...")
        for batch_images, batch_paths in iterator:
            batch_images = batch_images.to(self.device, non_blocking=True)

            # Get embeddings: average pool patch tokens for global representation
            features = self._get_features(batch_images)   # (B, N, D)
            embeddings = features.mean(dim=1)     # (B, D)

            # L2 normalization for cosine similarity
            embeddings = nn.functional.normalize(embeddings, p=2, dim=1)

            all_embeddings.append(embeddings.cpu().numpy())
            all_paths.extend(batch_paths)

        embeddings = np.vstack(all_embeddings)

        if return_paths:
            return embeddings, all_paths
        return embeddings

    def generate_single_embedding(
        self, image: Union[str, Path, Image.Image, torch.Tensor]
    ) -> np.ndarray:
        """Generate embedding for a single image"""
        transform = ImageEmbeddingDataset.default_transform()

        if isinstance(image, (str, Path)):
            image = Image.open(image).convert('RGB')

        if isinstance(image, Image.Image):
            image = transform(image)

        if isinstance(image, torch.Tensor):
            image = image.unsqueeze(0) if image.dim() == 3 else image

        image = image.to(self.device)

        with torch.no_grad():
            features = self._get_features(image)
            embedding = features.mean(dim=1)
            embedding = nn.functional.normalize(embedding, p=2, dim=1)

        return embedding.cpu().numpy()