File size: 2,773 Bytes
ed4e653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Feature extractor using WideResNet-101-2 backbone (timm).
Extracts multi-scale patch descriptors from layer2 (stride 8) and layer3 (stride 16).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


class PatchFeatureExtractor(nn.Module):
    """
    Extracts L2-normalized patch descriptors from WideResNet-101-2.

    For a 224×224 input:
      - layer2 output: [B, 512,  28, 28]
      - layer3 output: [B, 1024, 14, 14]

    Both are spatially aligned to 28×28 via adaptive avg pooling on layer3,
    then concatenated → [B, 1536, 28, 28].
    Each of the 28×28 = 784 spatial locations yields a 1536-dim descriptor.

    No training — runs in eval mode with torch.no_grad().
    """

    def __init__(self, backbone: str = "wide_resnet101_2", pretrained: bool = True):
        super().__init__()
        # features_only=True exposes intermediate feature maps
        self.backbone = timm.create_model(
            backbone,
            pretrained=pretrained,
            features_only=True,
            out_indices=(1, 2),   # layer2=index1 (512ch), layer3=index2 (1024ch)
        )
        self.backbone.eval()
        for param in self.backbone.parameters():
            param.requires_grad_(False)

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, 3, 224, 224] ImageNet-normalised images.
        Returns:
            patch_features: [B, H*W, 1536] L2-normalised patch descriptors
                            where H=W=28 for 224px input.
        """
        features = self.backbone(x)   # list: [layer2, layer3]
        f2 = features[0]              # [B, 512,  28, 28]
        f3 = features[1]              # [B, 1024, 14, 14]

        # Upsample layer3 to match layer2 spatial size
        f3_up = F.adaptive_avg_pool2d(f3, output_size=f2.shape[-2:])  # [B, 1024, 28, 28]

        # Concatenate along channel dim
        combined = torch.cat([f2, f3_up], dim=1)  # [B, 1536, 28, 28]

        B, C, H, W = combined.shape
        # Reshape to [B, H*W, C]
        patch_features = combined.permute(0, 2, 3, 1).reshape(B, H * W, C)

        # L2 normalize each descriptor
        patch_features = F.normalize(patch_features, p=2, dim=-1)

        return patch_features

    @torch.no_grad()
    def extract_patch_features(self, images: torch.Tensor) -> torch.Tensor:
        """
        Convenience method: flatten all patches across batch.
        Args:
            images: [B, 3, 224, 224]
        Returns:
            [B * H * W, 1536]  — ready for coreset subsampling or memory bank
        """
        patch_features = self.forward(images)  # [B, 784, 1536]
        B, N, C = patch_features.shape
        return patch_features.reshape(B * N, C)