import torch import torch.nn as nn from torchvision.models import resnet50, ResNet50_Weights from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import ImageClassifierOutput from typing import Optional class SkinClassifierConfig(PretrainedConfig): """Configuration class for SkinClassifier model.""" model_type = "skin-classifier" def __init__( self, num_labels: int = 2, image_size: int = 224, num_channels: int = 3, **kwargs ): super().__init__(**kwargs) self.num_labels = num_labels self.image_size = image_size self.num_channels = num_channels class SkinClassifierModel(PreTrainedModel): """ Skin Type Classification Model based on ResNet50. This model classifies skin images into two categories: - dry (label 0) - oily (label 1) """ config_class = SkinClassifierConfig def __init__(self, config): super().__init__(config) self.config = config # Initialize ResNet50 backbone self.resnet = resnet50(weights=None) # Replace the final classification layer self.resnet.fc = nn.Linear(self.resnet.fc.in_features, config.num_labels) # Initialize weights self.post_init() def forward( self, pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, **kwargs ) -> ImageClassifierOutput: """ Forward pass of the model. Args: pixel_values: Tensor of shape (batch_size, num_channels, height, width) labels: Optional tensor of shape (batch_size,) for training Returns: ImageClassifierOutput with logits and optional loss """ # Forward pass through ResNet logits = self.resnet(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return ImageClassifierOutput( loss=loss, logits=logits, )