File size: 2,057 Bytes
217bd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from transformers import AutoImageProcessor, AutoModel
import torch

def get_dinov2_model(model_type="vits14"):
    """Get DINOv2 model that returns full hidden states"""
    model_map = {
        'vits14': 'facebook/dinov2-small',
        'vitb14': 'facebook/dinov2-base',
        'vitl14': 'facebook/dinov2-large',
        'vitg14': 'facebook/dinov2-giant'
    }
    
    model = AutoModel.from_pretrained(model_map[model_type])
    return model  

def get_feature_dim(model_type):
    """Get feature dimension based on model type"""
    dims = {
        'vits14': 384,
        'vitb14': 768,
        'vitl14': 1024,
        'vitg14': 1536
    }
    return dims[model_type]

def extract_features(image_features, pooling_type='cls'):
    """Extract features using different pooling strategies"""
    # image_features should be last_hidden_states with shape [batch_size, num_patches+1, hidden_dim]
    batch_size = image_features.shape[0]
    
    if pooling_type == 'cls':
        return image_features[:, 0]  # get CLS token
    elif pooling_type == 'avg':
        return torch.mean(image_features[:, 1:], dim=1)  # average over patches
    elif pooling_type == 'max':
        return torch.max(image_features[:, 1:], dim=1)[0]  # max over patches
    elif pooling_type == 'cls_max':
        cls_token = image_features[:, 0]
        max_pool = torch.max(image_features[:, 1:], dim=1)[0]
        return torch.cat([cls_token, max_pool], dim=-1)
    elif pooling_type == 'cls_avg':
        cls_token = image_features[:, 0]
        avg_pool = torch.mean(image_features[:, 1:], dim=1)
        return torch.cat([cls_token, avg_pool], dim=-1)
    else:
        raise ValueError(f"Unknown pooling type: {pooling_type}")

def get_pooling_dim(base_dim, pooling_type):
    """Returns the final feature dimension according to the pooling type"""
    if pooling_type in ['cls', 'avg', 'max']:
        return base_dim
    elif pooling_type in ['cls_max', 'cls_avg']:
        return base_dim * 2
    else:
        raise ValueError(f"Unknown pooling type: {pooling_type}")