import torch import torch.nn as nn from transformers import AutoModel class DinoV3LinearMultiLinear(nn.Module): def __init__(self, backbone: AutoModel, hidden_size: int, num_classes: int, freeze_backbone: bool = True): super().__init__() self.backbone = backbone self.num_classes = num_classes if freeze_backbone: for p in self.backbone.parameters(): p.requires_grad = False self.backbone.eval() # three linear layers like in the original syke-pic model # hidden size -> 256 -> 128 -> num_classes self.linear1 = nn.Linear(hidden_size, 256) self.linear2 = nn.Linear(256, 128) self.linear3 = nn.Linear(128, self.num_classes) def print_num_trainable_parameters(self): print(f"Number of trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}") def forward(self, pixel_values): outputs = self.backbone(pixel_values=pixel_values) last_hidden = outputs.last_hidden_state cls = last_hidden[:, 0] logits = self.linear3(self.linear2(self.linear1(cls))) return logits def predict(self, pixel_values, temperature=1.3): """ Generate probability predictions for a batch of images. Args: pixel_values: Preprocessed image tensor (batch_size, 3, H, W) temperature: Temperature for softmax calibration (default 1.3) Returns: probs: Probability distribution over classes (shape: [batch_size, num_classes]) """ logits = self.forward(pixel_values) probs = torch.softmax(logits / temperature, dim=1) return probs