toderian commited on
Commit
17daa0b
·
verified ·
1 Parent(s): eb359e3

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. README.md +297 -210
  2. config.json +31 -25
  3. example_inference.py +189 -0
  4. model.py +200 -0
  5. model.safetensors +3 -0
  6. modeling_cervical.py +166 -0
  7. preprocessor_config.json +14 -0
  8. pytorch_model.bin +3 -0
README.md CHANGED
@@ -1,270 +1,357 @@
1
- # Cervical Type Classification - Model Training
2
-
3
- ## Overview
4
-
5
- This project classifies cervical images into 3 transformation zone types:
6
- - **Type 1**: Fully visible squamocolumnar junction (SCJ)
7
- - **Type 2**: Partially visible SCJ
8
- - **Type 3**: SCJ not visible (inside cervical canal)
9
-
10
- ## Best Model Summary
11
-
12
- | Metric | Value |
13
- |--------|-------|
14
- | **Validation Accuracy** | **65.52%** |
15
- | **Macro F1** | **65.61%** |
16
- | Best Epoch | 34 |
17
- | Total Parameters | 1,327,235 |
18
-
19
  ---
20
 
21
- ## Best Model Configuration
22
-
23
- **Run Name:** `L32_64_128_256_Res_SE_lr5e-04_d0.3`
24
 
25
- ### Architecture
26
 
27
- | Component | Value |
28
- |-----------|-------|
29
- | Conv Layers | [32, 64, 128, 256] |
30
- | FC Layers | [256, 128] |
31
- | Kernel Size | 3x3 |
32
- | Pooling | MaxPool 2x2 |
33
- | Batch Normalization | Yes |
34
- | Activation | ReLU |
35
- | Residual Connections | **Yes** |
36
- | SE Attention | **Yes** |
37
 
38
- ### Training Settings
39
 
40
- | Parameter | Value |
41
- |-----------|-------|
42
- | Learning Rate | 5e-4 |
43
- | Weight Decay | 1e-4 |
44
- | Dropout | 0.3 |
45
- | Batch Size | 32 |
46
- | Focal Loss Gamma | 2.0 |
47
- | Label Smoothing | 0.1 |
48
- | Data Augmentation | Yes |
49
 
50
  ---
51
 
52
- ## Performance Metrics
53
-
54
- ### Per-Class Metrics
55
 
56
- | Class | Precision | Recall | F1-Score | Support |
57
- |-------|-----------|--------|----------|---------|
58
- | **Type 1** | 79.26% | 61.49% | 69.26% | 348 |
59
- | **Type 2** | 58.09% | 75.29% | 65.58% | 348 |
60
- | **Type 3** | 64.40% | 59.77% | 62.00% | 348 |
61
- | **Macro Avg** | 67.25% | 65.52% | **65.61%** | 1044 |
62
-
63
- ### Confusion Matrix
64
 
65
  ```
66
- Predicted
67
- Type 1 Type 2 Type 3
68
- Actual Type 1 214 84 50
69
- Type 2 21 262 65
70
- Type 3 35 105 208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ```
72
 
73
- ### Interpretation
74
-
75
- | Finding | Implication |
76
- |---------|-------------|
77
- | Type 1 has highest precision (79%) | When model predicts Type 1, it's usually correct |
78
- | Type 2 has highest recall (75%) | Model catches most Type 2 cases |
79
- | Type 3 has lowest metrics | Hardest to classify - often confused with Type 2 |
80
- | Type 2 Type 3 confusion is common | 105 Type 3 misclassified as Type 2 |
81
-
82
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- ## Grid Search Results
85
-
86
- A grid search of 32 configurations was performed on January 17, 2026.
87
-
88
- ### Search Space
89
-
90
- | Parameter | Values Tested |
91
- |-----------|---------------|
92
- | Conv Layers | [32,64,128,256], [64,128,256] |
93
- | Learning Rate | 5e-4, 1e-4 |
94
- | Dropout | 0.3, 0.4 |
95
- | Residual | Yes, No |
96
- | SE Attention | Yes, No |
97
-
98
- ### Top 10 Configurations
99
-
100
- | Rank | Configuration | Accuracy | Key Features |
101
- |------|--------------|----------|--------------|
102
- | 1 | L32_64_128_256_Res_SE_lr5e-04_d0.3 | **65.52%** | 4-layer, Res+SE |
103
- | 2 | L64_128_256_Res_SE_lr5e-04_d0.3 | 65.04% | 3-layer, Res+SE |
104
- | 3 | L32_64_128_256_Res_SE_lr1e-04_d0.3 | 64.94% | 4-layer, lower LR |
105
- | 4 | L64_128_256_Res_SE_lr1e-04_d0.3 | 64.37% | 3-layer, lower LR |
106
- | 5 | L32_64_128_256_Res_SE_lr5e-04_d0.4 | 64.18% | Higher dropout |
107
- | 6 | L32_64_128_256_Res_lr5e-04_d0.4 | 64.08% | No SE |
108
- | 7 | L32_64_128_256_Res_lr1e-04_d0.3 | 63.60% | No SE, lower LR |
109
- | 8 | L32_64_128_256_Res_SE_lr1e-04_d0.4 | 63.51% | Lower LR, higher dropout |
110
- | 9 | L64_128_256_Res_SE_lr5e-04_d0.4 | 63.22% | 3-layer, higher dropout |
111
- | 10 | L64_128_256_Res_SE_lr1e-04_d0.4 | 63.12% | 3-layer, lower LR |
112
-
113
- ### Key Findings
114
-
115
- | Finding | Evidence |
116
- |---------|----------|
117
- | **Residual + SE is critical** | Top 10 models all use residual connections; top 4 use both Res+SE |
118
- | **4-layer network is better** | [32,64,128,256] outperforms [64,128,256] |
119
- | **Higher LR (5e-4) preferred** | 5e-4 consistently beats 1e-4 |
120
- | **Lower dropout (0.3) preferred** | 0.3 dropout outperforms 0.4 |
121
- | **Plain CNN performs worst** | Models without Res or SE are at the bottom |
122
-
123
- ### What Worked vs What Didn't
124
-
125
- | Worked | Didn't Work |
126
- |--------|-------------|
127
- | Residual connections | Plain convolutions |
128
- | SE attention blocks | No attention |
129
- | 4 conv layers | 3 conv layers |
130
- | LR = 5e-4 | LR = 1e-4 (too slow) |
131
- | Dropout = 0.3 | Dropout = 0.4 (too aggressive) |
132
- | Focal Loss | - |
133
- | Label smoothing 0.1 | - |
 
 
 
 
 
 
 
134
 
135
  ---
136
 
137
- ## Data
138
 
139
- | Split | Samples | Classes | Distribution |
140
- |-------|---------|---------|--------------|
141
- | Train | ~7,000 | 3 | Balanced after augmentation |
142
- | Test | 1,044 | 3 | [348, 348, 348] |
143
 
144
- ### Image Specifications
145
-
146
- - Size: Variable (resized during training)
147
- - Channels: 3 (RGB)
148
- - Source: Colposcopy images
149
-
150
- ---
151
 
152
- ## Model Files
153
 
154
- ### Best Model Location
 
 
 
 
 
155
 
156
- ```
157
- ./best_model.pth (this folder)
158
- ```
159
 
160
- Original training output:
161
  ```
162
- /data/downloads/cervical_type/_output/grid_search_v2_20260117_212011/run_001_L32_64_128_256_Res_SE_lr5e-04_d0.3/
 
 
 
 
 
163
  ```
164
 
165
- ### Checkpoint Contents
166
 
167
- ```python
168
- {
169
- "epoch": 34,
170
- "model_state_dict": ...,
171
- "optimizer_state_dict": ...,
172
- "scheduler_state_dict": ...,
173
- "metrics": {...},
174
- "model_config": {...}
175
- }
176
- ```
177
 
178
- ### Files in This Folder
179
 
180
- | File | Description |
181
- |------|-------------|
182
- | `best_model.pth` | Model checkpoint (weights + optimizer state) |
183
- | `config.json` | Training configuration used |
184
- | `training_history.json` | Loss/accuracy per epoch |
185
- | `grid_search_summary.json` | All 32 grid search results |
186
- | `README.md` | This file |
187
 
188
  ### Loading the Model
189
 
190
  ```python
191
  import torch
 
 
 
 
 
 
 
192
 
193
- # Load checkpoint (from this folder)
194
- checkpoint = torch.load('best_model.pth', weights_only=False)
 
195
 
196
- # Create model with same config
197
- model = BaseCNN(
198
- conv_layers=[32, 64, 128, 256],
199
- fc_layers=[256, 128],
200
- num_classes=3,
201
- dropout=0.3,
202
- use_residual=True,
203
- use_se_attention=True
204
- )
205
 
206
  # Load weights
207
- model.load_state_dict(checkpoint['model_state_dict'])
 
208
  model.eval()
209
  ```
210
 
211
- ---
212
-
213
- ## Output Structure
214
 
215
- ```
216
- _output/
217
- └── grid_search_v2_20260117_212011/
218
- ├── grid_search_config.json # Search space definition
219
- ├── all_results.json # All 32 run results
220
- ├── summary.json # Sorted results + best run
221
- ├── logs/
222
- │ └── grid_search.log
223
- └── run_001_.../ # Best run
224
- ├── checkpoints/
225
- │ ├── best_model.pth # Best validation accuracy
226
- │ ├── latest.pth # Final epoch
227
- │ └── epoch_*.pth # Periodic saves
228
- └── logs/
229
- ├── run_config.json
230
- └── training_history.json
 
 
 
 
 
 
 
 
231
  ```
232
 
233
  ---
234
 
235
- ## Comparison with v1 Baseline
 
 
 
 
 
 
 
 
 
236
 
237
- | Version | Accuracy | Improvement |
238
- |---------|----------|-------------|
239
- | v1 Baseline | 61.69% | - |
240
- | **v2 Best (Res+SE)** | **65.52%** | **+3.83%** |
241
 
242
- The addition of residual connections and SE attention improved accuracy by nearly 4%.
 
 
 
243
 
244
  ---
245
 
246
- ## Recommendations for Future Work
247
 
248
- 1. **Try deeper networks** - Add 5th conv layer [32, 64, 128, 256, 512]
249
- 2. **Transfer learning** - Use pretrained EfficientNet or ResNet backbone
250
- 3. **Address Type 3 confusion** - Type 3 is often misclassified as Type 2
251
- 4. **Ensemble methods** - Combine top 3-5 models
252
- 5. **Test Time Augmentation** - Average predictions over augmented versions
253
- 6. **More training data** - Current ~7k samples may be limiting
254
 
255
- ---
256
 
257
- ## Quick Start
258
 
259
- ```bash
260
- # Run the best configuration
261
- python train_grid_v2.py
262
 
263
- # Or load and evaluate the best model
264
- python evaluate.py --model /path/to/best_model.pth
265
- ```
 
 
 
 
 
 
 
266
 
267
  ---
268
 
269
- *Last updated: January 2026*
270
- *Grid search: 32 configurations, ~15 hours on single GPU*
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - image-classification
7
+ - medical
8
+ - cervical-cancer
9
+ - pytorch
10
+ - cnn
11
+ - colposcopy
12
+ datasets:
13
+ - custom
14
+ metrics:
15
+ - accuracy
16
+ - f1
17
+ pipeline_tag: image-classification
18
+ library_name: pytorch
19
  ---
20
 
21
+ # Cervical Cancer Classification CNN
 
 
22
 
23
+ A CNN model for classifying cervical colposcopy images into 4 severity classes for cervical cancer screening.
24
 
25
+ ## Model Description
 
 
 
 
 
 
 
 
 
26
 
27
+ This model classifies cervical images into:
28
 
29
+ | Class | Label | Description | Clinical Action |
30
+ |-------|-------|-------------|-----------------|
31
+ | 0 | Normal | Healthy cervical tissue | Routine screening in 3-5 years |
32
+ | 1 | LSIL | Low-grade Squamous Intraepithelial Lesion | Monitor, repeat test in 6-12 months |
33
+ | 2 | HSIL | High-grade Squamous Intraepithelial Lesion | Colposcopy, biopsy, treatment required |
34
+ | 3 | Cancer | Invasive cervical cancer | Immediate oncology referral |
 
 
 
35
 
36
  ---
37
 
38
+ ## Model Architecture
 
 
39
 
40
+ ### Architecture Diagram
 
 
 
 
 
 
 
41
 
42
  ```
43
+ ┌─────────────────────────────────────────────────────────────┐
44
+ │ INPUT IMAGE │
45
+ │ (3 × 224 × 298) │
46
+ └─────────────────────────┬───────────────────────────────────┘
47
+
48
+ ┌─────────────────────────▼───────────────────────────────────┐
49
+ │ CONV BLOCK 1 │
50
+ │ ├── Conv2d(3 → 32, kernel=3×3, padding=1) │
51
+ │ ├── BatchNorm2d(32) │
52
+ │ ├── ReLU │
53
+ │ └── MaxPool2d(2×2) │
54
+ │ Output: 32 × 112 × 149 │
55
+ └─────────────────────────┬───────────────────────────────────┘
56
+
57
+ ┌─────────────────────────▼───────────────────────────────────┐
58
+ │ CONV BLOCK 2 │
59
+ │ ├── Conv2d(32 → 64, kernel=3×3, padding=1) │
60
+ │ ├── BatchNorm2d(64) │
61
+ │ ├── ReLU │
62
+ │ └── MaxPool2d(2×2) │
63
+ │ Output: 64 × 56 × 74 │
64
+ └─────────────────────────┬───────────────────────────────────┘
65
+
66
+ ┌─────────────────────────▼───────────────────────────────────┐
67
+ │ CONV BLOCK 3 │
68
+ │ ├── Conv2d(64 → 128, kernel=3×3, padding=1) │
69
+ │ ├── BatchNorm2d(128) │
70
+ │ ├── ReLU │
71
+ │ └── MaxPool2d(2×2) │
72
+ │ Output: 128 × 28 × 37 │
73
+ └─────────────────────────┬───────────────────────────────────┘
74
+
75
+ ┌─────────────────────────▼───────────────────────────────────┐
76
+ │ CONV BLOCK 4 │
77
+ │ ├── Conv2d(128 → 256, kernel=3×3, padding=1) │
78
+ │ ├── BatchNorm2d(256) │
79
+ │ ├── ReLU │
80
+ │ └── MaxPool2d(2×2) │
81
+ │ Output: 256 × 14 × 18 │
82
+ └─────────────────────────┬───────────────────────────────────┘
83
+
84
+ ┌─────────────────────────▼───────────────────────────────────┐
85
+ │ GLOBAL AVERAGE POOLING │
86
+ │ └── AdaptiveAvgPool2d(1×1) │
87
+ │ Output: 256 × 1 × 1 → Flatten → 256 │
88
+ └─────────────────────────┬───────────────────────────────────┘
89
+
90
+ ┌─────────────────────────▼───────────────────────────────────┐
91
+ │ FC BLOCK 1 │
92
+ │ ├── Linear(256 → 256) │
93
+ │ ├── ReLU │
94
+ │ └── Dropout(0.5) │
95
+ └─────────────────────────┬───────────────────────────────────┘
96
+
97
+ ┌─────────────────────────▼───────────────────────────────────┐
98
+ │ FC BLOCK 2 │
99
+ │ ├── Linear(256 → 128) │
100
+ │ ├── ReLU │
101
+ │ └── Dropout(0.5) │
102
+ └─────────────────────────┬───────────────────────────────────┘
103
+
104
+ ┌─────────────────────────▼───────────────────────────────────┐
105
+ │ CLASSIFIER │
106
+ │ └── Linear(128 → 4) │
107
+ │ Output: 4 class logits │
108
+ └─────────────────────────┬───────────────────────────────────┘
109
+
110
+
111
+ [Normal, LSIL, HSIL, Cancer]
112
  ```
113
 
114
+ ### Architecture Summary Table
115
+
116
+ | Layer | Type | Input Shape | Output Shape | Parameters |
117
+ |-------|------|-------------|--------------|------------|
118
+ | conv_layers.0 | Conv2d | (3, 224, 298) | (32, 224, 298) | 896 |
119
+ | conv_layers.1 | BatchNorm2d | (32, 224, 298) | (32, 224, 298) | 64 |
120
+ | conv_layers.2 | ReLU | - | - | 0 |
121
+ | conv_layers.3 | MaxPool2d | (32, 224, 298) | (32, 112, 149) | 0 |
122
+ | conv_layers.4 | Conv2d | (32, 112, 149) | (64, 112, 149) | 18,496 |
123
+ | conv_layers.5 | BatchNorm2d | (64, 112, 149) | (64, 112, 149) | 128 |
124
+ | conv_layers.6 | ReLU | - | - | 0 |
125
+ | conv_layers.7 | MaxPool2d | (64, 112, 149) | (64, 56, 74) | 0 |
126
+ | conv_layers.8 | Conv2d | (64, 56, 74) | (128, 56, 74) | 73,856 |
127
+ | conv_layers.9 | BatchNorm2d | (128, 56, 74) | (128, 56, 74) | 256 |
128
+ | conv_layers.10 | ReLU | - | - | 0 |
129
+ | conv_layers.11 | MaxPool2d | (128, 56, 74) | (128, 28, 37) | 0 |
130
+ | conv_layers.12 | Conv2d | (128, 28, 37) | (256, 28, 37) | 295,168 |
131
+ | conv_layers.13 | BatchNorm2d | (256, 28, 37) | (256, 28, 37) | 512 |
132
+ | conv_layers.14 | ReLU | - | - | 0 |
133
+ | conv_layers.15 | MaxPool2d | (256, 28, 37) | (256, 14, 18) | 0 |
134
+ | avgpool | AdaptiveAvgPool2d | (256, 14, 18) | (256, 1, 1) | 0 |
135
+ | fc_layers.0 | Linear | 256 | 256 | 65,792 |
136
+ | fc_layers.1 | ReLU | - | - | 0 |
137
+ | fc_layers.2 | Dropout | - | - | 0 |
138
+ | fc_layers.3 | Linear | 256 | 128 | 32,896 |
139
+ | fc_layers.4 | ReLU | - | - | 0 |
140
+ | fc_layers.5 | Dropout | - | - | 0 |
141
+ | classifier | Linear | 128 | 4 | 516 |
142
+ | **Total** | | | | **488,580** |
143
+
144
+ ### PyTorch Model Code
145
 
146
+ ```python
147
+ import torch
148
+ import torch.nn as nn
149
+
150
+ class CervicalCancerCNN(nn.Module):
151
+ def __init__(self):
152
+ super().__init__()
153
+
154
+ # Convolutional layers: [32, 64, 128, 256]
155
+ self.conv_layers = nn.Sequential(
156
+ # Block 1: 3 -> 32
157
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
158
+ nn.BatchNorm2d(32),
159
+ nn.ReLU(inplace=True),
160
+ nn.MaxPool2d(2, 2),
161
+
162
+ # Block 2: 32 -> 64
163
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
164
+ nn.BatchNorm2d(64),
165
+ nn.ReLU(inplace=True),
166
+ nn.MaxPool2d(2, 2),
167
+
168
+ # Block 3: 64 -> 128
169
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
170
+ nn.BatchNorm2d(128),
171
+ nn.ReLU(inplace=True),
172
+ nn.MaxPool2d(2, 2),
173
+
174
+ # Block 4: 128 -> 256
175
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
176
+ nn.BatchNorm2d(256),
177
+ nn.ReLU(inplace=True),
178
+ nn.MaxPool2d(2, 2),
179
+ )
180
+
181
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
182
+
183
+ # Fully connected layers: [256, 128] -> 4
184
+ self.fc_layers = nn.Sequential(
185
+ nn.Linear(256, 256),
186
+ nn.ReLU(inplace=True),
187
+ nn.Dropout(0.5),
188
+ nn.Linear(256, 128),
189
+ nn.ReLU(inplace=True),
190
+ nn.Dropout(0.5),
191
+ )
192
+
193
+ self.classifier = nn.Linear(128, 4)
194
+
195
+ def forward(self, x):
196
+ x = self.conv_layers(x)
197
+ x = self.avgpool(x)
198
+ x = x.view(x.size(0), -1)
199
+ x = self.fc_layers(x)
200
+ x = self.classifier(x)
201
+ return x
202
+ ```
203
 
204
  ---
205
 
206
+ ## Performance
207
 
208
+ ### Overall Metrics
 
 
 
209
 
210
+ | Metric | Value |
211
+ |--------|-------|
212
+ | **Accuracy** | 59.52% |
213
+ | **Macro F1** | 59.85% |
214
+ | **Parameters** | 488,580 |
 
 
215
 
216
+ ### Per-Class Metrics
217
 
218
+ | Class | Precision | Recall | F1 Score | Support |
219
+ |-------|-----------|--------|----------|---------|
220
+ | Normal | 0.595 | 0.595 | 0.595 | 84 |
221
+ | LSIL | 0.521 | 0.583 | 0.551 | 84 |
222
+ | HSIL | 0.446 | 0.440 | 0.443 | 84 |
223
+ | Cancer | 0.853 | 0.762 | 0.805 | 84 |
224
 
225
+ ### Confusion Matrix
 
 
226
 
 
227
  ```
228
+ Predicted → Normal LSIL HSIL Cancer
229
+ Actual ↓
230
+ Normal 50 9 17 8
231
+ LSIL 24 49 11 0
232
+ HSIL 9 35 37 3
233
+ Cancer 1 1 18 64
234
  ```
235
 
236
+ ---
237
 
238
+ ## Usage
 
 
 
 
 
 
 
 
 
239
 
240
+ ### Installation
241
 
242
+ ```bash
243
+ pip install torch torchvision safetensors huggingface_hub
244
+ ```
 
 
 
 
245
 
246
  ### Loading the Model
247
 
248
  ```python
249
  import torch
250
+ from safetensors.torch import load_file
251
+ from huggingface_hub import hf_hub_download
252
+ import json
253
+
254
+ # Download model files
255
+ model_file = hf_hub_download("toderian/cerviguard_lesion", "model.safetensors")
256
+ config_file = hf_hub_download("toderian/cerviguard_lesion", "config.json")
257
 
258
+ # Load config
259
+ with open(config_file) as f:
260
+ config = json.load(f)
261
 
262
+ # Define model (copy from above or download modeling_cervical.py)
263
+ model = CervicalCancerCNN()
 
 
 
 
 
 
 
264
 
265
  # Load weights
266
+ state_dict = load_file(model_file)
267
+ model.load_state_dict(state_dict)
268
  model.eval()
269
  ```
270
 
271
+ ### Inference
 
 
272
 
273
+ ```python
274
+ from PIL import Image
275
+ import torchvision.transforms as T
276
+
277
+ # Preprocessing
278
+ transform = T.Compose([
279
+ T.Resize((224, 298)),
280
+ T.ToTensor(),
281
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
282
+ ])
283
+
284
+ # Load and preprocess image
285
+ image = Image.open("cervical_image.jpg").convert("RGB")
286
+ input_tensor = transform(image).unsqueeze(0)
287
+
288
+ # Inference
289
+ with torch.no_grad():
290
+ output = model(input_tensor)
291
+ probabilities = torch.softmax(output, dim=1)
292
+ prediction = output.argmax(dim=1).item()
293
+
294
+ classes = ["Normal", "LSIL", "HSIL", "Cancer"]
295
+ print(f"Prediction: {classes[prediction]}")
296
+ print(f"Confidence: {probabilities[0][prediction]:.2%}")
297
  ```
298
 
299
  ---
300
 
301
+ ## Training Details
302
+
303
+ | Parameter | Value |
304
+ |-----------|-------|
305
+ | Learning Rate | 1e-4 |
306
+ | Batch Size | 32 |
307
+ | Optimizer | Adam |
308
+ | Loss | CrossEntropyLoss |
309
+ | Dropout | 0.5 |
310
+ | Epochs | 34 (early stopping at 24) |
311
 
312
+ ### Dataset
 
 
 
313
 
314
+ | Split | Samples | Distribution |
315
+ |-------|---------|--------------|
316
+ | Train | 3,003 | Imbalanced [1540, 469, 854, 140] |
317
+ | Test | 336 | Balanced [84, 84, 84, 84] |
318
 
319
  ---
320
 
321
+ ## Limitations
322
 
323
+ - Trained on limited dataset (~3k samples)
324
+ - HSIL class has lowest performance (F1=0.443)
325
+ - Should not be used as sole diagnostic tool
326
+ - Intended for research and screening assistance only
 
 
327
 
328
+ ## Medical Disclaimer
329
 
330
+ ⚠️ **This model is for research purposes only.** It should not be used as a substitute for professional medical diagnosis. Always consult qualified healthcare professionals for cervical cancer screening and diagnosis.
331
 
332
+ ---
 
 
333
 
334
+ ## Files in This Repository
335
+
336
+ | File | Description |
337
+ |------|-------------|
338
+ | `model.safetensors` | Model weights (safetensors format) |
339
+ | `pytorch_model.bin` | Model weights (legacy PyTorch format) |
340
+ | `config.json` | Model configuration |
341
+ | `preprocessor_config.json` | Image preprocessing settings |
342
+ | `modeling_cervical.py` | Model class definition |
343
+ | `example_inference.py` | Example inference script |
344
 
345
  ---
346
 
347
+ ## Citation
348
+
349
+ ```bibtex
350
+ @misc{cervical-cancer-cnn-2025,
351
+ author = {Toderian},
352
+ title = {Cervical Cancer Classification CNN},
353
+ year = {2025},
354
+ publisher = {Hugging Face},
355
+ url = {https://huggingface.co/toderian/cerviguard_lesion}
356
+ }
357
+ ```
config.json CHANGED
@@ -1,26 +1,32 @@
1
  {
2
- "batch_size": 32,
3
- "learning_rate": 0.0005,
4
- "weight_decay": 0.0001,
5
- "layers": [
6
- 32,
7
- 64,
8
- 128,
9
- 256
10
- ],
11
- "use_residual": true,
12
- "use_se_attention": true,
13
- "focal_gamma": 2.0,
14
- "label_smoothing": 0.1,
15
- "dropout": 0.3,
16
- "kernel": 3,
17
- "batchnorm": true,
18
- "activation": "ReLU",
19
- "pool": true,
20
- "fc_multipliers": [
21
- 1.0,
22
- 0.5
23
- ],
24
- "nr_classes": 3,
25
- "augmentation": true
26
- }
 
 
 
 
 
 
 
1
  {
2
+ "architectures": ["CervicalCancerCNN"],
3
+ "model_type": "cervical-cancer-cnn",
4
+ "auto_map": {
5
+ "AutoModel": "modeling_cervical.CervicalCancerCNN"
6
+ },
7
+ "num_labels": 4,
8
+ "num_classes": 4,
9
+ "id2label": {
10
+ "0": "Normal",
11
+ "1": "LSIL",
12
+ "2": "HSIL",
13
+ "3": "Cancer"
14
+ },
15
+ "label2id": {
16
+ "Normal": 0,
17
+ "LSIL": 1,
18
+ "HSIL": 2,
19
+ "Cancer": 3
20
+ },
21
+ "conv_layers": [32, 64, 128, 256],
22
+ "fc_layers": [256, 128],
23
+ "dropout": 0.5,
24
+ "input_channels": 3,
25
+ "input_size": {
26
+ "height": 224,
27
+ "width": 298
28
+ },
29
+ "total_parameters": 488580,
30
+ "problem_type": "single_label_classification",
31
+ "torch_dtype": "float32"
32
+ }
example_inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example inference script for Cervical Cancer Classification model.
3
+
4
+ Usage:
5
+ # From local directory:
6
+ python example_inference.py --image path/to/image.jpg --model ./
7
+
8
+ # From Hugging Face Hub:
9
+ python example_inference.py --image path/to/image.jpg --model toderian/cerviguard_lesion
10
+ """
11
+
12
+ import argparse
13
+ import torch
14
+ import torch.nn as nn
15
+ from PIL import Image
16
+ import torchvision.transforms as T
17
+ from pathlib import Path
18
+ import json
19
+
20
+
21
+ class CervicalCancerCNN(nn.Module):
22
+ """CNN for cervical cancer classification."""
23
+
24
+ def __init__(self, config=None):
25
+ super().__init__()
26
+
27
+ config = config or {}
28
+ conv_channels = config.get("conv_layers", [32, 64, 128, 256])
29
+ fc_sizes = config.get("fc_layers", [256, 128])
30
+ dropout = config.get("dropout", 0.5)
31
+ num_classes = config.get("num_classes", 4)
32
+
33
+ # Convolutional layers
34
+ layers = []
35
+ in_channels = 3
36
+ for out_channels in conv_channels:
37
+ layers.extend([
38
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
39
+ nn.BatchNorm2d(out_channels),
40
+ nn.ReLU(inplace=True),
41
+ nn.MaxPool2d(kernel_size=2, stride=2),
42
+ ])
43
+ in_channels = out_channels
44
+
45
+ self.conv_layers = nn.Sequential(*layers)
46
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
47
+
48
+ # FC layers
49
+ fc_blocks = []
50
+ in_features = conv_channels[-1]
51
+ for fc_size in fc_sizes:
52
+ fc_blocks.extend([
53
+ nn.Linear(in_features, fc_size),
54
+ nn.ReLU(inplace=True),
55
+ nn.Dropout(dropout),
56
+ ])
57
+ in_features = fc_size
58
+
59
+ self.fc_layers = nn.Sequential(*fc_blocks)
60
+ self.classifier = nn.Linear(in_features, num_classes)
61
+
62
+ def forward(self, x):
63
+ x = self.conv_layers(x)
64
+ x = self.avgpool(x)
65
+ x = x.view(x.size(0), -1)
66
+ x = self.fc_layers(x)
67
+ x = self.classifier(x)
68
+ return x
69
+
70
+
71
+ def load_model_local(model_dir, device="cpu"):
72
+ """Load model from local directory."""
73
+ model_dir = Path(model_dir)
74
+
75
+ # Load config
76
+ config_path = model_dir / "config.json"
77
+ config = {}
78
+ if config_path.exists():
79
+ with open(config_path) as f:
80
+ config = json.load(f)
81
+
82
+ # Create model
83
+ model = CervicalCancerCNN(config)
84
+
85
+ # Load weights
86
+ if (model_dir / "model.safetensors").exists():
87
+ from safetensors.torch import load_file
88
+ state_dict = load_file(str(model_dir / "model.safetensors"))
89
+ model.load_state_dict(state_dict)
90
+ elif (model_dir / "pytorch_model.bin").exists():
91
+ state_dict = torch.load(model_dir / "pytorch_model.bin", map_location=device, weights_only=True)
92
+ model.load_state_dict(state_dict)
93
+ else:
94
+ raise FileNotFoundError(f"No model weights found in {model_dir}")
95
+
96
+ model.to(device)
97
+ model.eval()
98
+ return model, config
99
+
100
+
101
+ def load_model_hub(repo_id, device="cpu"):
102
+ """Load model from Hugging Face Hub."""
103
+ from huggingface_hub import hf_hub_download, snapshot_download
104
+
105
+ # Download model files
106
+ model_dir = snapshot_download(repo_id=repo_id)
107
+ return load_model_local(model_dir, device)
108
+
109
+
110
+ def load_model(model_path, device="cpu"):
111
+ """Load model from local path or Hugging Face Hub."""
112
+ model_path = Path(model_path)
113
+
114
+ if model_path.exists():
115
+ return load_model_local(model_path, device)
116
+ else:
117
+ # Assume it's a Hugging Face repo ID
118
+ return load_model_hub(str(model_path), device)
119
+
120
+
121
+ def get_preprocessor(config):
122
+ """Get image preprocessing transform."""
123
+ # Get size from config or use defaults
124
+ input_size = config.get("input_size", {"height": 224, "width": 298})
125
+ height = input_size.get("height", 224)
126
+ width = input_size.get("width", 298)
127
+
128
+ return T.Compose([
129
+ T.Resize((height, width)),
130
+ T.ToTensor(),
131
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
132
+ ])
133
+
134
+
135
+ def predict(model, image_tensor, config):
136
+ """Run inference and return prediction."""
137
+ # Get label mapping from config
138
+ id2label = config.get("id2label", {
139
+ "0": "Normal",
140
+ "1": "LSIL",
141
+ "2": "HSIL",
142
+ "3": "Cancer"
143
+ })
144
+
145
+ with torch.no_grad():
146
+ output = model(image_tensor)
147
+ probabilities = torch.softmax(output, dim=1)[0]
148
+ prediction = output.argmax(dim=1).item()
149
+
150
+ return {
151
+ "class_id": prediction,
152
+ "class_name": id2label.get(str(prediction), f"Class {prediction}"),
153
+ "probabilities": {
154
+ id2label.get(str(i), f"Class {i}"): f"{prob:.2%}"
155
+ for i, prob in enumerate(probabilities.tolist())
156
+ },
157
+ "confidence": f"{probabilities[prediction]:.2%}"
158
+ }
159
+
160
+
161
+ def main():
162
+ parser = argparse.ArgumentParser(description="Cervical Cancer Classification")
163
+ parser.add_argument("--image", required=True, help="Path to input image")
164
+ parser.add_argument("--model", default="./", help="Path to model dir or HF repo ID")
165
+ parser.add_argument("--device", default="cpu", help="Device (cpu/cuda)")
166
+ args = parser.parse_args()
167
+
168
+ print(f"Loading model from {args.model}...")
169
+ model, config = load_model(args.model, args.device)
170
+
171
+ print(f"Processing image: {args.image}")
172
+ transform = get_preprocessor(config)
173
+ image = Image.open(args.image).convert('RGB')
174
+ image_tensor = transform(image).unsqueeze(0).to(args.device)
175
+
176
+ result = predict(model, image_tensor, config)
177
+
178
+ print("\n" + "=" * 50)
179
+ print("PREDICTION RESULT")
180
+ print("=" * 50)
181
+ print(f"Class: {result['class_name']}")
182
+ print(f"Confidence: {result['confidence']}")
183
+ print("\nAll probabilities:")
184
+ for cls, prob in result['probabilities'].items():
185
+ print(f" {cls}: {prob}")
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
model.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cervical Cancer Classification Model
3
+
4
+ This file provides the model architecture for easy import.
5
+
6
+ Usage:
7
+ from model import CervicalCancerCNN, load_model, predict
8
+
9
+ model = load_model("model.safetensors")
10
+ result = predict(model, image_tensor)
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from pathlib import Path
16
+
17
+
18
+ class CervicalCancerCNN(nn.Module):
19
+ """
20
+ CNN for cervical cancer classification.
21
+
22
+ Classifies cervical colposcopy images into 4 severity classes:
23
+ - 0: Normal - Healthy cervical tissue
24
+ - 1: LSIL - Low-grade Squamous Intraepithelial Lesion
25
+ - 2: HSIL - High-grade Squamous Intraepithelial Lesion
26
+ - 3: Cancer - Invasive cervical cancer
27
+
28
+ Architecture:
29
+ Conv[32,64,128,256] -> AvgPool -> FC[256,128] -> Classifier[4]
30
+
31
+ Input:
32
+ Tensor of shape (batch, 3, 224, 298)
33
+
34
+ Output:
35
+ Logits of shape (batch, 4)
36
+ """
37
+
38
+ # Class labels
39
+ CLASSES = {
40
+ 0: "Normal",
41
+ 1: "LSIL",
42
+ 2: "HSIL",
43
+ 3: "Cancer"
44
+ }
45
+
46
+ def __init__(self, config=None):
47
+ super().__init__()
48
+
49
+ # Default configuration
50
+ config = config or {}
51
+ conv_channels = config.get("conv_layers", [32, 64, 128, 256])
52
+ fc_sizes = config.get("fc_layers", [256, 128])
53
+ dropout = config.get("dropout", 0.5)
54
+ num_classes = config.get("num_classes", 4)
55
+
56
+ # Build convolutional layers
57
+ layers = []
58
+ in_channels = 3
59
+ for out_channels in conv_channels:
60
+ layers.extend([
61
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
62
+ nn.BatchNorm2d(out_channels),
63
+ nn.ReLU(inplace=True),
64
+ nn.MaxPool2d(kernel_size=2, stride=2),
65
+ ])
66
+ in_channels = out_channels
67
+
68
+ self.conv_layers = nn.Sequential(*layers)
69
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
70
+
71
+ # Build fully connected layers
72
+ fc_blocks = []
73
+ in_features = conv_channels[-1]
74
+ for fc_size in fc_sizes:
75
+ fc_blocks.extend([
76
+ nn.Linear(in_features, fc_size),
77
+ nn.ReLU(inplace=True),
78
+ nn.Dropout(dropout),
79
+ ])
80
+ in_features = fc_size
81
+
82
+ self.fc_layers = nn.Sequential(*fc_blocks)
83
+ self.classifier = nn.Linear(in_features, num_classes)
84
+
85
+ def forward(self, x):
86
+ """Forward pass."""
87
+ x = self.conv_layers(x)
88
+ x = self.avgpool(x)
89
+ x = x.view(x.size(0), -1)
90
+ x = self.fc_layers(x)
91
+ x = self.classifier(x)
92
+ return x
93
+
94
+ def predict_class(self, x):
95
+ """Predict class labels and probabilities."""
96
+ self.eval()
97
+ with torch.no_grad():
98
+ logits = self.forward(x)
99
+ probs = torch.softmax(logits, dim=1)
100
+ preds = torch.argmax(logits, dim=1)
101
+ return preds, probs
102
+
103
+
104
+ def load_model(model_path, device="cpu"):
105
+ """
106
+ Load model from file.
107
+
108
+ Args:
109
+ model_path: Path to model weights (.safetensors or .bin/.pth)
110
+ device: Device to load model on ("cpu" or "cuda")
111
+
112
+ Returns:
113
+ Loaded model in eval mode
114
+ """
115
+ model = CervicalCancerCNN()
116
+
117
+ model_path = Path(model_path)
118
+
119
+ if model_path.suffix == ".safetensors":
120
+ from safetensors.torch import load_file
121
+ state_dict = load_file(str(model_path))
122
+ else:
123
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
124
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
125
+ state_dict = checkpoint["model_state_dict"]
126
+ else:
127
+ state_dict = checkpoint
128
+
129
+ model.load_state_dict(state_dict)
130
+ model.to(device)
131
+ model.eval()
132
+
133
+ return model
134
+
135
+
136
+ def predict(model, image_tensor, device="cpu"):
137
+ """
138
+ Run prediction on an image tensor.
139
+
140
+ Args:
141
+ model: Loaded CervicalCancerCNN model
142
+ image_tensor: Preprocessed image tensor (1, 3, 224, 298)
143
+ device: Device for inference
144
+
145
+ Returns:
146
+ Dictionary with prediction results
147
+ """
148
+ model.eval()
149
+ image_tensor = image_tensor.to(device)
150
+
151
+ with torch.no_grad():
152
+ logits = model(image_tensor)
153
+ probs = torch.softmax(logits, dim=1)[0]
154
+ pred_class = torch.argmax(logits, dim=1).item()
155
+
156
+ return {
157
+ "class_id": pred_class,
158
+ "class_name": CervicalCancerCNN.CLASSES[pred_class],
159
+ "confidence": probs[pred_class].item(),
160
+ "probabilities": {
161
+ CervicalCancerCNN.CLASSES[i]: probs[i].item()
162
+ for i in range(4)
163
+ }
164
+ }
165
+
166
+
167
+ def get_preprocessing_transform():
168
+ """
169
+ Get the preprocessing transform for input images.
170
+
171
+ Returns:
172
+ torchvision.transforms.Compose object
173
+ """
174
+ import torchvision.transforms as T
175
+
176
+ return T.Compose([
177
+ T.Resize((224, 298)),
178
+ T.ToTensor(),
179
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
180
+ ])
181
+
182
+
183
+ # Quick usage example
184
+ if __name__ == "__main__":
185
+ import sys
186
+
187
+ # Create model
188
+ model = CervicalCancerCNN()
189
+ print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
190
+
191
+ # Print architecture
192
+ print("\nArchitecture:")
193
+ print(model)
194
+
195
+ # Test forward pass
196
+ dummy_input = torch.randn(1, 3, 224, 298)
197
+ output = model(dummy_input)
198
+ print(f"\nInput shape: {dummy_input.shape}")
199
+ print(f"Output shape: {output.shape}")
200
+ print(f"Output classes: {list(CervicalCancerCNN.CLASSES.values())}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1f4af0e010e669105d0b7c8bb09d2781b3df73572281f2666fb1d054aeb0eeb
3
+ size 1961104
modeling_cervical.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cervical Cancer Classification Model
3
+
4
+ Custom CNN model for classifying cervical images into 4 severity classes.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class CervicalCancerCNN(nn.Module):
12
+ """
13
+ CNN for cervical cancer classification.
14
+
15
+ Classifies cervical images into 4 classes:
16
+ - 0: Normal
17
+ - 1: LSIL (Low-grade Squamous Intraepithelial Lesion)
18
+ - 2: HSIL (High-grade Squamous Intraepithelial Lesion)
19
+ - 3: Cancer
20
+
21
+ Args:
22
+ config: Optional configuration dict with keys:
23
+ - conv_layers: List of conv channel sizes (default: [32, 64, 128, 256])
24
+ - fc_layers: List of FC layer sizes (default: [256, 128])
25
+ - num_classes: Number of output classes (default: 4)
26
+ - dropout: Dropout rate (default: 0.5)
27
+ """
28
+
29
+ def __init__(self, config=None):
30
+ super().__init__()
31
+
32
+ # Default config
33
+ self.config = config or {
34
+ "conv_layers": [32, 64, 128, 256],
35
+ "fc_layers": [256, 128],
36
+ "num_classes": 4,
37
+ "dropout": 0.5,
38
+ "input_channels": 3,
39
+ }
40
+
41
+ conv_channels = self.config.get("conv_layers", [32, 64, 128, 256])
42
+ fc_sizes = self.config.get("fc_layers", [256, 128])
43
+ dropout = self.config.get("dropout", 0.5)
44
+ num_classes = self.config.get("num_classes", 4)
45
+ input_channels = self.config.get("input_channels", 3)
46
+
47
+ # Build convolutional layers
48
+ layers = []
49
+ in_channels = input_channels
50
+
51
+ for out_channels in conv_channels:
52
+ layers.extend([
53
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(out_channels),
55
+ nn.ReLU(inplace=True),
56
+ nn.MaxPool2d(kernel_size=2, stride=2),
57
+ ])
58
+ in_channels = out_channels
59
+
60
+ self.conv_layers = nn.Sequential(*layers)
61
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
62
+
63
+ # Build fully connected layers
64
+ fc_blocks = []
65
+ in_features = conv_channels[-1]
66
+
67
+ for fc_size in fc_sizes:
68
+ fc_blocks.extend([
69
+ nn.Linear(in_features, fc_size),
70
+ nn.ReLU(inplace=True),
71
+ nn.Dropout(dropout),
72
+ ])
73
+ in_features = fc_size
74
+
75
+ self.fc_layers = nn.Sequential(*fc_blocks)
76
+ self.classifier = nn.Linear(in_features, num_classes)
77
+
78
+ # Class labels
79
+ self.id2label = {
80
+ 0: "Normal",
81
+ 1: "LSIL",
82
+ 2: "HSIL",
83
+ 3: "Cancer"
84
+ }
85
+ self.label2id = {v: k for k, v in self.id2label.items()}
86
+
87
+ def forward(self, x):
88
+ """
89
+ Forward pass.
90
+
91
+ Args:
92
+ x: Input tensor of shape (batch, 3, height, width)
93
+
94
+ Returns:
95
+ Logits tensor of shape (batch, num_classes)
96
+ """
97
+ x = self.conv_layers(x)
98
+ x = self.avgpool(x)
99
+ x = x.view(x.size(0), -1)
100
+ x = self.fc_layers(x)
101
+ x = self.classifier(x)
102
+ return x
103
+
104
+ def predict(self, x):
105
+ """
106
+ Predict class labels.
107
+
108
+ Args:
109
+ x: Input tensor of shape (batch, 3, height, width)
110
+
111
+ Returns:
112
+ Tuple of (predicted_class_ids, probabilities)
113
+ """
114
+ self.eval()
115
+ with torch.no_grad():
116
+ logits = self.forward(x)
117
+ probs = torch.softmax(logits, dim=1)
118
+ preds = torch.argmax(logits, dim=1)
119
+ return preds, probs
120
+
121
+ @classmethod
122
+ def from_pretrained(cls, model_path, device="cpu"):
123
+ """
124
+ Load pretrained model.
125
+
126
+ Args:
127
+ model_path: Path to model directory or checkpoint file
128
+ device: Device to load model on
129
+
130
+ Returns:
131
+ Loaded model
132
+ """
133
+ import os
134
+ from pathlib import Path
135
+
136
+ model_path = Path(model_path)
137
+
138
+ # Try different file formats
139
+ if model_path.is_dir():
140
+ if (model_path / "model.safetensors").exists():
141
+ weights_path = model_path / "model.safetensors"
142
+ use_safetensors = True
143
+ elif (model_path / "pytorch_model.bin").exists():
144
+ weights_path = model_path / "pytorch_model.bin"
145
+ use_safetensors = False
146
+ else:
147
+ raise FileNotFoundError(f"No model weights found in {model_path}")
148
+ else:
149
+ weights_path = model_path
150
+ use_safetensors = str(model_path).endswith(".safetensors")
151
+
152
+ # Create model
153
+ model = cls()
154
+
155
+ # Load weights
156
+ if use_safetensors:
157
+ from safetensors.torch import load_file
158
+ state_dict = load_file(str(weights_path))
159
+ else:
160
+ state_dict = torch.load(weights_path, map_location=device, weights_only=True)
161
+
162
+ model.load_state_dict(state_dict)
163
+ model.to(device)
164
+ model.eval()
165
+
166
+ return model
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_rescale": true,
4
+ "do_resize": true,
5
+ "image_mean": [0.485, 0.456, 0.406],
6
+ "image_std": [0.229, 0.224, 0.225],
7
+ "image_processor_type": "ImageProcessor",
8
+ "resample": 3,
9
+ "rescale_factor": 0.00392156862745098,
10
+ "size": {
11
+ "height": 224,
12
+ "width": 298
13
+ }
14
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d4556ca826cfb058a2aa99352adfd5783a6c5f1186931943a56ddbb7ac83f7a
3
+ size 1969965