|
|
""" |
|
|
Cervical Cancer Classification Model |
|
|
|
|
|
Custom CNN model for classifying cervical images into 4 severity classes. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class CervicalCancerCNN(nn.Module): |
|
|
""" |
|
|
CNN for cervical cancer classification. |
|
|
|
|
|
Classifies cervical images into 4 classes: |
|
|
- 0: Normal |
|
|
- 1: LSIL (Low-grade Squamous Intraepithelial Lesion) |
|
|
- 2: HSIL (High-grade Squamous Intraepithelial Lesion) |
|
|
- 3: Cancer |
|
|
|
|
|
Args: |
|
|
config: Optional configuration dict with keys: |
|
|
- conv_layers: List of conv channel sizes (default: [32, 64, 128, 256]) |
|
|
- fc_layers: List of FC layer sizes (default: [256, 128]) |
|
|
- num_classes: Number of output classes (default: 4) |
|
|
- dropout: Dropout rate (default: 0.5) |
|
|
""" |
|
|
|
|
|
def __init__(self, config=None): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.config = config or { |
|
|
"conv_layers": [32, 64, 128, 256], |
|
|
"fc_layers": [256, 128], |
|
|
"num_classes": 4, |
|
|
"dropout": 0.5, |
|
|
"input_channels": 3, |
|
|
} |
|
|
|
|
|
conv_channels = self.config.get("conv_layers", [32, 64, 128, 256]) |
|
|
fc_sizes = self.config.get("fc_layers", [256, 128]) |
|
|
dropout = self.config.get("dropout", 0.5) |
|
|
num_classes = self.config.get("num_classes", 4) |
|
|
input_channels = self.config.get("input_channels", 3) |
|
|
|
|
|
|
|
|
layers = [] |
|
|
in_channels = input_channels |
|
|
|
|
|
for out_channels in conv_channels: |
|
|
layers.extend([ |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(out_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
]) |
|
|
in_channels = out_channels |
|
|
|
|
|
self.conv_layers = nn.Sequential(*layers) |
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
|
|
|
fc_blocks = [] |
|
|
in_features = conv_channels[-1] |
|
|
|
|
|
for fc_size in fc_sizes: |
|
|
fc_blocks.extend([ |
|
|
nn.Linear(in_features, fc_size), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout), |
|
|
]) |
|
|
in_features = fc_size |
|
|
|
|
|
self.fc_layers = nn.Sequential(*fc_blocks) |
|
|
self.classifier = nn.Linear(in_features, num_classes) |
|
|
|
|
|
|
|
|
self.id2label = { |
|
|
0: "Normal", |
|
|
1: "LSIL", |
|
|
2: "HSIL", |
|
|
3: "Cancer" |
|
|
} |
|
|
self.label2id = {v: k for k, v in self.id2label.items()} |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, 3, height, width) |
|
|
|
|
|
Returns: |
|
|
Logits tensor of shape (batch, num_classes) |
|
|
""" |
|
|
x = self.conv_layers(x) |
|
|
x = self.avgpool(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
x = self.fc_layers(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
def predict(self, x): |
|
|
""" |
|
|
Predict class labels. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, 3, height, width) |
|
|
|
|
|
Returns: |
|
|
Tuple of (predicted_class_ids, probabilities) |
|
|
""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
logits = self.forward(x) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
preds = torch.argmax(logits, dim=1) |
|
|
return preds, probs |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path, device="cpu"): |
|
|
""" |
|
|
Load pretrained model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model directory or checkpoint file |
|
|
device: Device to load model on |
|
|
|
|
|
Returns: |
|
|
Loaded model |
|
|
""" |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
model_path = Path(model_path) |
|
|
|
|
|
|
|
|
if model_path.is_dir(): |
|
|
if (model_path / "model.safetensors").exists(): |
|
|
weights_path = model_path / "model.safetensors" |
|
|
use_safetensors = True |
|
|
elif (model_path / "pytorch_model.bin").exists(): |
|
|
weights_path = model_path / "pytorch_model.bin" |
|
|
use_safetensors = False |
|
|
else: |
|
|
raise FileNotFoundError(f"No model weights found in {model_path}") |
|
|
else: |
|
|
weights_path = model_path |
|
|
use_safetensors = str(model_path).endswith(".safetensors") |
|
|
|
|
|
|
|
|
model = cls() |
|
|
|
|
|
|
|
|
if use_safetensors: |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(str(weights_path)) |
|
|
else: |
|
|
state_dict = torch.load(weights_path, map_location=device, weights_only=True) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|