raidAthmaneBenlala's picture
Upload README.md with huggingface_hub
415365e verified
---
license: mit
tags:
- video-classification
- medical
- microscopy
- sperm-analysis
- pytorch
library_name: pytorch
---
# Sperm Normality Rate Classifier
## Model Description
This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture.
**Architecture**: 3D ResNet18
**Task**: Video Classification
**Classes**: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%)
**Input**: 90 frames (3 seconds at 30 fps), 224x224 RGB
**Framework**: PyTorch
## Intended Use
This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine.
## Model Details
- **Input Format**: Video clips of 90 frames (3 seconds at 30 fps)
- **Preprocessing**:
- Frames resized to 224x224
- ImageNet normalization + per-video standardization
- Combined normalization for robustness
- **Output**: Probability distribution over 6 normality rate classes
## Training Details
- **Dataset**: Balanced dataset with 100 clips per class
- **Normalization**: Combined ImageNet + per-video standardization
- **Loss Function**: Focal Loss with class weights
- **Optimizer**: Adam (lr=1e-4)
- **Early Stopping**: Patience of 10 epochs
## Usage
```python
import torch
import torch.nn as nn
import torchvision.models.video as video_models
import cv2
import numpy as np
# Define model architecture
class VideoClassifier(nn.Module):
def __init__(self, num_classes=6):
super(VideoClassifier, self).__init__()
self.backbone = video_models.r3d_18(pretrained=False)
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.backbone(x)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VideoClassifier(num_classes=6)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)
model.eval()
# Preprocess video
def preprocess_video(video_path, num_frames=90, target_size=(224, 224)):
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, target_size)
frame = frame.astype(np.float32) / 255.0
frames.append(frame)
cap.release()
# Sample to 90 frames
if len(frames) < num_frames:
repeat_factor = int(np.ceil(num_frames / len(frames)))
frames = (frames * repeat_factor)[:num_frames]
elif len(frames) > num_frames:
indices = np.linspace(0, len(frames) - 1, num_frames).astype(int)
frames = [frames[i] for i in indices]
# Convert to tensor (C, T, H, W)
frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2)
# Apply normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
frames = (frames - mean) / std
# Per-video standardization
video_mean = frames.mean()
video_std = frames.std()
if video_std > 0:
frames = (frames - video_mean) / video_std
return frames.unsqueeze(0) # Add batch dimension
# Inference
video_tensor = preprocess_video("path/to/video.mp4")
video_tensor = video_tensor.to(device)
with torch.no_grad():
outputs = model(video_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_names = ["0%", "60%", "70%", "80%", "85%", "90%"]
print(f"Predicted normality rate: {class_names[predicted_class]}")
print(f"Confidence: {probabilities[0][predicted_class].item():.4f}")
```
## Limitations
- Trained on specific microscopy equipment and protocols
- Performance may vary with different imaging conditions
- Should be used as diagnostic support, not sole decision-making tool
- Requires proper video preprocessing
## Citation
If you use this model, please cite:
```
@misc{sperm-normality-classifier,
author = {Raid Athmane Benlala},
title = {Sperm Normality Rate Classifier},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}}
}
```