raidAthmaneBenlala commited on
Commit
415365e
·
verified ·
1 Parent(s): 241291f

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +145 -0
README.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - video-classification
5
+ - medical
6
+ - microscopy
7
+ - sperm-analysis
8
+ - pytorch
9
+ library_name: pytorch
10
+ ---
11
+
12
+ # Sperm Normality Rate Classifier
13
+
14
+ ## Model Description
15
+
16
+ This model classifies human sperm microscopic videos into normality rate categories using a 3D ResNet18 architecture.
17
+
18
+ **Architecture**: 3D ResNet18
19
+ **Task**: Video Classification
20
+ **Classes**: 6 normality rate categories (0%, 60%, 70%, 80%, 85%, 90%)
21
+ **Input**: 90 frames (3 seconds at 30 fps), 224x224 RGB
22
+ **Framework**: PyTorch
23
+
24
+ ## Intended Use
25
+
26
+ This model is designed for analyzing human sperm microscopic videos to predict normality rates. It's intended for research and diagnostic support in reproductive medicine.
27
+
28
+ ## Model Details
29
+
30
+ - **Input Format**: Video clips of 90 frames (3 seconds at 30 fps)
31
+ - **Preprocessing**:
32
+ - Frames resized to 224x224
33
+ - ImageNet normalization + per-video standardization
34
+ - Combined normalization for robustness
35
+ - **Output**: Probability distribution over 6 normality rate classes
36
+
37
+ ## Training Details
38
+
39
+ - **Dataset**: Balanced dataset with 100 clips per class
40
+ - **Normalization**: Combined ImageNet + per-video standardization
41
+ - **Loss Function**: Focal Loss with class weights
42
+ - **Optimizer**: Adam (lr=1e-4)
43
+ - **Early Stopping**: Patience of 10 epochs
44
+
45
+ ## Usage
46
+
47
+ ```python
48
+ import torch
49
+ import torch.nn as nn
50
+ import torchvision.models.video as video_models
51
+ import cv2
52
+ import numpy as np
53
+
54
+ # Define model architecture
55
+ class VideoClassifier(nn.Module):
56
+ def __init__(self, num_classes=6):
57
+ super(VideoClassifier, self).__init__()
58
+ self.backbone = video_models.r3d_18(pretrained=False)
59
+ in_features = self.backbone.fc.in_features
60
+ self.backbone.fc = nn.Linear(in_features, num_classes)
61
+
62
+ def forward(self, x):
63
+ return self.backbone(x)
64
+
65
+ # Load model
66
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+ model = VideoClassifier(num_classes=6)
68
+ model.load_state_dict(torch.load('best_model.pth', map_location=device))
69
+ model.to(device)
70
+ model.eval()
71
+
72
+ # Preprocess video
73
+ def preprocess_video(video_path, num_frames=90, target_size=(224, 224)):
74
+ cap = cv2.VideoCapture(video_path)
75
+ frames = []
76
+
77
+ while True:
78
+ ret, frame = cap.read()
79
+ if not ret:
80
+ break
81
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
82
+ frame = cv2.resize(frame, target_size)
83
+ frame = frame.astype(np.float32) / 255.0
84
+ frames.append(frame)
85
+
86
+ cap.release()
87
+
88
+ # Sample to 90 frames
89
+ if len(frames) < num_frames:
90
+ repeat_factor = int(np.ceil(num_frames / len(frames)))
91
+ frames = (frames * repeat_factor)[:num_frames]
92
+ elif len(frames) > num_frames:
93
+ indices = np.linspace(0, len(frames) - 1, num_frames).astype(int)
94
+ frames = [frames[i] for i in indices]
95
+
96
+ # Convert to tensor (C, T, H, W)
97
+ frames = torch.FloatTensor(np.array(frames)).permute(3, 0, 1, 2)
98
+
99
+ # Apply normalization
100
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
101
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
102
+ frames = (frames - mean) / std
103
+
104
+ # Per-video standardization
105
+ video_mean = frames.mean()
106
+ video_std = frames.std()
107
+ if video_std > 0:
108
+ frames = (frames - video_mean) / video_std
109
+
110
+ return frames.unsqueeze(0) # Add batch dimension
111
+
112
+ # Inference
113
+ video_tensor = preprocess_video("path/to/video.mp4")
114
+ video_tensor = video_tensor.to(device)
115
+
116
+ with torch.no_grad():
117
+ outputs = model(video_tensor)
118
+ probabilities = torch.softmax(outputs, dim=1)
119
+ predicted_class = torch.argmax(probabilities, dim=1).item()
120
+
121
+ class_names = ["0%", "60%", "70%", "80%", "85%", "90%"]
122
+ print(f"Predicted normality rate: {class_names[predicted_class]}")
123
+ print(f"Confidence: {probabilities[0][predicted_class].item():.4f}")
124
+ ```
125
+
126
+ ## Limitations
127
+
128
+ - Trained on specific microscopy equipment and protocols
129
+ - Performance may vary with different imaging conditions
130
+ - Should be used as diagnostic support, not sole decision-making tool
131
+ - Requires proper video preprocessing
132
+
133
+ ## Citation
134
+
135
+ If you use this model, please cite:
136
+
137
+ ```
138
+ @misc{sperm-normality-classifier,
139
+ author = {Raid Athmane Benlala},
140
+ title = {Sperm Normality Rate Classifier},
141
+ year = {2025},
142
+ publisher = {Hugging Face},
143
+ howpublished = {\url{https://huggingface.co/raidAthmaneBenlala/normality-rate-classifier}}
144
+ }
145
+ ```