| """ | |
| @author: bvk1ng (Adityam Ghosh) | |
| Date: 12/28/2023 | |
| """ | |
| from typing import Any, List, Tuple, Dict, Union, Callable | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class CNNModel(nn.Module): | |
| def __init__(self, K: int, cnn_params: List, fully_connected_params: List): | |
| super().__init__() | |
| self.network = nn.Sequential() | |
| for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params): | |
| self.network.add_module( | |
| f"conv2d_{idx}", | |
| nn.LazyConv2d( | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| ), | |
| ) | |
| self.network.add_module(f"activation_{idx}", nn.ReLU()) | |
| self.network.add_module("flatten", nn.Flatten()) | |
| for idx, out_feats in enumerate(fully_connected_params): | |
| self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats)) | |
| self.network.add_module(f"fc_activation_{idx}", nn.ReLU()) | |
| self.network.add_module("final_layer", nn.LazyLinear(out_features=K)) | |
| def forward(self, X: torch.Tensor) -> torch.Tensor: | |
| return self.network(X) | |