""" Cervical Cancer Classification Model Custom CNN model for classifying cervical images into 4 severity classes. """ import torch import torch.nn as nn class CervicalCancerCNN(nn.Module): """ CNN for cervical cancer classification. Classifies cervical images into 4 classes: - 0: Normal - 1: LSIL (Low-grade Squamous Intraepithelial Lesion) - 2: HSIL (High-grade Squamous Intraepithelial Lesion) - 3: Cancer Args: config: Optional configuration dict with keys: - conv_layers: List of conv channel sizes (default: [32, 64, 128, 256]) - fc_layers: List of FC layer sizes (default: [256, 128]) - num_classes: Number of output classes (default: 4) - dropout: Dropout rate (default: 0.5) """ def __init__(self, config=None): super().__init__() # Default config self.config = config or { "conv_layers": [32, 64, 128, 256], "fc_layers": [256, 128], "num_classes": 4, "dropout": 0.5, "input_channels": 3, } conv_channels = self.config.get("conv_layers", [32, 64, 128, 256]) fc_sizes = self.config.get("fc_layers", [256, 128]) dropout = self.config.get("dropout", 0.5) num_classes = self.config.get("num_classes", 4) input_channels = self.config.get("input_channels", 3) # Build convolutional layers layers = [] in_channels = input_channels for out_channels in conv_channels: layers.extend([ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), ]) in_channels = out_channels self.conv_layers = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) # Build fully connected layers fc_blocks = [] in_features = conv_channels[-1] for fc_size in fc_sizes: fc_blocks.extend([ nn.Linear(in_features, fc_size), nn.ReLU(inplace=True), nn.Dropout(dropout), ]) in_features = fc_size self.fc_layers = nn.Sequential(*fc_blocks) self.classifier = nn.Linear(in_features, num_classes) # Class labels self.id2label = { 0: "Normal", 1: "LSIL", 2: "HSIL", 3: "Cancer" } self.label2id = {v: k for k, v in self.id2label.items()} def forward(self, x): """ Forward pass. Args: x: Input tensor of shape (batch, 3, height, width) Returns: Logits tensor of shape (batch, num_classes) """ x = self.conv_layers(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc_layers(x) x = self.classifier(x) return x def predict(self, x): """ Predict class labels. Args: x: Input tensor of shape (batch, 3, height, width) Returns: Tuple of (predicted_class_ids, probabilities) """ self.eval() with torch.no_grad(): logits = self.forward(x) probs = torch.softmax(logits, dim=1) preds = torch.argmax(logits, dim=1) return preds, probs @classmethod def from_pretrained(cls, model_path, device="cpu"): """ Load pretrained model. Args: model_path: Path to model directory or checkpoint file device: Device to load model on Returns: Loaded model """ import os from pathlib import Path model_path = Path(model_path) # Try different file formats if model_path.is_dir(): if (model_path / "model.safetensors").exists(): weights_path = model_path / "model.safetensors" use_safetensors = True elif (model_path / "pytorch_model.bin").exists(): weights_path = model_path / "pytorch_model.bin" use_safetensors = False else: raise FileNotFoundError(f"No model weights found in {model_path}") else: weights_path = model_path use_safetensors = str(model_path).endswith(".safetensors") # Create model model = cls() # Load weights if use_safetensors: from safetensors.torch import load_file state_dict = load_file(str(weights_path)) else: state_dict = torch.load(weights_path, map_location=device, weights_only=True) model.load_state_dict(state_dict) model.to(device) model.eval() return model