| from typing import Dict, List, Optional, Tuple | |
| import timm | |
| import torch | |
| from pytorch_lightning import LightningModule | |
| class TemplateClassifier(LightningModule): | |
| def __init__(self, config: dict): | |
| super().__init__() | |
| # NN architecture | |
| self.backbone = timm.create_model( | |
| #SPECIFY HERE YOUR MODEL | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # WRITE YOU CODE HERE | |
| predictions=None | |
| return predictions | |