Sperm Normality Rate Classifier

Model Description

This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture.

Architecture: 3D ResNet18
Task: Video Classification
Classes: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%)
Input: 90 frames (3 seconds at 30 fps), 224x224 RGB
Framework: PyTorch

Intended Use

This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine.

Model Details

  • Input Format: Video clips of 90 frames (3 seconds at 30 fps)
  • Preprocessing:
    • Frames resized to 224x224
    • ImageNet normalization + per-video standardization
    • Combined normalization for robustness
  • Output: Probability distribution over 6 normality rate classes

Training Details

  • Dataset: Balanced dataset with 100 clips per class
  • Normalization: Combined ImageNet + per-video standardization
  • Loss Function: Focal Loss with class weights
  • Optimizer: Adam (lr=1e-4)
  • Early Stopping: Patience of 10 epochs

Usage

import torch
import torch.nn as nn
import torchvision.models.video as video_models
import cv2
import numpy as np

# Define model architecture
class VideoClassifier(nn.Module):
    def __init__(self, num_classes=6):
        super(VideoClassifier, self).__init__()
        self.backbone = video_models.r3d_18(pretrained=False)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VideoClassifier(num_classes=6)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)
model.eval()

# Preprocess video
def preprocess_video(video_path, num_frames=90, target_size=(224, 224)):
    cap = cv2.VideoCapture(video_path)
    frames = []
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, target_size)
        frame = frame.astype(np.float32) / 255.0
        frames.append(frame)
    
    cap.release()
    
    # Sample to 90 frames
    if len(frames) < num_frames:
        repeat_factor = int(np.ceil(num_frames / len(frames)))
        frames = (frames * repeat_factor)[:num_frames]
    elif len(frames) > num_frames:
        indices = np.linspace(0, len(frames) - 1, num_frames).astype(int)
        frames = [frames[i] for i in indices]
    
    # Convert to tensor (C, T, H, W)
    frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2)
    
    # Apply normalization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
    frames = (frames - mean) / std
    
    # Per-video standardization
    video_mean = frames.mean()
    video_std = frames.std()
    if video_std > 0:
        frames = (frames - video_mean) / video_std
    
    return frames.unsqueeze(0)  # Add batch dimension

# Inference
video_tensor = preprocess_video("path/to/video.mp4")
video_tensor = video_tensor.to(device)

with torch.no_grad():
    outputs = model(video_tensor)
    probabilities = torch.softmax(outputs, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()

class_names = ["0%", "60%", "70%", "80%", "85%", "90%"]
print(f"Predicted normality rate: {class_names[predicted_class]}")
print(f"Confidence: {probabilities[0][predicted_class].item():.4f}")

Limitations

  • Trained on specific microscopy equipment and protocols
  • Performance may vary with different imaging conditions
  • Should be used as diagnostic support, not sole decision-making tool
  • Requires proper video preprocessing

Citation

If you use this model, please cite:

@misc{sperm-normality-classifier,
  author = {Raid Athmane Benlala},
  title = {Sperm Normality Rate Classifier},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}}
}
Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support