import torch.nn as nn from torchvision import models class ResNet18Classifier(nn.Module): def __init__(self, num_classes: int, dropout: float = 0.4, fc_dim: int = 256): super().__init__() weights = models.ResNet18_Weights.DEFAULT self.backbone = models.resnet18(weights=weights) in_features = self.backbone.fc.in_features # Gel de tout le réseau sauf layer4 et classifieur for param in self.backbone.parameters(): param.requires_grad = False for param in self.backbone.layer4.parameters(): param.requires_grad = True self.backbone.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, fc_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) for param in self.backbone.fc.parameters(): param.requires_grad = True def forward(self, x): return self.backbone(x) class SimpleCNN(nn.Module): def __init__( self, num_classes: int, num_conv_blocks: int = 3, base_filters: int = 32, kernel_size: int = 3, use_batchnorm: bool = True, dropout: float = 0.4, fc_dim: int = 256, ): super().__init__() padding = kernel_size // 2 layers = [] in_channels = 3 for i in range(num_conv_blocks): # Les filtres doublent à chaque bloc, plafonnés à 512 out_channels = min(base_filters * (2 ** i), 512) layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)) if use_batchnorm: layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.MaxPool2d(2, 2)) in_channels = out_channels self.features = nn.Sequential(*layers) # Pooling global : indépendant de la taille spatiale d'entrée self.pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_channels, fc_dim), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(fc_dim, num_classes), ) def forward(self, x): x = self.features(x) x = self.pool(x) x = x.flatten(1) return self.classifier(x)