| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig |
| from transformers.utils import ModelOutput |
|
|
|
|
| @dataclass |
| class SiglipForImageClassifierOutput(ModelOutput): |
| loss: torch.FloatTensor | None = None |
| logits: torch.FloatTensor | None = None |
| pooler_output: torch.FloatTensor | None = None |
| hidden_states: tuple[torch.FloatTensor, ...] | None = None |
| attentions: tuple[torch.FloatTensor, ...] | None = None |
|
|
|
|
| class SiglipForImageClassification(SiglipPreTrainedModel): |
| config_class = SiglipVisionConfig |
| main_input_name = "pixel_values" |
|
|
| def __init__( |
| self, |
| config, |
| ): |
| super().__init__(config) |
|
|
| self.num_labels = config.num_labels |
| self.siglip = SiglipVisionModel(config) |
|
|
| |
| self.classifier = ( |
| nn.Linear(config.hidden_size, config.num_labels) |
| if config.num_labels > 0 |
| else nn.Identity() |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None |
| ): |
| outputs = self.siglip(pixel_values) |
| pooler_output = outputs.pooler_output |
| logits = self.classifier(pooler_output) |
|
|
| loss = None |
|
|
| return SiglipForImageClassifierOutput( |
| loss=loss, |
| logits=logits, |
| pooler_output=outputs.pooler_output, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|