|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- video-classification |
|
|
- medical |
|
|
- microscopy |
|
|
- sperm-analysis |
|
|
- pytorch |
|
|
library_name: pytorch |
|
|
--- |
|
|
|
|
|
# Sperm Normality Rate Classifier |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture. |
|
|
|
|
|
**Architecture**: 3D ResNet18 |
|
|
**Task**: Video Classification |
|
|
**Classes**: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%) |
|
|
**Input**: 90 frames (3 seconds at 30 fps), 224x224 RGB |
|
|
**Framework**: PyTorch |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Input Format**: Video clips of 90 frames (3 seconds at 30 fps) |
|
|
- **Preprocessing**: |
|
|
- Frames resized to 224x224 |
|
|
- ImageNet normalization + per-video standardization |
|
|
- Combined normalization for robustness |
|
|
- **Output**: Probability distribution over 6 normality rate classes |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Dataset**: Balanced dataset with 100 clips per class |
|
|
- **Normalization**: Combined ImageNet + per-video standardization |
|
|
- **Loss Function**: Focal Loss with class weights |
|
|
- **Optimizer**: Adam (lr=1e-4) |
|
|
- **Early Stopping**: Patience of 10 epochs |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.models.video as video_models |
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
# Define model architecture |
|
|
class VideoClassifier(nn.Module): |
|
|
def __init__(self, num_classes=6): |
|
|
super(VideoClassifier, self).__init__() |
|
|
self.backbone = video_models.r3d_18(pretrained=False) |
|
|
in_features = self.backbone.fc.in_features |
|
|
self.backbone.fc = nn.Linear(in_features, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.backbone(x) |
|
|
|
|
|
# Load model |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = VideoClassifier(num_classes=6) |
|
|
model.load_state_dict(torch.load('best_model.pth', map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
# Preprocess video |
|
|
def preprocess_video(video_path, num_frames=90, target_size=(224, 224)): |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
frames = [] |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frame = cv2.resize(frame, target_size) |
|
|
frame = frame.astype(np.float32) / 255.0 |
|
|
frames.append(frame) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
# Sample to 90 frames |
|
|
if len(frames) < num_frames: |
|
|
repeat_factor = int(np.ceil(num_frames / len(frames))) |
|
|
frames = (frames * repeat_factor)[:num_frames] |
|
|
elif len(frames) > num_frames: |
|
|
indices = np.linspace(0, len(frames) - 1, num_frames).astype(int) |
|
|
frames = [frames[i] for i in indices] |
|
|
|
|
|
# Convert to tensor (C, T, H, W) |
|
|
frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2) |
|
|
|
|
|
# Apply normalization |
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1) |
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1) |
|
|
frames = (frames - mean) / std |
|
|
|
|
|
# Per-video standardization |
|
|
video_mean = frames.mean() |
|
|
video_std = frames.std() |
|
|
if video_std > 0: |
|
|
frames = (frames - video_mean) / video_std |
|
|
|
|
|
return frames.unsqueeze(0) # Add batch dimension |
|
|
|
|
|
# Inference |
|
|
video_tensor = preprocess_video("path/to/video.mp4") |
|
|
video_tensor = video_tensor.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(video_tensor) |
|
|
probabilities = torch.softmax(outputs, dim=1) |
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
|
|
|
class_names = ["0%", "60%", "70%", "80%", "85%", "90%"] |
|
|
print(f"Predicted normality rate: {class_names[predicted_class]}") |
|
|
print(f"Confidence: {probabilities[0][predicted_class].item():.4f}") |
|
|
``` |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Trained on specific microscopy equipment and protocols |
|
|
- Performance may vary with different imaging conditions |
|
|
- Should be used as diagnostic support, not sole decision-making tool |
|
|
- Requires proper video preprocessing |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
``` |
|
|
@misc{sperm-normality-classifier, |
|
|
author = {Raid Athmane Benlala}, |
|
|
title = {Sperm Normality Rate Classifier}, |
|
|
year = {2025}, |
|
|
publisher = {Hugging Face}, |
|
|
howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}} |
|
|
} |
|
|
``` |
|
|
|