Spaces:
Sleeping
Sleeping
| import tqdm | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from PIL import Image | |
| from pathlib import Path | |
| from torchvision import transforms | |
| from typing import List, Union, Tuple, Optional | |
| from torch.utils.data import DataLoader, Dataset | |
| class ImageEmbeddingDataset(Dataset): | |
| """Dataset for batch image embedding generation""" | |
| def __init__( | |
| self, | |
| image_paths: List[Union[str, Path]], | |
| transform=None | |
| ): | |
| self.image_paths = [Path(p) for p in image_paths] | |
| self.transform = transform or self.default_transform() | |
| def default_transform(): | |
| # I-JEPA uses mean=05 and std=0.5 normalization | |
| return transforms.Compose([ | |
| transforms.Resize( | |
| 224, interpolation=transforms.InterpolationMode.BICUBIC | |
| ), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] | |
| ) | |
| ]) | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| img_path = self.image_paths[idx] | |
| image = Image.open(img_path).convert('RGB') | |
| image = self.transform(image) | |
| return image, str(img_path) | |
| class EmbeddingGenerator: | |
| """Generate embeddings for image database using batch inference.""" | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| batch_size: int = 4, | |
| num_workers: int = 1, | |
| layer_strategy: str = "second_last", | |
| specific_indices: Optional[List[int]] = None, | |
| ): | |
| self.model = model.to(device).eval() | |
| self.device = device | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.layer_strategy = layer_strategy | |
| self.specific_indices = specific_indices | |
| # Freeze model | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| def _get_features(self, images: torch.Tensor) -> torch.Tensor: | |
| """Central routing method - all forward calss go through here.""" | |
| if self.layer_strategy == "last": | |
| return self.model(images) | |
| return self.model.get_layer_representations( | |
| images, | |
| strategy=self.layer_strategy, | |
| specific_indices=self.specific_indices, | |
| ) | |
| def generate_embeddings( | |
| self, | |
| image_paths: List[Union[str, Path]], | |
| return_paths: bool = True, | |
| show_progress: bool = True, | |
| ) -> Union[np.ndarray, Tuple[np.ndarray, List[str]]]: | |
| """ | |
| Generate embeddings for all images. | |
| Returns: | |
| embeddings: (N, D) array of embeddings | |
| paths: (optional) list of image paths | |
| """ | |
| print(" 3.1 Image Embedding Dataset...") | |
| dataset = ImageEmbeddingDataset(image_paths) | |
| print(" 3.2 DataLoader...") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| ) | |
| all_embeddings = [] | |
| all_paths = [] | |
| print(" 3.3 tqdm iterator...\n") | |
| iterator = tqdm.tqdm(dataloader, desc="Generating embeddings") if show_progress else dataloader | |
| print(" 3.4 for loop...") | |
| for batch_images, batch_paths in iterator: | |
| batch_images = batch_images.to(self.device, non_blocking=True) | |
| # Get embeddings: average pool patch tokens for global representation | |
| features = self._get_features(batch_images) # (B, N, D) | |
| embeddings = features.mean(dim=1) # (B, D) | |
| # L2 normalization for cosine similarity | |
| embeddings = nn.functional.normalize(embeddings, p=2, dim=1) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| all_paths.extend(batch_paths) | |
| embeddings = np.vstack(all_embeddings) | |
| if return_paths: | |
| return embeddings, all_paths | |
| return embeddings | |
| def generate_single_embedding( | |
| self, image: Union[str, Path, Image.Image, torch.Tensor] | |
| ) -> np.ndarray: | |
| """Generate embedding for a single image""" | |
| transform = ImageEmbeddingDataset.default_transform() | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert('RGB') | |
| if isinstance(image, Image.Image): | |
| image = transform(image) | |
| if isinstance(image, torch.Tensor): | |
| image = image.unsqueeze(0) if image.dim() == 3 else image | |
| image = image.to(self.device) | |
| with torch.no_grad(): | |
| features = self._get_features(image) | |
| embedding = features.mean(dim=1) | |
| embedding = nn.functional.normalize(embedding, p=2, dim=1) | |
| return embedding.cpu().numpy() | |