| """ |
| Cervical Type Classification Model |
| |
| This module contains the BaseCNN model for classifying cervical images |
| into 3 transformation zone types. |
| |
| Usage: |
| from model import BaseCNN |
| |
| # Load pretrained model |
| model = BaseCNN.from_pretrained("./") |
| |
| # Or create from scratch |
| model = BaseCNN( |
| layers=[32, 64, 128, 256], |
| fc_layers=[256, 128], |
| nr_classes=3 |
| ) |
| """ |
|
|
| import json |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
|
|
| try: |
| from safetensors.torch import load_file, save_file |
| HAS_SAFETENSORS = True |
| except ImportError: |
| HAS_SAFETENSORS = False |
|
|
|
|
| class BaseCNN(nn.Module): |
| """ |
| Simple CNN for cervical type classification. |
| |
| Classifies cervical images into 3 transformation zone types: |
| - Type 1: Transformation zone fully visible on ectocervix |
| - Type 2: Transformation zone partially visible |
| - Type 3: Transformation zone not visible (within endocervical canal) |
| |
| Args: |
| layers: List of output channels for each conv layer. Default: [32, 64, 128, 256] |
| kernel: Kernel size for conv layers. Default: 3 |
| padding: Padding for conv layers. Default: 1 |
| stride: Stride for conv layers. Default: 1 |
| batchnorm: Whether to use batch normalization. Default: True |
| bn_pre_activ: Whether to apply BN before activation. Default: True |
| activation: Activation function name. Default: 'ReLU' |
| dropout: Dropout rate for FC layers. Default: 0.4 |
| pool: Whether to use max pooling after each conv. Default: True |
| fc_layers: List of FC layer sizes. Default: [256, 128] |
| nr_classes: Number of output classes. Default: 3 |
| in_channels: Number of input channels. Default: 3 |
| """ |
|
|
| def __init__( |
| self, |
| layers: list = None, |
| kernel: int = 3, |
| padding: int = 1, |
| stride: int = 1, |
| batchnorm: bool = True, |
| bn_pre_activ: bool = True, |
| activation: str = 'ReLU', |
| dropout: float = 0.4, |
| pool: bool = True, |
| fc_layers: list = None, |
| nr_classes: int = 3, |
| in_channels: int = 3, |
| ): |
| super().__init__() |
|
|
| |
| self.config = { |
| 'layers': layers or [32, 64, 128, 256], |
| 'kernel': kernel, |
| 'padding': padding, |
| 'stride': stride, |
| 'batchnorm': batchnorm, |
| 'bn_pre_activ': bn_pre_activ, |
| 'activation': activation, |
| 'dropout': dropout, |
| 'pool': pool, |
| 'fc_layers': fc_layers or [256, 128], |
| 'nr_classes': nr_classes, |
| 'in_channels': in_channels, |
| } |
|
|
| layers = self.config['layers'] |
| fc_layers = self.config['fc_layers'] |
|
|
| |
| activation_fn = getattr(nn, activation) |
|
|
| |
| self.conv_layers = nn.ModuleList() |
| prev_channels = in_channels |
|
|
| for out_channels in layers: |
| self.conv_layers.append( |
| nn.Conv2d(prev_channels, out_channels, kernel, stride, padding) |
| ) |
| if batchnorm and bn_pre_activ: |
| self.conv_layers.append(nn.BatchNorm2d(out_channels)) |
| self.conv_layers.append(activation_fn()) |
| if batchnorm and not bn_pre_activ: |
| self.conv_layers.append(nn.BatchNorm2d(out_channels)) |
| if pool: |
| self.conv_layers.append(nn.MaxPool2d(2, 2)) |
| prev_channels = out_channels |
|
|
| |
| self.adaptive_pool = nn.AdaptiveAvgPool2d(1) |
|
|
| |
| self.fc_layers = nn.ModuleList() |
| prev_features = layers[-1] |
|
|
| for fc_size in fc_layers: |
| self.fc_layers.append(nn.Linear(prev_features, fc_size)) |
| self.fc_layers.append(activation_fn()) |
| self.fc_layers.append(nn.Dropout(dropout)) |
| prev_features = fc_size |
|
|
| |
| self.classifier = nn.Linear(prev_features, nr_classes) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass. |
| |
| Args: |
| x: Input tensor of shape (batch_size, 3, 256, 256) |
| |
| Returns: |
| Logits tensor of shape (batch_size, num_classes) |
| """ |
| for layer in self.conv_layers: |
| x = layer(x) |
|
|
| x = self.adaptive_pool(x) |
| x = x.view(x.size(0), -1) |
|
|
| for layer in self.fc_layers: |
| x = layer(x) |
|
|
| x = self.classifier(x) |
| return x |
|
|
| @classmethod |
| def from_pretrained(cls, model_path: str, device: str = 'cpu') -> 'BaseCNN': |
| """ |
| Load a pretrained model from a directory. |
| |
| Args: |
| model_path: Path to directory containing model files |
| device: Device to load model on ('cpu' or 'cuda') |
| |
| Returns: |
| Loaded model in eval mode |
| """ |
| model_path = Path(model_path) |
|
|
| |
| config_path = model_path / 'config.json' |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
|
|
| |
| model = cls(**config['model_config']) |
|
|
| |
| safetensors_path = model_path / 'model.safetensors' |
| pytorch_path = model_path / 'pytorch_model.bin' |
|
|
| if safetensors_path.exists() and HAS_SAFETENSORS: |
| state_dict = load_file(str(safetensors_path), device=device) |
| elif pytorch_path.exists(): |
| state_dict = torch.load(pytorch_path, map_location=device, weights_only=True) |
| else: |
| raise FileNotFoundError(f"No model weights found in {model_path}") |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| return model |
|
|
| def save_pretrained(self, save_path: str) -> None: |
| """ |
| Save model in Hugging Face compatible format. |
| |
| Args: |
| save_path: Directory to save model files |
| """ |
| save_path = Path(save_path) |
| save_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| config = { |
| 'model_type': 'BaseCNN', |
| 'model_config': self.config, |
| 'num_labels': self.config['nr_classes'], |
| 'id2label': { |
| '0': 'Type 1', |
| '1': 'Type 2', |
| '2': 'Type 3' |
| }, |
| 'label2id': { |
| 'Type 1': 0, |
| 'Type 2': 1, |
| 'Type 3': 2 |
| } |
| } |
| with open(save_path / 'config.json', 'w') as f: |
| json.dump(config, f, indent=2) |
|
|
| |
| state_dict = {k: v.contiguous() for k, v in self.state_dict().items()} |
|
|
| |
| if HAS_SAFETENSORS: |
| save_file(state_dict, str(save_path / 'model.safetensors')) |
|
|
| |
| torch.save(state_dict, save_path / 'pytorch_model.bin') |
|
|
|
|
| |
| ID2LABEL = {0: 'Type 1', 1: 'Type 2', 2: 'Type 3'} |
| LABEL2ID = {'Type 1': 0, 'Type 2': 1, 'Type 3': 2} |
|
|
|
|
| if __name__ == '__main__': |
| |
| model = BaseCNN() |
| print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| x = torch.randn(1, 3, 256, 256) |
| y = model(x) |
| print(f"Input shape: {x.shape}") |
| print(f"Output shape: {y.shape}") |
|
|