| | import torch.nn as nn |
| |
|
| | |
| |
|
| | from transformers import PreTrainedModel |
| |
|
| | from .configuration_spice_cnn import SpiceCNNConfig |
| |
|
| |
|
| | class SpiceCNNModelForImageClassification(PreTrainedModel): |
| | config_class = SpiceCNNConfig |
| |
|
| | def __init__(self, config: SpiceCNNConfig): |
| | super().__init__(config) |
| | layers = [ |
| | nn.Conv2d( |
| | config.in_channels, 16, kernel_size=config.kernel_size, padding=1 |
| | ), |
| | nn.BatchNorm2d(16), |
| | nn.ReLU(), |
| | nn.MaxPool2d(kernel_size=config.pooling_size), |
| | nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1), |
| | nn.BatchNorm2d(32), |
| | nn.ReLU(), |
| | nn.MaxPool2d(kernel_size=config.pooling_size), |
| | nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(), |
| | nn.MaxPool2d(kernel_size=config.pooling_size), |
| | nn.Flatten(), |
| | nn.Linear(64 * 3 * 3, 128), |
| | nn.ReLU(), |
| | nn.Dropout(0.5), |
| | nn.Linear(128, config.num_classes), |
| | ] |
| | self.model = nn.Sequential(*layers) |
| |
|
| | def forward(self, tensor, labels=None): |
| | logits = self.model(tensor) |
| | if labels is not None: |
| | loss_fnc = nn.CrossEntropyLoss() |
| | loss = loss_fnc(logits, labels) |
| | return {"loss": loss, "logits": logits} |
| | return {"logits": logits} |
| |
|
| |
|
| | |
| | |
| | |
| |
|