skin-type-classifier / modeling_skin_classifier.py
anismizi's picture
Initial model upload
2385a75 verified
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import ImageClassifierOutput
from typing import Optional
class SkinClassifierConfig(PretrainedConfig):
"""Configuration class for SkinClassifier model."""
model_type = "skin-classifier"
def __init__(
self,
num_labels: int = 2,
image_size: int = 224,
num_channels: int = 3,
**kwargs
):
super().__init__(**kwargs)
self.num_labels = num_labels
self.image_size = image_size
self.num_channels = num_channels
class SkinClassifierModel(PreTrainedModel):
"""
Skin Type Classification Model based on ResNet50.
This model classifies skin images into two categories:
- dry (label 0)
- oily (label 1)
"""
config_class = SkinClassifierConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize ResNet50 backbone
self.resnet = resnet50(weights=None)
# Replace the final classification layer
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, config.num_labels)
# Initialize weights
self.post_init()
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> ImageClassifierOutput:
"""
Forward pass of the model.
Args:
pixel_values: Tensor of shape (batch_size, num_channels, height, width)
labels: Optional tensor of shape (batch_size,) for training
Returns:
ImageClassifierOutput with logits and optional loss
"""
# Forward pass through ResNet
logits = self.resnet(pixel_values)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
return ImageClassifierOutput(
loss=loss,
logits=logits,
)