simple-cnn / modeling_simplecnn.py
CezarCalin's picture
Upload SimpleCNN
09ae8a7 verified
raw
history blame contribute delete
843 Bytes
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_simplecnn import SimpleCNNConfig
class SimpleCNN(PreTrainedModel):
config_class = SimpleCNNConfig
def __init__(self, config):
super().__init__(config)
self.conv_layers = nn.Sequential(
nn.Conv2d(config.input_channels, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(32 * 7 * 7, 64),
nn.ReLU(),
nn.Linear(64, config.num_classes)
)
def forward(self, x):
x = self.conv_layers(x)
return self.fc_layers(x)