danielaivanova's picture
Upload folder using huggingface_hub
dff7e68 verified
import torch
import torch.nn as nn
from transformers import AutoModel
class DinoV3LinearMultiLinear(nn.Module):
def __init__(self, backbone: AutoModel, hidden_size: int, num_classes: int, freeze_backbone: bool = True):
super().__init__()
self.backbone = backbone
self.num_classes = num_classes
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone.eval()
# three linear layers like in the original syke-pic model
# hidden size -> 256 -> 128 -> num_classes
self.linear1 = nn.Linear(hidden_size, 256)
self.linear2 = nn.Linear(256, 128)
self.linear3 = nn.Linear(128, self.num_classes)
def print_num_trainable_parameters(self):
print(f"Number of trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
def forward(self, pixel_values):
outputs = self.backbone(pixel_values=pixel_values)
last_hidden = outputs.last_hidden_state
cls = last_hidden[:, 0]
logits = self.linear3(self.linear2(self.linear1(cls)))
return logits
def predict(self, pixel_values, temperature=1.3):
"""
Generate probability predictions for a batch of images.
Args:
pixel_values: Preprocessed image tensor (batch_size, 3, H, W)
temperature: Temperature for softmax calibration (default 1.3)
Returns:
probs: Probability distribution over classes (shape: [batch_size, num_classes])
"""
logits = self.forward(pixel_values)
probs = torch.softmax(logits / temperature, dim=1)
return probs