File size: 4,373 Bytes
415365e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
---
license: mit
tags:
- video-classification
- medical
- microscopy
- sperm-analysis
- pytorch
library_name: pytorch
---
# Sperm Normality Rate Classifier
## Model Description
This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture.
**Architecture**: 3D ResNet18
**Task**: Video Classification
**Classes**: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%)
**Input**: 90 frames (3 seconds at 30 fps), 224x224 RGB
**Framework**: PyTorch
## Intended Use
This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine.
## Model Details
- **Input Format**: Video clips of 90 frames (3 seconds at 30 fps)
- **Preprocessing**:
- Frames resized to 224x224
- ImageNet normalization + per-video standardization
- Combined normalization for robustness
- **Output**: Probability distribution over 6 normality rate classes
## Training Details
- **Dataset**: Balanced dataset with 100 clips per class
- **Normalization**: Combined ImageNet + per-video standardization
- **Loss Function**: Focal Loss with class weights
- **Optimizer**: Adam (lr=1e-4)
- **Early Stopping**: Patience of 10 epochs
## Usage
```python
import torch
import torch.nn as nn
import torchvision.models.video as video_models
import cv2
import numpy as np
# Define model architecture
class VideoClassifier(nn.Module):
def __init__(self, num_classes=6):
super(VideoClassifier, self).__init__()
self.backbone = video_models.r3d_18(pretrained=False)
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.backbone(x)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VideoClassifier(num_classes=6)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)
model.eval()
# Preprocess video
def preprocess_video(video_path, num_frames=90, target_size=(224, 224)):
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, target_size)
frame = frame.astype(np.float32) / 255.0
frames.append(frame)
cap.release()
# Sample to 90 frames
if len(frames) < num_frames:
repeat_factor = int(np.ceil(num_frames / len(frames)))
frames = (frames * repeat_factor)[:num_frames]
elif len(frames) > num_frames:
indices = np.linspace(0, len(frames) - 1, num_frames).astype(int)
frames = [frames[i] for i in indices]
# Convert to tensor (C, T, H, W)
frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2)
# Apply normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
frames = (frames - mean) / std
# Per-video standardization
video_mean = frames.mean()
video_std = frames.std()
if video_std > 0:
frames = (frames - video_mean) / video_std
return frames.unsqueeze(0) # Add batch dimension
# Inference
video_tensor = preprocess_video("path/to/video.mp4")
video_tensor = video_tensor.to(device)
with torch.no_grad():
outputs = model(video_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_names = ["0%", "60%", "70%", "80%", "85%", "90%"]
print(f"Predicted normality rate: {class_names[predicted_class]}")
print(f"Confidence: {probabilities[0][predicted_class].item():.4f}")
```
## Limitations
- Trained on specific microscopy equipment and protocols
- Performance may vary with different imaging conditions
- Should be used as diagnostic support, not sole decision-making tool
- Requires proper video preprocessing
## Citation
If you use this model, please cite:
```
@misc{sperm-normality-classifier,
author = {Raid Athmane Benlala},
title = {Sperm Normality Rate Classifier},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}}
}
```
|