cerviguard_lesion / model.py
toderian's picture
Upload folder using huggingface_hub
84ddecf verified
"""
Cervical Cancer Classification Model
This file provides the model architecture for easy import.
Usage:
from model import CervicalCancerCNN, load_model, predict
model = load_model("model.safetensors")
result = predict(model, image_tensor)
"""
import torch
import torch.nn as nn
from pathlib import Path
class CervicalCancerCNN(nn.Module):
"""
CNN for cervical cancer classification.
Classifies cervical colposcopy images into 4 severity classes:
- 0: Normal - Healthy cervical tissue
- 1: LSIL - Low-grade Squamous Intraepithelial Lesion
- 2: HSIL - High-grade Squamous Intraepithelial Lesion
- 3: Cancer - Invasive cervical cancer
Architecture:
Conv[32,64,128,256] -> AvgPool -> FC[256,128] -> Classifier[4]
Input:
Tensor of shape (batch, 3, 224, 298)
Output:
Logits of shape (batch, 4)
"""
# Class labels
CLASSES = {
0: "Normal",
1: "LSIL",
2: "HSIL",
3: "Cancer"
}
def __init__(self, config=None):
super().__init__()
# Default configuration
config = config or {}
conv_channels = config.get("conv_layers", [32, 64, 128, 256])
fc_sizes = config.get("fc_layers", [256, 128])
dropout = config.get("dropout", 0.5)
num_classes = config.get("num_classes", 4)
# Build convolutional layers
layers = []
in_channels = 3
for out_channels in conv_channels:
layers.extend([
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
])
in_channels = out_channels
self.conv_layers = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
# Build fully connected layers
fc_blocks = []
in_features = conv_channels[-1]
for fc_size in fc_sizes:
fc_blocks.extend([
nn.Linear(in_features, fc_size),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
in_features = fc_size
self.fc_layers = nn.Sequential(*fc_blocks)
self.classifier = nn.Linear(in_features, num_classes)
def forward(self, x):
"""Forward pass."""
x = self.conv_layers(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
x = self.classifier(x)
return x
def predict_class(self, x):
"""Predict class labels and probabilities."""
self.eval()
with torch.no_grad():
logits = self.forward(x)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
return preds, probs
def load_model(model_path, device="cpu"):
"""
Load model from file.
Args:
model_path: Path to model weights (.safetensors or .bin/.pth)
device: Device to load model on ("cpu" or "cuda")
Returns:
Loaded model in eval mode
"""
model = CervicalCancerCNN()
model_path = Path(model_path)
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def predict(model, image_tensor, device="cpu"):
"""
Run prediction on an image tensor.
Args:
model: Loaded CervicalCancerCNN model
image_tensor: Preprocessed image tensor (1, 3, 224, 298)
device: Device for inference
Returns:
Dictionary with prediction results
"""
model.eval()
image_tensor = image_tensor.to(device)
with torch.no_grad():
logits = model(image_tensor)
probs = torch.softmax(logits, dim=1)[0]
pred_class = torch.argmax(logits, dim=1).item()
return {
"class_id": pred_class,
"class_name": CervicalCancerCNN.CLASSES[pred_class],
"confidence": probs[pred_class].item(),
"probabilities": {
CervicalCancerCNN.CLASSES[i]: probs[i].item()
for i in range(4)
}
}
def get_preprocessing_transform():
"""
Get the preprocessing transform for input images.
Returns:
torchvision.transforms.Compose object
"""
import torchvision.transforms as T
return T.Compose([
T.Resize((224, 298)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Quick usage example
if __name__ == "__main__":
import sys
# Create model
model = CervicalCancerCNN()
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
# Print architecture
print("\nArchitecture:")
print(model)
# Test forward pass
dummy_input = torch.randn(1, 3, 224, 298)
output = model(dummy_input)
print(f"\nInput shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Output classes: {list(CervicalCancerCNN.CLASSES.values())}")