""" Cervical Type Classification Model This module contains the BaseCNN model for classifying cervical images into 3 transformation zone types. Usage: from model import BaseCNN # Load pretrained model model = BaseCNN.from_pretrained("./") # Or create from scratch model = BaseCNN( layers=[32, 64, 128, 256], fc_layers=[256, 128], nr_classes=3 ) """ import json from pathlib import Path import torch import torch.nn as nn try: from safetensors.torch import load_file, save_file HAS_SAFETENSORS = True except ImportError: HAS_SAFETENSORS = False class BaseCNN(nn.Module): """ Simple CNN for cervical type classification. Classifies cervical images into 3 transformation zone types: - Type 1: Transformation zone fully visible on ectocervix - Type 2: Transformation zone partially visible - Type 3: Transformation zone not visible (within endocervical canal) Args: layers: List of output channels for each conv layer. Default: [32, 64, 128, 256] kernel: Kernel size for conv layers. Default: 3 padding: Padding for conv layers. Default: 1 stride: Stride for conv layers. Default: 1 batchnorm: Whether to use batch normalization. Default: True bn_pre_activ: Whether to apply BN before activation. Default: True activation: Activation function name. Default: 'ReLU' dropout: Dropout rate for FC layers. Default: 0.4 pool: Whether to use max pooling after each conv. Default: True fc_layers: List of FC layer sizes. Default: [256, 128] nr_classes: Number of output classes. Default: 3 in_channels: Number of input channels. Default: 3 """ def __init__( self, layers: list = None, kernel: int = 3, padding: int = 1, stride: int = 1, batchnorm: bool = True, bn_pre_activ: bool = True, activation: str = 'ReLU', dropout: float = 0.4, pool: bool = True, fc_layers: list = None, nr_classes: int = 3, in_channels: int = 3, ): super().__init__() # Store config for serialization self.config = { 'layers': layers or [32, 64, 128, 256], 'kernel': kernel, 'padding': padding, 'stride': stride, 'batchnorm': batchnorm, 'bn_pre_activ': bn_pre_activ, 'activation': activation, 'dropout': dropout, 'pool': pool, 'fc_layers': fc_layers or [256, 128], 'nr_classes': nr_classes, 'in_channels': in_channels, } layers = self.config['layers'] fc_layers = self.config['fc_layers'] # Activation function activation_fn = getattr(nn, activation) # Build convolutional layers (ModuleList to match original) self.conv_layers = nn.ModuleList() prev_channels = in_channels for out_channels in layers: self.conv_layers.append( nn.Conv2d(prev_channels, out_channels, kernel, stride, padding) ) if batchnorm and bn_pre_activ: self.conv_layers.append(nn.BatchNorm2d(out_channels)) self.conv_layers.append(activation_fn()) if batchnorm and not bn_pre_activ: self.conv_layers.append(nn.BatchNorm2d(out_channels)) if pool: self.conv_layers.append(nn.MaxPool2d(2, 2)) prev_channels = out_channels # Global average pooling self.adaptive_pool = nn.AdaptiveAvgPool2d(1) # Build fully connected layers (ModuleList to match original) self.fc_layers = nn.ModuleList() prev_features = layers[-1] for fc_size in fc_layers: self.fc_layers.append(nn.Linear(prev_features, fc_size)) self.fc_layers.append(activation_fn()) self.fc_layers.append(nn.Dropout(dropout)) prev_features = fc_size # Final classifier (separate, to match original) self.classifier = nn.Linear(prev_features, nr_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. Args: x: Input tensor of shape (batch_size, 3, 256, 256) Returns: Logits tensor of shape (batch_size, num_classes) """ for layer in self.conv_layers: x = layer(x) x = self.adaptive_pool(x) x = x.view(x.size(0), -1) for layer in self.fc_layers: x = layer(x) x = self.classifier(x) return x @classmethod def from_pretrained(cls, model_path: str, device: str = 'cpu') -> 'BaseCNN': """ Load a pretrained model from a directory. Args: model_path: Path to directory containing model files device: Device to load model on ('cpu' or 'cuda') Returns: Loaded model in eval mode """ model_path = Path(model_path) # Load config config_path = model_path / 'config.json' with open(config_path, 'r') as f: config = json.load(f) # Create model model = cls(**config['model_config']) # Load weights (prefer safetensors) safetensors_path = model_path / 'model.safetensors' pytorch_path = model_path / 'pytorch_model.bin' if safetensors_path.exists() and HAS_SAFETENSORS: state_dict = load_file(str(safetensors_path), device=device) elif pytorch_path.exists(): state_dict = torch.load(pytorch_path, map_location=device, weights_only=True) else: raise FileNotFoundError(f"No model weights found in {model_path}") model.load_state_dict(state_dict) model.to(device) model.eval() return model def save_pretrained(self, save_path: str) -> None: """ Save model in Hugging Face compatible format. Args: save_path: Directory to save model files """ save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) # Save config config = { 'model_type': 'BaseCNN', 'model_config': self.config, 'num_labels': self.config['nr_classes'], 'id2label': { '0': 'Type 1', '1': 'Type 2', '2': 'Type 3' }, 'label2id': { 'Type 1': 0, 'Type 2': 1, 'Type 3': 2 } } with open(save_path / 'config.json', 'w') as f: json.dump(config, f, indent=2) # Save weights state_dict = {k: v.contiguous() for k, v in self.state_dict().items()} # SafeTensors format (recommended) if HAS_SAFETENSORS: save_file(state_dict, str(save_path / 'model.safetensors')) # PyTorch format (backup) torch.save(state_dict, save_path / 'pytorch_model.bin') # Label mappings ID2LABEL = {0: 'Type 1', 1: 'Type 2', 2: 'Type 3'} LABEL2ID = {'Type 1': 0, 'Type 2': 1, 'Type 3': 2} if __name__ == '__main__': # Quick test model = BaseCNN() print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # Test forward pass x = torch.randn(1, 3, 256, 256) y = model(x) print(f"Input shape: {x.shape}") print(f"Output shape: {y.shape}")