--- license: mit tags: - video-classification - medical - microscopy - sperm-analysis - pytorch library_name: pytorch --- # Sperm Normality Rate Classifier ## Model Description This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture. **Architecture**: 3D ResNet18 **Task**: Video Classification **Classes**: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%) **Input**: 90 frames (3 seconds at 30 fps), 224x224 RGB **Framework**: PyTorch ## Intended Use This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine. ## Model Details - **Input Format**: Video clips of 90 frames (3 seconds at 30 fps) - **Preprocessing**: - Frames resized to 224x224 - ImageNet normalization + per-video standardization - Combined normalization for robustness - **Output**: Probability distribution over 6 normality rate classes ## Training Details - **Dataset**: Balanced dataset with 100 clips per class - **Normalization**: Combined ImageNet + per-video standardization - **Loss Function**: Focal Loss with class weights - **Optimizer**: Adam (lr=1e-4) - **Early Stopping**: Patience of 10 epochs ## Usage ```python import torch import torch.nn as nn import torchvision.models.video as video_models import cv2 import numpy as np # Define model architecture class VideoClassifier(nn.Module): def __init__(self, num_classes=6): super(VideoClassifier, self).__init__() self.backbone = video_models.r3d_18(pretrained=False) in_features = self.backbone.fc.in_features self.backbone.fc = nn.Linear(in_features, num_classes) def forward(self, x): return self.backbone(x) # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = VideoClassifier(num_classes=6) model.load_state_dict(torch.load('best_model.pth', map_location=device)) model.to(device) model.eval() # Preprocess video def preprocess_video(video_path, num_frames=90, target_size=(224, 224)): cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.resize(frame, target_size) frame = frame.astype(np.float32) / 255.0 frames.append(frame) cap.release() # Sample to 90 frames if len(frames) < num_frames: repeat_factor = int(np.ceil(num_frames / len(frames))) frames = (frames * repeat_factor)[:num_frames] elif len(frames) > num_frames: indices = np.linspace(0, len(frames) - 1, num_frames).astype(int) frames = [frames[i] for i in indices] # Convert to tensor (C, T, H, W) frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2) # Apply normalization mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1) frames = (frames - mean) / std # Per-video standardization video_mean = frames.mean() video_std = frames.std() if video_std > 0: frames = (frames - video_mean) / video_std return frames.unsqueeze(0) # Add batch dimension # Inference video_tensor = preprocess_video("path/to/video.mp4") video_tensor = video_tensor.to(device) with torch.no_grad(): outputs = model(video_tensor) probabilities = torch.softmax(outputs, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() class_names = ["0%", "60%", "70%", "80%", "85%", "90%"] print(f"Predicted normality rate: {class_names[predicted_class]}") print(f"Confidence: {probabilities[0][predicted_class].item():.4f}") ``` ## Limitations - Trained on specific microscopy equipment and protocols - Performance may vary with different imaging conditions - Should be used as diagnostic support, not sole decision-making tool - Requires proper video preprocessing ## Citation If you use this model, please cite: ``` @misc{sperm-normality-classifier, author = {Raid Athmane Benlala}, title = {Sperm Normality Rate Classifier}, year = {2025}, publisher = {Hugging Face}, howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}} } ```