|
|
""" |
|
|
Cervical Cancer Classification Model |
|
|
|
|
|
This file provides the model architecture for easy import. |
|
|
|
|
|
Usage: |
|
|
from model import CervicalCancerCNN, load_model, predict |
|
|
|
|
|
model = load_model("model.safetensors") |
|
|
result = predict(model, image_tensor) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class CervicalCancerCNN(nn.Module): |
|
|
""" |
|
|
CNN for cervical cancer classification. |
|
|
|
|
|
Classifies cervical colposcopy images into 4 severity classes: |
|
|
- 0: Normal - Healthy cervical tissue |
|
|
- 1: LSIL - Low-grade Squamous Intraepithelial Lesion |
|
|
- 2: HSIL - High-grade Squamous Intraepithelial Lesion |
|
|
- 3: Cancer - Invasive cervical cancer |
|
|
|
|
|
Architecture: |
|
|
Conv[32,64,128,256] -> AvgPool -> FC[256,128] -> Classifier[4] |
|
|
|
|
|
Input: |
|
|
Tensor of shape (batch, 3, 224, 298) |
|
|
|
|
|
Output: |
|
|
Logits of shape (batch, 4) |
|
|
""" |
|
|
|
|
|
|
|
|
CLASSES = { |
|
|
0: "Normal", |
|
|
1: "LSIL", |
|
|
2: "HSIL", |
|
|
3: "Cancer" |
|
|
} |
|
|
|
|
|
def __init__(self, config=None): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
config = config or {} |
|
|
conv_channels = config.get("conv_layers", [32, 64, 128, 256]) |
|
|
fc_sizes = config.get("fc_layers", [256, 128]) |
|
|
dropout = config.get("dropout", 0.5) |
|
|
num_classes = config.get("num_classes", 4) |
|
|
|
|
|
|
|
|
layers = [] |
|
|
in_channels = 3 |
|
|
for out_channels in conv_channels: |
|
|
layers.extend([ |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
]) |
|
|
in_channels = out_channels |
|
|
|
|
|
self.conv_layers = nn.Sequential(*layers) |
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
|
|
|
fc_blocks = [] |
|
|
in_features = conv_channels[-1] |
|
|
for fc_size in fc_sizes: |
|
|
fc_blocks.extend([ |
|
|
nn.Linear(in_features, fc_size), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout), |
|
|
]) |
|
|
in_features = fc_size |
|
|
|
|
|
self.fc_layers = nn.Sequential(*fc_blocks) |
|
|
self.classifier = nn.Linear(in_features, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass.""" |
|
|
x = self.conv_layers(x) |
|
|
x = self.avgpool(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
x = self.fc_layers(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
def predict_class(self, x): |
|
|
"""Predict class labels and probabilities.""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
logits = self.forward(x) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
preds = torch.argmax(logits, dim=1) |
|
|
return preds, probs |
|
|
|
|
|
|
|
|
def load_model(model_path, device="cpu"): |
|
|
""" |
|
|
Load model from file. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model weights (.safetensors or .bin/.pth) |
|
|
device: Device to load model on ("cpu" or "cuda") |
|
|
|
|
|
Returns: |
|
|
Loaded model in eval mode |
|
|
""" |
|
|
model = CervicalCancerCNN() |
|
|
|
|
|
model_path = Path(model_path) |
|
|
|
|
|
if model_path.suffix == ".safetensors": |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(str(model_path)) |
|
|
else: |
|
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
|
|
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: |
|
|
state_dict = checkpoint["model_state_dict"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def predict(model, image_tensor, device="cpu"): |
|
|
""" |
|
|
Run prediction on an image tensor. |
|
|
|
|
|
Args: |
|
|
model: Loaded CervicalCancerCNN model |
|
|
image_tensor: Preprocessed image tensor (1, 3, 224, 298) |
|
|
device: Device for inference |
|
|
|
|
|
Returns: |
|
|
Dictionary with prediction results |
|
|
""" |
|
|
model.eval() |
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(image_tensor) |
|
|
probs = torch.softmax(logits, dim=1)[0] |
|
|
pred_class = torch.argmax(logits, dim=1).item() |
|
|
|
|
|
return { |
|
|
"class_id": pred_class, |
|
|
"class_name": CervicalCancerCNN.CLASSES[pred_class], |
|
|
"confidence": probs[pred_class].item(), |
|
|
"probabilities": { |
|
|
CervicalCancerCNN.CLASSES[i]: probs[i].item() |
|
|
for i in range(4) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def get_preprocessing_transform(): |
|
|
""" |
|
|
Get the preprocessing transform for input images. |
|
|
|
|
|
Returns: |
|
|
torchvision.transforms.Compose object |
|
|
""" |
|
|
import torchvision.transforms as T |
|
|
|
|
|
return T.Compose([ |
|
|
T.Resize((224, 298)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
|
|
|
model = CervicalCancerCNN() |
|
|
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") |
|
|
|
|
|
|
|
|
print("\nArchitecture:") |
|
|
print(model) |
|
|
|
|
|
|
|
|
dummy_input = torch.randn(1, 3, 224, 298) |
|
|
output = model(dummy_input) |
|
|
print(f"\nInput shape: {dummy_input.shape}") |
|
|
print(f"Output shape: {output.shape}") |
|
|
print(f"Output classes: {list(CervicalCancerCNN.CLASSES.values())}") |
|
|
|