cervical_lesion / modeling_cervical.py
toderian's picture
Upload folder using huggingface_hub
17daa0b verified
"""
Cervical Cancer Classification Model
Custom CNN model for classifying cervical images into 4 severity classes.
"""
import torch
import torch.nn as nn
class CervicalCancerCNN(nn.Module):
"""
CNN for cervical cancer classification.
Classifies cervical images into 4 classes:
- 0: Normal
- 1: LSIL (Low-grade Squamous Intraepithelial Lesion)
- 2: HSIL (High-grade Squamous Intraepithelial Lesion)
- 3: Cancer
Args:
config: Optional configuration dict with keys:
- conv_layers: List of conv channel sizes (default: [32, 64, 128, 256])
- fc_layers: List of FC layer sizes (default: [256, 128])
- num_classes: Number of output classes (default: 4)
- dropout: Dropout rate (default: 0.5)
"""
def __init__(self, config=None):
super().__init__()
# Default config
self.config = config or {
"conv_layers": [32, 64, 128, 256],
"fc_layers": [256, 128],
"num_classes": 4,
"dropout": 0.5,
"input_channels": 3,
}
conv_channels = self.config.get("conv_layers", [32, 64, 128, 256])
fc_sizes = self.config.get("fc_layers", [256, 128])
dropout = self.config.get("dropout", 0.5)
num_classes = self.config.get("num_classes", 4)
input_channels = self.config.get("input_channels", 3)
# Build convolutional layers
layers = []
in_channels = input_channels
for out_channels in conv_channels:
layers.extend([
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
])
in_channels = out_channels
self.conv_layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
# Build fully connected layers
fc_blocks = []
in_features = conv_channels[-1]
for fc_size in fc_sizes:
fc_blocks.extend([
nn.Linear(in_features, fc_size),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
in_features = fc_size
self.fc_layers = nn.Sequential(*fc_blocks)
self.classifier = nn.Linear(in_features, num_classes)
# Class labels
self.id2label = {
0: "Normal",
1: "LSIL",
2: "HSIL",
3: "Cancer"
}
self.label2id = {v: k for k, v in self.id2label.items()}
def forward(self, x):
"""
Forward pass.
Args:
x: Input tensor of shape (batch, 3, height, width)
Returns:
Logits tensor of shape (batch, num_classes)
"""
x = self.conv_layers(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
x = self.classifier(x)
return x
def predict(self, x):
"""
Predict class labels.
Args:
x: Input tensor of shape (batch, 3, height, width)
Returns:
Tuple of (predicted_class_ids, probabilities)
"""
self.eval()
with torch.no_grad():
logits = self.forward(x)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
return preds, probs
@classmethod
def from_pretrained(cls, model_path, device="cpu"):
"""
Load pretrained model.
Args:
model_path: Path to model directory or checkpoint file
device: Device to load model on
Returns:
Loaded model
"""
import os
from pathlib import Path
model_path = Path(model_path)
# Try different file formats
if model_path.is_dir():
if (model_path / "model.safetensors").exists():
weights_path = model_path / "model.safetensors"
use_safetensors = True
elif (model_path / "pytorch_model.bin").exists():
weights_path = model_path / "pytorch_model.bin"
use_safetensors = False
else:
raise FileNotFoundError(f"No model weights found in {model_path}")
else:
weights_path = model_path
use_safetensors = str(model_path).endswith(".safetensors")
# Create model
model = cls()
# Load weights
if use_safetensors:
from safetensors.torch import load_file
state_dict = load_file(str(weights_path))
else:
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model