Sperm Normality Rate Classifier
Model Description
This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture.
Architecture: 3D ResNet18
Task: Video Classification
Classes: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%)
Input: 90 frames (3 seconds at 30 fps), 224x224 RGB
Framework: PyTorch
Intended Use
This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine.
Model Details
- Input Format: Video clips of 90 frames (3 seconds at 30 fps)
- Preprocessing:
- Frames resized to 224x224
- ImageNet normalization + per-video standardization
- Combined normalization for robustness
- Output: Probability distribution over 6 normality rate classes
Training Details
- Dataset: Balanced dataset with 100 clips per class
- Normalization: Combined ImageNet + per-video standardization
- Loss Function: Focal Loss with class weights
- Optimizer: Adam (lr=1e-4)
- Early Stopping: Patience of 10 epochs
Usage
import torch
import torch.nn as nn
import torchvision.models.video as video_models
import cv2
import numpy as np
# Define model architecture
class VideoClassifier(nn.Module):
def __init__(self, num_classes=6):
super(VideoClassifier, self).__init__()
self.backbone = video_models.r3d_18(pretrained=False)
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.backbone(x)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VideoClassifier(num_classes=6)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)
model.eval()
# Preprocess video
def preprocess_video(video_path, num_frames=90, target_size=(224, 224)):
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, target_size)
frame = frame.astype(np.float32) / 255.0
frames.append(frame)
cap.release()
# Sample to 90 frames
if len(frames) < num_frames:
repeat_factor = int(np.ceil(num_frames / len(frames)))
frames = (frames * repeat_factor)[:num_frames]
elif len(frames) > num_frames:
indices = np.linspace(0, len(frames) - 1, num_frames).astype(int)
frames = [frames[i] for i in indices]
# Convert to tensor (C, T, H, W)
frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2)
# Apply normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
frames = (frames - mean) / std
# Per-video standardization
video_mean = frames.mean()
video_std = frames.std()
if video_std > 0:
frames = (frames - video_mean) / video_std
return frames.unsqueeze(0) # Add batch dimension
# Inference
video_tensor = preprocess_video("path/to/video.mp4")
video_tensor = video_tensor.to(device)
with torch.no_grad():
outputs = model(video_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_names = ["0%", "60%", "70%", "80%", "85%", "90%"]
print(f"Predicted normality rate: {class_names[predicted_class]}")
print(f"Confidence: {probabilities[0][predicted_class].item():.4f}")
Limitations
- Trained on specific microscopy equipment and protocols
- Performance may vary with different imaging conditions
- Should be used as diagnostic support, not sole decision-making tool
- Requires proper video preprocessing
Citation
If you use this model, please cite:
@misc{sperm-normality-classifier,
author = {Raid Athmane Benlala},
title = {Sperm Normality Rate Classifier},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}}
}
- Downloads last month
- 1
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support