|
|
""" |
|
|
Cervical Type Classification Model |
|
|
|
|
|
This module contains the BaseCNN model for classifying cervical images |
|
|
into 3 transformation zone types. |
|
|
|
|
|
Usage: |
|
|
from model import BaseCNN |
|
|
|
|
|
# Load pretrained model |
|
|
model = BaseCNN.from_pretrained("./") |
|
|
|
|
|
# Or create from scratch |
|
|
model = BaseCNN( |
|
|
layers=[32, 64, 128, 256], |
|
|
fc_layers=[256, 128], |
|
|
nr_classes=3 |
|
|
) |
|
|
""" |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
try: |
|
|
from safetensors.torch import load_file, save_file |
|
|
HAS_SAFETENSORS = True |
|
|
except ImportError: |
|
|
HAS_SAFETENSORS = False |
|
|
|
|
|
|
|
|
class BaseCNN(nn.Module): |
|
|
""" |
|
|
Simple CNN for cervical type classification. |
|
|
|
|
|
Classifies cervical images into 3 transformation zone types: |
|
|
- Type 1: Transformation zone fully visible on ectocervix |
|
|
- Type 2: Transformation zone partially visible |
|
|
- Type 3: Transformation zone not visible (within endocervical canal) |
|
|
|
|
|
Args: |
|
|
layers: List of output channels for each conv layer. Default: [32, 64, 128, 256] |
|
|
kernel: Kernel size for conv layers. Default: 3 |
|
|
padding: Padding for conv layers. Default: 1 |
|
|
stride: Stride for conv layers. Default: 1 |
|
|
batchnorm: Whether to use batch normalization. Default: True |
|
|
bn_pre_activ: Whether to apply BN before activation. Default: True |
|
|
activation: Activation function name. Default: 'ReLU' |
|
|
dropout: Dropout rate for FC layers. Default: 0.4 |
|
|
pool: Whether to use max pooling after each conv. Default: True |
|
|
fc_layers: List of FC layer sizes. Default: [256, 128] |
|
|
nr_classes: Number of output classes. Default: 3 |
|
|
in_channels: Number of input channels. Default: 3 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
layers: list = None, |
|
|
kernel: int = 3, |
|
|
padding: int = 1, |
|
|
stride: int = 1, |
|
|
batchnorm: bool = True, |
|
|
bn_pre_activ: bool = True, |
|
|
activation: str = 'ReLU', |
|
|
dropout: float = 0.4, |
|
|
pool: bool = True, |
|
|
fc_layers: list = None, |
|
|
nr_classes: int = 3, |
|
|
in_channels: int = 3, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.config = { |
|
|
'layers': layers or [32, 64, 128, 256], |
|
|
'kernel': kernel, |
|
|
'padding': padding, |
|
|
'stride': stride, |
|
|
'batchnorm': batchnorm, |
|
|
'bn_pre_activ': bn_pre_activ, |
|
|
'activation': activation, |
|
|
'dropout': dropout, |
|
|
'pool': pool, |
|
|
'fc_layers': fc_layers or [256, 128], |
|
|
'nr_classes': nr_classes, |
|
|
'in_channels': in_channels, |
|
|
} |
|
|
|
|
|
layers = self.config['layers'] |
|
|
fc_layers = self.config['fc_layers'] |
|
|
|
|
|
|
|
|
activation_fn = getattr(nn, activation) |
|
|
|
|
|
|
|
|
self.conv_layers = nn.ModuleList() |
|
|
prev_channels = in_channels |
|
|
|
|
|
for out_channels in layers: |
|
|
self.conv_layers.append( |
|
|
nn.Conv2d(prev_channels, out_channels, kernel, stride, padding) |
|
|
) |
|
|
if batchnorm and bn_pre_activ: |
|
|
self.conv_layers.append(nn.BatchNorm2d(out_channels)) |
|
|
self.conv_layers.append(activation_fn()) |
|
|
if batchnorm and not bn_pre_activ: |
|
|
self.conv_layers.append(nn.BatchNorm2d(out_channels)) |
|
|
if pool: |
|
|
self.conv_layers.append(nn.MaxPool2d(2, 2)) |
|
|
prev_channels = out_channels |
|
|
|
|
|
|
|
|
self.adaptive_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
|
|
|
self.fc_layers = nn.ModuleList() |
|
|
prev_features = layers[-1] |
|
|
|
|
|
for fc_size in fc_layers: |
|
|
self.fc_layers.append(nn.Linear(prev_features, fc_size)) |
|
|
self.fc_layers.append(activation_fn()) |
|
|
self.fc_layers.append(nn.Dropout(dropout)) |
|
|
prev_features = fc_size |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(prev_features, nr_classes) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch_size, 3, 256, 256) |
|
|
|
|
|
Returns: |
|
|
Logits tensor of shape (batch_size, num_classes) |
|
|
""" |
|
|
for layer in self.conv_layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.adaptive_pool(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
for layer in self.fc_layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path: str, device: str = 'cpu') -> 'BaseCNN': |
|
|
""" |
|
|
Load a pretrained model from a directory. |
|
|
|
|
|
Args: |
|
|
model_path: Path to directory containing model files |
|
|
device: Device to load model on ('cpu' or 'cuda') |
|
|
|
|
|
Returns: |
|
|
Loaded model in eval mode |
|
|
""" |
|
|
model_path = Path(model_path) |
|
|
|
|
|
|
|
|
config_path = model_path / 'config.json' |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
model = cls(**config['model_config']) |
|
|
|
|
|
|
|
|
safetensors_path = model_path / 'model.safetensors' |
|
|
pytorch_path = model_path / 'pytorch_model.bin' |
|
|
|
|
|
if safetensors_path.exists() and HAS_SAFETENSORS: |
|
|
state_dict = load_file(str(safetensors_path), device=device) |
|
|
elif pytorch_path.exists(): |
|
|
state_dict = torch.load(pytorch_path, map_location=device, weights_only=True) |
|
|
else: |
|
|
raise FileNotFoundError(f"No model weights found in {model_path}") |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_path: str) -> None: |
|
|
""" |
|
|
Save model in Hugging Face compatible format. |
|
|
|
|
|
Args: |
|
|
save_path: Directory to save model files |
|
|
""" |
|
|
save_path = Path(save_path) |
|
|
save_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
config = { |
|
|
'model_type': 'BaseCNN', |
|
|
'model_config': self.config, |
|
|
'num_labels': self.config['nr_classes'], |
|
|
'id2label': { |
|
|
'0': 'Type 1', |
|
|
'1': 'Type 2', |
|
|
'2': 'Type 3' |
|
|
}, |
|
|
'label2id': { |
|
|
'Type 1': 0, |
|
|
'Type 2': 1, |
|
|
'Type 3': 2 |
|
|
} |
|
|
} |
|
|
with open(save_path / 'config.json', 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
|
|
|
state_dict = {k: v.contiguous() for k, v in self.state_dict().items()} |
|
|
|
|
|
|
|
|
if HAS_SAFETENSORS: |
|
|
save_file(state_dict, str(save_path / 'model.safetensors')) |
|
|
|
|
|
|
|
|
torch.save(state_dict, save_path / 'pytorch_model.bin') |
|
|
|
|
|
|
|
|
|
|
|
ID2LABEL = {0: 'Type 1', 1: 'Type 2', 2: 'Type 3'} |
|
|
LABEL2ID = {'Type 1': 0, 'Type 2': 1, 'Type 3': 2} |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
model = BaseCNN() |
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
x = torch.randn(1, 3, 256, 256) |
|
|
y = model(x) |
|
|
print(f"Input shape: {x.shape}") |
|
|
print(f"Output shape: {y.shape}") |
|
|
|