|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
class SkinClassifierConfig(PretrainedConfig): |
|
|
"""Configuration class for SkinClassifier model.""" |
|
|
|
|
|
model_type = "skin-classifier" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_labels: int = 2, |
|
|
image_size: int = 224, |
|
|
num_channels: int = 3, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.num_labels = num_labels |
|
|
self.image_size = image_size |
|
|
self.num_channels = num_channels |
|
|
|
|
|
|
|
|
class SkinClassifierModel(PreTrainedModel): |
|
|
""" |
|
|
Skin Type Classification Model based on ResNet50. |
|
|
|
|
|
This model classifies skin images into two categories: |
|
|
- dry (label 0) |
|
|
- oily (label 1) |
|
|
""" |
|
|
|
|
|
config_class = SkinClassifierConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.resnet = resnet50(weights=None) |
|
|
|
|
|
|
|
|
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: torch.FloatTensor, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
**kwargs |
|
|
) -> ImageClassifierOutput: |
|
|
""" |
|
|
Forward pass of the model. |
|
|
|
|
|
Args: |
|
|
pixel_values: Tensor of shape (batch_size, num_channels, height, width) |
|
|
labels: Optional tensor of shape (batch_size,) for training |
|
|
|
|
|
Returns: |
|
|
ImageClassifierOutput with logits and optional loss |
|
|
""" |
|
|
|
|
|
logits = self.resnet(pixel_values) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
|
|
|
return ImageClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|