File size: 1,718 Bytes
dff7e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
from transformers import AutoModel

class DinoV3LinearMultiLinear(nn.Module):
    def __init__(self, backbone: AutoModel, hidden_size: int, num_classes: int, freeze_backbone: bool = True):
        super().__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False
            self.backbone.eval()
        # three linear layers like in the original syke-pic model
        # hidden size -> 256 -> 128 -> num_classes
        self.linear1 = nn.Linear(hidden_size, 256)
        self.linear2 = nn.Linear(256, 128)
        self.linear3 = nn.Linear(128, self.num_classes)

    def print_num_trainable_parameters(self):
        print(f"Number of trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")

    def forward(self, pixel_values):
        outputs = self.backbone(pixel_values=pixel_values)
        last_hidden = outputs.last_hidden_state
        cls = last_hidden[:, 0]
        logits = self.linear3(self.linear2(self.linear1(cls)))
        return logits

    def predict(self, pixel_values, temperature=1.3):
        """
        Generate probability predictions for a batch of images.
        
        Args:
            pixel_values: Preprocessed image tensor (batch_size, 3, H, W)
            temperature: Temperature for softmax calibration (default 1.3)
        
        Returns:
            probs: Probability distribution over classes (shape: [batch_size, num_classes])
        """
        logits = self.forward(pixel_values)
        probs = torch.softmax(logits / temperature, dim=1)
        return probs