Spaces:
Runtime error
Runtime error
| import lightning as pl | |
| import torch.nn as nn | |
| import torch | |
| from timm import create_model | |
| from torchmetrics.classification import Accuracy | |
| from torch.optim.lr_scheduler import StepLR | |
| import torch.optim as optim | |
| from loguru import logger | |
| logger.add("logs/model.log", rotation="1 MB", level="INFO") | |
| class LitEfficientNet(pl.LightningModule): | |
| def __init__( | |
| self, | |
| model_name="tf_efficientnet_lite0", | |
| num_classes=10, | |
| lr=1e-3, | |
| custom_loss=None, | |
| ): | |
| """ | |
| Initializes a CNN model from TIMM and integrates TorchMetrics. | |
| Args: | |
| model_name (str): TIMM model name (e.g., "tf_efficientnet_lite0"). | |
| num_classes (int): Number of output classes (e.g., 0β9 for MNIST). | |
| lr (float): Learning rate for the optimizer. | |
| custom_loss (callable, optional): Custom loss function. Defaults to CrossEntropyLoss. | |
| """ | |
| super().__init__() | |
| self.lr = lr | |
| self.model = create_model( | |
| model_name, | |
| pretrained=True, | |
| num_classes=num_classes, | |
| in_chans=1, # Set to 1 channel for grayscale input | |
| ) | |
| self.loss_fn = custom_loss or nn.CrossEntropyLoss() | |
| self.train_acc = Accuracy(num_classes=num_classes, task="multiclass") | |
| self.val_acc = Accuracy(num_classes=num_classes, task="multiclass") | |
| self.test_acc = Accuracy(num_classes=num_classes, task="multiclass") | |
| logger.info(f"Model initialized with TIMM backbone: {model_name}") | |
| logger.info(f"Number of output classes: {num_classes}") | |
| def forward(self, x): | |
| """ | |
| Forward pass of the model. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Model predictions. | |
| """ | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_hat = self(x) | |
| loss = self.loss_fn(y_hat, y) | |
| self.train_acc.update(y_hat, y) | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| self.log("train_acc", self.train_acc, prog_bar=True, logger=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_hat = self(x) | |
| loss = self.loss_fn(y_hat, y) | |
| self.val_acc.update(y_hat, y) | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| self.log("val_acc", self.val_acc, prog_bar=True, logger=True) | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_hat = self(x) | |
| self.test_acc.update(y_hat, y) | |
| self.log("test_acc", self.test_acc, prog_bar=True, logger=True) | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.parameters(), lr=self.lr) | |
| scheduler = StepLR(optimizer, step_size=1, gamma=0.9) | |
| logger.info(f"Optimizer: Adam, Learning Rate: {self.lr}") | |
| logger.info("Scheduler: StepLR with step_size=1 and gamma=0.9") | |
| return [optimizer], [scheduler] | |