| |
| import logging |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from mmengine.runner import load_checkpoint |
|
|
|
|
| class AlexNet(nn.Module): |
| """AlexNet backbone. |
| |
| Args: |
| num_classes (int): number of classes for classification. |
| """ |
|
|
| def __init__(self, num_classes: int = -1): |
| super().__init__() |
| self.num_classes = num_classes |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(64, 192, kernel_size=5, padding=2), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| nn.Conv2d(192, 384, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(384, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2), |
| ) |
| if self.num_classes > 0: |
| self.classifier = nn.Sequential( |
| nn.Dropout(), |
| nn.Linear(256 * 6 * 6, 4096), |
| nn.ReLU(inplace=True), |
| nn.Dropout(), |
| nn.Linear(4096, 4096), |
| nn.ReLU(inplace=True), |
| nn.Linear(4096, num_classes), |
| ) |
|
|
| def init_weights(self, pretrained: Optional[str] = None) -> None: |
| if isinstance(pretrained, str): |
| logger = logging.getLogger() |
| load_checkpoint(self, pretrained, strict=False, logger=logger) |
| elif pretrained is None: |
| |
| pass |
| else: |
| raise TypeError('pretrained must be a str or None') |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
| x = self.features(x) |
| if self.num_classes > 0: |
| x = x.view(x.size(0), 256 * 6 * 6) |
| x = self.classifier(x) |
|
|
| return x |
|
|