toderian's picture
Upload folder using huggingface_hub
b915bae verified
"""
Cervical Type Classification Model
This module contains the BaseCNN model for classifying cervical images
into 3 transformation zone types.
Usage:
from model import BaseCNN
# Load pretrained model
model = BaseCNN.from_pretrained("./")
# Or create from scratch
model = BaseCNN(
layers=[32, 64, 128, 256],
fc_layers=[256, 128],
nr_classes=3
)
"""
import json
from pathlib import Path
import torch
import torch.nn as nn
try:
from safetensors.torch import load_file, save_file
HAS_SAFETENSORS = True
except ImportError:
HAS_SAFETENSORS = False
class BaseCNN(nn.Module):
"""
Simple CNN for cervical type classification.
Classifies cervical images into 3 transformation zone types:
- Type 1: Transformation zone fully visible on ectocervix
- Type 2: Transformation zone partially visible
- Type 3: Transformation zone not visible (within endocervical canal)
Args:
layers: List of output channels for each conv layer. Default: [32, 64, 128, 256]
kernel: Kernel size for conv layers. Default: 3
padding: Padding for conv layers. Default: 1
stride: Stride for conv layers. Default: 1
batchnorm: Whether to use batch normalization. Default: True
bn_pre_activ: Whether to apply BN before activation. Default: True
activation: Activation function name. Default: 'ReLU'
dropout: Dropout rate for FC layers. Default: 0.4
pool: Whether to use max pooling after each conv. Default: True
fc_layers: List of FC layer sizes. Default: [256, 128]
nr_classes: Number of output classes. Default: 3
in_channels: Number of input channels. Default: 3
"""
def __init__(
self,
layers: list = None,
kernel: int = 3,
padding: int = 1,
stride: int = 1,
batchnorm: bool = True,
bn_pre_activ: bool = True,
activation: str = 'ReLU',
dropout: float = 0.4,
pool: bool = True,
fc_layers: list = None,
nr_classes: int = 3,
in_channels: int = 3,
):
super().__init__()
# Store config for serialization
self.config = {
'layers': layers or [32, 64, 128, 256],
'kernel': kernel,
'padding': padding,
'stride': stride,
'batchnorm': batchnorm,
'bn_pre_activ': bn_pre_activ,
'activation': activation,
'dropout': dropout,
'pool': pool,
'fc_layers': fc_layers or [256, 128],
'nr_classes': nr_classes,
'in_channels': in_channels,
}
layers = self.config['layers']
fc_layers = self.config['fc_layers']
# Activation function
activation_fn = getattr(nn, activation)
# Build convolutional layers (ModuleList to match original)
self.conv_layers = nn.ModuleList()
prev_channels = in_channels
for out_channels in layers:
self.conv_layers.append(
nn.Conv2d(prev_channels, out_channels, kernel, stride, padding)
)
if batchnorm and bn_pre_activ:
self.conv_layers.append(nn.BatchNorm2d(out_channels))
self.conv_layers.append(activation_fn())
if batchnorm and not bn_pre_activ:
self.conv_layers.append(nn.BatchNorm2d(out_channels))
if pool:
self.conv_layers.append(nn.MaxPool2d(2, 2))
prev_channels = out_channels
# Global average pooling
self.adaptive_pool = nn.AdaptiveAvgPool2d(1)
# Build fully connected layers (ModuleList to match original)
self.fc_layers = nn.ModuleList()
prev_features = layers[-1]
for fc_size in fc_layers:
self.fc_layers.append(nn.Linear(prev_features, fc_size))
self.fc_layers.append(activation_fn())
self.fc_layers.append(nn.Dropout(dropout))
prev_features = fc_size
# Final classifier (separate, to match original)
self.classifier = nn.Linear(prev_features, nr_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape (batch_size, 3, 256, 256)
Returns:
Logits tensor of shape (batch_size, num_classes)
"""
for layer in self.conv_layers:
x = layer(x)
x = self.adaptive_pool(x)
x = x.view(x.size(0), -1)
for layer in self.fc_layers:
x = layer(x)
x = self.classifier(x)
return x
@classmethod
def from_pretrained(cls, model_path: str, device: str = 'cpu') -> 'BaseCNN':
"""
Load a pretrained model from a directory.
Args:
model_path: Path to directory containing model files
device: Device to load model on ('cpu' or 'cuda')
Returns:
Loaded model in eval mode
"""
model_path = Path(model_path)
# Load config
config_path = model_path / 'config.json'
with open(config_path, 'r') as f:
config = json.load(f)
# Create model
model = cls(**config['model_config'])
# Load weights (prefer safetensors)
safetensors_path = model_path / 'model.safetensors'
pytorch_path = model_path / 'pytorch_model.bin'
if safetensors_path.exists() and HAS_SAFETENSORS:
state_dict = load_file(str(safetensors_path), device=device)
elif pytorch_path.exists():
state_dict = torch.load(pytorch_path, map_location=device, weights_only=True)
else:
raise FileNotFoundError(f"No model weights found in {model_path}")
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def save_pretrained(self, save_path: str) -> None:
"""
Save model in Hugging Face compatible format.
Args:
save_path: Directory to save model files
"""
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
# Save config
config = {
'model_type': 'BaseCNN',
'model_config': self.config,
'num_labels': self.config['nr_classes'],
'id2label': {
'0': 'Type 1',
'1': 'Type 2',
'2': 'Type 3'
},
'label2id': {
'Type 1': 0,
'Type 2': 1,
'Type 3': 2
}
}
with open(save_path / 'config.json', 'w') as f:
json.dump(config, f, indent=2)
# Save weights
state_dict = {k: v.contiguous() for k, v in self.state_dict().items()}
# SafeTensors format (recommended)
if HAS_SAFETENSORS:
save_file(state_dict, str(save_path / 'model.safetensors'))
# PyTorch format (backup)
torch.save(state_dict, save_path / 'pytorch_model.bin')
# Label mappings
ID2LABEL = {0: 'Type 1', 1: 'Type 2', 2: 'Type 3'}
LABEL2ID = {'Type 1': 0, 'Type 2': 1, 'Type 3': 2}
if __name__ == '__main__':
# Quick test
model = BaseCNN()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test forward pass
x = torch.randn(1, 3, 256, 256)
y = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")