import torch.nn as nn class SimpleCNN(nn.Module): def __init__( self, num_classes: int, num_conv_blocks: int = 3, base_filters: int = 32, kernel_size: int = 3, use_batchnorm: bool = True, dropout: float = 0.4, fc_dim: int = 256, ): super().__init__() padding = kernel_size // 2 layers = [] in_channels = 3 for i in range(num_conv_blocks): # Les filtres doublent à chaque bloc, plafonnés à 512 out_channels = min(base_filters * (2 ** i), 512) layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)) if use_batchnorm: layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.MaxPool2d(2, 2)) in_channels = out_channels self.features = nn.Sequential(*layers) # Pooling global : indépendant de la taille spatiale d'entrée self.pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_channels, fc_dim), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) def forward(self, x): x = self.features(x) x = self.pool(x) x = x.flatten(1) return self.classifier(x)