File size: 4,890 Bytes
17daa0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
Cervical Cancer Classification Model
Custom CNN model for classifying cervical images into 4 severity classes.
"""
import torch
import torch.nn as nn
class CervicalCancerCNN(nn.Module):
"""
CNN for cervical cancer classification.
Classifies cervical images into 4 classes:
- 0: Normal
- 1: LSIL (Low-grade Squamous Intraepithelial Lesion)
- 2: HSIL (High-grade Squamous Intraepithelial Lesion)
- 3: Cancer
Args:
config: Optional configuration dict with keys:
- conv_layers: List of conv channel sizes (default: [32, 64, 128, 256])
- fc_layers: List of FC layer sizes (default: [256, 128])
- num_classes: Number of output classes (default: 4)
- dropout: Dropout rate (default: 0.5)
"""
def __init__(self, config=None):
super().__init__()
# Default config
self.config = config or {
"conv_layers": [32, 64, 128, 256],
"fc_layers": [256, 128],
"num_classes": 4,
"dropout": 0.5,
"input_channels": 3,
}
conv_channels = self.config.get("conv_layers", [32, 64, 128, 256])
fc_sizes = self.config.get("fc_layers", [256, 128])
dropout = self.config.get("dropout", 0.5)
num_classes = self.config.get("num_classes", 4)
input_channels = self.config.get("input_channels", 3)
# Build convolutional layers
layers = []
in_channels = input_channels
for out_channels in conv_channels:
layers.extend([
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
])
in_channels = out_channels
self.conv_layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
# Build fully connected layers
fc_blocks = []
in_features = conv_channels[-1]
for fc_size in fc_sizes:
fc_blocks.extend([
nn.Linear(in_features, fc_size),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
in_features = fc_size
self.fc_layers = nn.Sequential(*fc_blocks)
self.classifier = nn.Linear(in_features, num_classes)
# Class labels
self.id2label = {
0: "Normal",
1: "LSIL",
2: "HSIL",
3: "Cancer"
}
self.label2id = {v: k for k, v in self.id2label.items()}
def forward(self, x):
"""
Forward pass.
Args:
x: Input tensor of shape (batch, 3, height, width)
Returns:
Logits tensor of shape (batch, num_classes)
"""
x = self.conv_layers(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
x = self.classifier(x)
return x
def predict(self, x):
"""
Predict class labels.
Args:
x: Input tensor of shape (batch, 3, height, width)
Returns:
Tuple of (predicted_class_ids, probabilities)
"""
self.eval()
with torch.no_grad():
logits = self.forward(x)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
return preds, probs
@classmethod
def from_pretrained(cls, model_path, device="cpu"):
"""
Load pretrained model.
Args:
model_path: Path to model directory or checkpoint file
device: Device to load model on
Returns:
Loaded model
"""
import os
from pathlib import Path
model_path = Path(model_path)
# Try different file formats
if model_path.is_dir():
if (model_path / "model.safetensors").exists():
weights_path = model_path / "model.safetensors"
use_safetensors = True
elif (model_path / "pytorch_model.bin").exists():
weights_path = model_path / "pytorch_model.bin"
use_safetensors = False
else:
raise FileNotFoundError(f"No model weights found in {model_path}")
else:
weights_path = model_path
use_safetensors = str(model_path).endswith(".safetensors")
# Create model
model = cls()
# Load weights
if use_safetensors:
from safetensors.torch import load_file
state_dict = load_file(str(weights_path))
else:
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
|