""" Cervical Cancer Classification Model This file provides the model architecture for easy import. Usage: from model import CervicalCancerCNN, load_model, predict model = load_model("model.safetensors") result = predict(model, image_tensor) """ import torch import torch.nn as nn from pathlib import Path class CervicalCancerCNN(nn.Module): """ CNN for cervical cancer classification. Classifies cervical colposcopy images into 4 severity classes: - 0: Normal - Healthy cervical tissue - 1: LSIL - Low-grade Squamous Intraepithelial Lesion - 2: HSIL - High-grade Squamous Intraepithelial Lesion - 3: Cancer - Invasive cervical cancer Architecture: Conv[32,64,128,256] -> AvgPool -> FC[256,128] -> Classifier[4] Input: Tensor of shape (batch, 3, 224, 298) Output: Logits of shape (batch, 4) """ # Class labels CLASSES = { 0: "Normal", 1: "LSIL", 2: "HSIL", 3: "Cancer" } def __init__(self, config=None): super().__init__() # Default configuration config = config or {} conv_channels = config.get("conv_layers", [32, 64, 128, 256]) fc_sizes = config.get("fc_layers", [256, 128]) dropout = config.get("dropout", 0.5) num_classes = config.get("num_classes", 4) # Build convolutional layers layers = [] in_channels = 3 for out_channels in conv_channels: layers.extend([ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), ]) in_channels = out_channels self.conv_layers = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) # Build fully connected layers fc_blocks = [] in_features = conv_channels[-1] for fc_size in fc_sizes: fc_blocks.extend([ nn.Linear(in_features, fc_size), nn.ReLU(inplace=True), nn.Dropout(dropout), ]) in_features = fc_size self.fc_layers = nn.Sequential(*fc_blocks) self.classifier = nn.Linear(in_features, num_classes) def forward(self, x): """Forward pass.""" x = self.conv_layers(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc_layers(x) x = self.classifier(x) return x def predict_class(self, x): """Predict class labels and probabilities.""" self.eval() with torch.no_grad(): logits = self.forward(x) probs = torch.softmax(logits, dim=1) preds = torch.argmax(logits, dim=1) return preds, probs def load_model(model_path, device="cpu"): """ Load model from file. Args: model_path: Path to model weights (.safetensors or .bin/.pth) device: Device to load model on ("cpu" or "cuda") Returns: Loaded model in eval mode """ model = CervicalCancerCNN() model_path = Path(model_path) if model_path.suffix == ".safetensors": from safetensors.torch import load_file state_dict = load_file(str(model_path)) else: checkpoint = torch.load(model_path, map_location=device, weights_only=False) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint model.load_state_dict(state_dict) model.to(device) model.eval() return model def predict(model, image_tensor, device="cpu"): """ Run prediction on an image tensor. Args: model: Loaded CervicalCancerCNN model image_tensor: Preprocessed image tensor (1, 3, 224, 298) device: Device for inference Returns: Dictionary with prediction results """ model.eval() image_tensor = image_tensor.to(device) with torch.no_grad(): logits = model(image_tensor) probs = torch.softmax(logits, dim=1)[0] pred_class = torch.argmax(logits, dim=1).item() return { "class_id": pred_class, "class_name": CervicalCancerCNN.CLASSES[pred_class], "confidence": probs[pred_class].item(), "probabilities": { CervicalCancerCNN.CLASSES[i]: probs[i].item() for i in range(4) } } def get_preprocessing_transform(): """ Get the preprocessing transform for input images. Returns: torchvision.transforms.Compose object """ import torchvision.transforms as T return T.Compose([ T.Resize((224, 298)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Quick usage example if __name__ == "__main__": import sys # Create model model = CervicalCancerCNN() print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") # Print architecture print("\nArchitecture:") print(model) # Test forward pass dummy_input = torch.randn(1, 3, 224, 298) output = model(dummy_input) print(f"\nInput shape: {dummy_input.shape}") print(f"Output shape: {output.shape}") print(f"Output classes: {list(CervicalCancerCNN.CLASSES.values())}")