|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- image-classification |
|
|
- medical |
|
|
- cervical-cancer |
|
|
- pytorch |
|
|
- cnn |
|
|
- colposcopy |
|
|
datasets: |
|
|
- custom |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
pipeline_tag: image-classification |
|
|
library_name: pytorch |
|
|
--- |
|
|
|
|
|
# Cervical Cancer Classification CNN |
|
|
|
|
|
A CNN model for classifying cervical colposcopy images into 4 severity classes for cervical cancer screening. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model classifies cervical images into: |
|
|
|
|
|
| Class | Label | Description | Clinical Action | |
|
|
|-------|-------|-------------|-----------------| |
|
|
| 0 | Normal | Healthy cervical tissue | Routine screening in 3-5 years | |
|
|
| 1 | LSIL | Low-grade Squamous Intraepithelial Lesion | Monitor, repeat test in 6-12 months | |
|
|
| 2 | HSIL | High-grade Squamous Intraepithelial Lesion | Colposcopy, biopsy, treatment required | |
|
|
| 3 | Cancer | Invasive cervical cancer | Immediate oncology referral | |
|
|
|
|
|
--- |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
### Architecture Diagram |
|
|
|
|
|
``` |
|
|
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β INPUT IMAGE β |
|
|
β (3 Γ 224 Γ 298) β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β CONV BLOCK 1 β |
|
|
β βββ Conv2d(3 β 32, kernel=3Γ3, padding=1) β |
|
|
β βββ BatchNorm2d(32) β |
|
|
β βββ ReLU β |
|
|
β βββ MaxPool2d(2Γ2) β |
|
|
β Output: 32 Γ 112 Γ 149 β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β CONV BLOCK 2 β |
|
|
β βββ Conv2d(32 β 64, kernel=3Γ3, padding=1) β |
|
|
β βββ BatchNorm2d(64) β |
|
|
β βββ ReLU β |
|
|
β βββ MaxPool2d(2Γ2) β |
|
|
β Output: 64 Γ 56 Γ 74 β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β CONV BLOCK 3 β |
|
|
β βββ Conv2d(64 β 128, kernel=3Γ3, padding=1) β |
|
|
β βββ BatchNorm2d(128) β |
|
|
β βββ ReLU β |
|
|
β βββ MaxPool2d(2Γ2) β |
|
|
β Output: 128 Γ 28 Γ 37 β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β CONV BLOCK 4 β |
|
|
β βββ Conv2d(128 β 256, kernel=3Γ3, padding=1) β |
|
|
β βββ BatchNorm2d(256) β |
|
|
β βββ ReLU β |
|
|
β βββ MaxPool2d(2Γ2) β |
|
|
β Output: 256 Γ 14 Γ 18 β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β GLOBAL AVERAGE POOLING β |
|
|
β βββ AdaptiveAvgPool2d(1Γ1) β |
|
|
β Output: 256 Γ 1 Γ 1 β Flatten β 256 β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β FC BLOCK 1 β |
|
|
β βββ Linear(256 β 256) β |
|
|
β βββ ReLU β |
|
|
β βββ Dropout(0.5) β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β FC BLOCK 2 β |
|
|
β βββ Linear(256 β 128) β |
|
|
β βββ ReLU β |
|
|
β βββ Dropout(0.5) β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ |
|
|
β CLASSIFIER β |
|
|
β βββ Linear(128 β 4) β |
|
|
β Output: 4 class logits β |
|
|
βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βΌ |
|
|
[Normal, LSIL, HSIL, Cancer] |
|
|
``` |
|
|
|
|
|
### Architecture Summary Table |
|
|
|
|
|
| Layer | Type | Input Shape | Output Shape | Parameters | |
|
|
|-------|------|-------------|--------------|------------| |
|
|
| conv_layers.0 | Conv2d | (3, 224, 298) | (32, 224, 298) | 896 | |
|
|
| conv_layers.1 | BatchNorm2d | (32, 224, 298) | (32, 224, 298) | 64 | |
|
|
| conv_layers.2 | ReLU | - | - | 0 | |
|
|
| conv_layers.3 | MaxPool2d | (32, 224, 298) | (32, 112, 149) | 0 | |
|
|
| conv_layers.4 | Conv2d | (32, 112, 149) | (64, 112, 149) | 18,496 | |
|
|
| conv_layers.5 | BatchNorm2d | (64, 112, 149) | (64, 112, 149) | 128 | |
|
|
| conv_layers.6 | ReLU | - | - | 0 | |
|
|
| conv_layers.7 | MaxPool2d | (64, 112, 149) | (64, 56, 74) | 0 | |
|
|
| conv_layers.8 | Conv2d | (64, 56, 74) | (128, 56, 74) | 73,856 | |
|
|
| conv_layers.9 | BatchNorm2d | (128, 56, 74) | (128, 56, 74) | 256 | |
|
|
| conv_layers.10 | ReLU | - | - | 0 | |
|
|
| conv_layers.11 | MaxPool2d | (128, 56, 74) | (128, 28, 37) | 0 | |
|
|
| conv_layers.12 | Conv2d | (128, 28, 37) | (256, 28, 37) | 295,168 | |
|
|
| conv_layers.13 | BatchNorm2d | (256, 28, 37) | (256, 28, 37) | 512 | |
|
|
| conv_layers.14 | ReLU | - | - | 0 | |
|
|
| conv_layers.15 | MaxPool2d | (256, 28, 37) | (256, 14, 18) | 0 | |
|
|
| avgpool | AdaptiveAvgPool2d | (256, 14, 18) | (256, 1, 1) | 0 | |
|
|
| fc_layers.0 | Linear | 256 | 256 | 65,792 | |
|
|
| fc_layers.1 | ReLU | - | - | 0 | |
|
|
| fc_layers.2 | Dropout | - | - | 0 | |
|
|
| fc_layers.3 | Linear | 256 | 128 | 32,896 | |
|
|
| fc_layers.4 | ReLU | - | - | 0 | |
|
|
| fc_layers.5 | Dropout | - | - | 0 | |
|
|
| classifier | Linear | 128 | 4 | 516 | |
|
|
| **Total** | | | | **488,580** | |
|
|
|
|
|
### PyTorch Model Code |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
class CervicalCancerCNN(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
# Convolutional layers: [32, 64, 128, 256] |
|
|
self.conv_layers = nn.Sequential( |
|
|
# Block 1: 3 -> 32 |
|
|
nn.Conv2d(3, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
# Block 2: 32 -> 64 |
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
# Block 3: 64 -> 128 |
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
# Block 4: 128 -> 256 |
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
) |
|
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
# Fully connected layers: [256, 128] -> 4 |
|
|
self.fc_layers = nn.Sequential( |
|
|
nn.Linear(256, 256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(256, 128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.5), |
|
|
) |
|
|
|
|
|
self.classifier = nn.Linear(128, 4) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_layers(x) |
|
|
x = self.avgpool(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
x = self.fc_layers(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Overall Metrics |
|
|
|
|
|
| Metric | Value | |
|
|
|--------|-------| |
|
|
| **Accuracy** | 59.52% | |
|
|
| **Macro F1** | 59.85% | |
|
|
| **Parameters** | 488,580 | |
|
|
|
|
|
### Per-Class Metrics |
|
|
|
|
|
| Class | Precision | Recall | F1 Score | Support | |
|
|
|-------|-----------|--------|----------|---------| |
|
|
| Normal | 0.595 | 0.595 | 0.595 | 84 | |
|
|
| LSIL | 0.521 | 0.583 | 0.551 | 84 | |
|
|
| HSIL | 0.446 | 0.440 | 0.443 | 84 | |
|
|
| Cancer | 0.853 | 0.762 | 0.805 | 84 | |
|
|
|
|
|
### Confusion Matrix |
|
|
|
|
|
``` |
|
|
Predicted β Normal LSIL HSIL Cancer |
|
|
Actual β |
|
|
Normal 50 9 17 8 |
|
|
LSIL 24 49 11 0 |
|
|
HSIL 9 35 37 3 |
|
|
Cancer 1 1 18 64 |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision safetensors huggingface_hub |
|
|
``` |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json |
|
|
|
|
|
# Download model files |
|
|
model_file = hf_hub_download("toderian/cerviguard_lesion", "model.safetensors") |
|
|
config_file = hf_hub_download("toderian/cerviguard_lesion", "config.json") |
|
|
|
|
|
# Load config |
|
|
with open(config_file) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
# Define model (copy from above or download modeling_cervical.py) |
|
|
model = CervicalCancerCNN() |
|
|
|
|
|
# Load weights |
|
|
state_dict = load_file(model_file) |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
|
|
|
```python |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
|
|
|
# Preprocessing |
|
|
transform = T.Compose([ |
|
|
T.Resize((224, 298)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
# Load and preprocess image |
|
|
image = Image.open("cervical_image.jpg").convert("RGB") |
|
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
# Inference |
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
probabilities = torch.softmax(output, dim=1) |
|
|
prediction = output.argmax(dim=1).item() |
|
|
|
|
|
classes = ["Normal", "LSIL", "HSIL", "Cancer"] |
|
|
print(f"Prediction: {classes[prediction]}") |
|
|
print(f"Confidence: {probabilities[0][prediction]:.2%}") |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Training Details |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Learning Rate | 1e-4 | |
|
|
| Batch Size | 32 | |
|
|
| Optimizer | Adam | |
|
|
| Loss | CrossEntropyLoss | |
|
|
| Dropout | 0.5 | |
|
|
| Epochs | 34 (early stopping at 24) | |
|
|
|
|
|
### Dataset |
|
|
|
|
|
| Split | Samples | Distribution | |
|
|
|-------|---------|--------------| |
|
|
| Train | 3,003 | Imbalanced [1540, 469, 854, 140] | |
|
|
| Test | 336 | Balanced [84, 84, 84, 84] | |
|
|
|
|
|
--- |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Trained on limited dataset (~3k samples) |
|
|
- HSIL class has lowest performance (F1=0.443) |
|
|
- Should not be used as sole diagnostic tool |
|
|
- Intended for research and screening assistance only |
|
|
|
|
|
## Medical Disclaimer |
|
|
|
|
|
β οΈ **This model is for research purposes only.** It should not be used as a substitute for professional medical diagnosis. Always consult qualified healthcare professionals for cervical cancer screening and diagnosis. |
|
|
|
|
|
--- |
|
|
|
|
|
## Files in This Repository |
|
|
|
|
|
| File | Description | |
|
|
|------|-------------| |
|
|
| `model.safetensors` | Model weights (safetensors format) | |
|
|
| `pytorch_model.bin` | Model weights (legacy PyTorch format) | |
|
|
| `config.json` | Model configuration | |
|
|
| `preprocessor_config.json` | Image preprocessing settings | |
|
|
| `modeling_cervical.py` | Model class definition | |
|
|
| `example_inference.py` | Example inference script | |
|
|
|
|
|
--- |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{cervical-cancer-cnn-2025, |
|
|
author = {Toderian}, |
|
|
title = {Cervical Cancer Classification CNN}, |
|
|
year = {2025}, |
|
|
publisher = {Hugging Face}, |
|
|
url = {https://huggingface.co/toderian/cerviguard_lesion} |
|
|
} |
|
|
``` |
|
|
|