| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
|
|
| class DinoV3LinearMultiLinear(nn.Module): |
| def __init__(self, backbone: AutoModel, hidden_size: int, num_classes: int, freeze_backbone: bool = True): |
| super().__init__() |
| self.backbone = backbone |
| self.num_classes = num_classes |
| if freeze_backbone: |
| for p in self.backbone.parameters(): |
| p.requires_grad = False |
| self.backbone.eval() |
| |
| |
| self.linear1 = nn.Linear(hidden_size, 256) |
| self.linear2 = nn.Linear(256, 128) |
| self.linear3 = nn.Linear(128, self.num_classes) |
|
|
| def print_num_trainable_parameters(self): |
| print(f"Number of trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}") |
|
|
| def forward(self, pixel_values): |
| outputs = self.backbone(pixel_values=pixel_values) |
| last_hidden = outputs.last_hidden_state |
| cls = last_hidden[:, 0] |
| logits = self.linear3(self.linear2(self.linear1(cls))) |
| return logits |
|
|
| def predict(self, pixel_values, temperature=1.3): |
| """ |
| Generate probability predictions for a batch of images. |
| |
| Args: |
| pixel_values: Preprocessed image tensor (batch_size, 3, H, W) |
| temperature: Temperature for softmax calibration (default 1.3) |
| |
| Returns: |
| probs: Probability distribution over classes (shape: [batch_size, num_classes]) |
| """ |
| logits = self.forward(pixel_values) |
| probs = torch.softmax(logits / temperature, dim=1) |
| return probs |