farrell236 commited on
Commit
b8c9192
·
1 Parent(s): 1691c67
Files changed (15) hide show
  1. .gitignore +148 -0
  2. LICENCE.md +21 -0
  3. README.md +41 -1
  4. RESULTS.md +87 -0
  5. app.py +206 -0
  6. augmentations.py +85 -0
  7. checkpoints/advamd.pt +3 -0
  8. checkpoints/drus.pt +3 -0
  9. checkpoints/pig.pt +3 -0
  10. dataloader.py +55 -0
  11. model.py +61 -0
  12. requirements.txt +10 -0
  13. run_inference.py +99 -0
  14. test.py +368 -0
  15. train.py +366 -0
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+
3
+ deep-learning-models
4
+ .pytest_cache/
5
+ backup
6
+ examples-local
7
+
8
+ ### Python template
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ env/
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *,cover
55
+ .hypothesis/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # pyenv
82
+ .python-version
83
+
84
+ # celery beat schedule file
85
+ celerybeat-schedule
86
+
87
+ # SageMath parsed files
88
+ *.sage.py
89
+
90
+ # dotenv
91
+ .env
92
+
93
+ # virtualenv
94
+ .venv
95
+ venv/
96
+ ENV/
97
+
98
+ # Spyder project settings
99
+ .spyderproject
100
+
101
+ # Rope project settings
102
+ .ropeproject
103
+ ### JetBrains template
104
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
105
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
106
+
107
+ # User-specific stuff:
108
+ .idea
109
+ .idea/**/workspace.xml
110
+ .idea/**/tasks.xml
111
+ .idea/dictionaries
112
+
113
+ # Sensitive or high-churn files:
114
+ .idea/**/dataSources/
115
+ .idea/**/dataSources.ids
116
+ .idea/**/dataSources.xml
117
+ .idea/**/dataSources.local.xml
118
+ .idea/**/sqlDataSources.xml
119
+ .idea/**/dynamic.xml
120
+ .idea/**/uiDesigner.xml
121
+
122
+ # Gradle:
123
+ .idea/**/gradle.xml
124
+ .idea/**/libraries
125
+
126
+ # Mongo Explorer plugin:
127
+ .idea/**/mongoSettings.xml
128
+
129
+ ## File-based project format:
130
+ *.iws
131
+
132
+ ## Plugin-specific files:
133
+
134
+ # IntelliJ
135
+ /out/
136
+
137
+ # mpeltonen/sbt-idea plugin
138
+ .idea_modules/
139
+
140
+ # JIRA plugin
141
+ atlassian-ide-plugin.xml
142
+
143
+ # Crashlytics plugin (for Android Studio and IntelliJ)
144
+ com_crashlytics_export_strings.xml
145
+ crashlytics.properties
146
+ crashlytics-build.properties
147
+ fabric.properties
148
+
LICENCE.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 NIH/DIR
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -12,4 +12,44 @@ license: mit
12
  short_description: Framework for Classifying patient-based AMD in CFP images
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  short_description: Framework for Classifying patient-based AMD in CFP images
13
  ---
14
 
15
+ # DeepSeeNet PyTorch
16
+
17
+ This repository is a PyTorch reimplementation of the original DeepSeeNet model:
18
+
19
+ https://github.com/ncbi-nlp/DeepSeeNet
20
+
21
+ DeepSeeNet predicts patient-level AREDS Simplified Severity Scale scores for age-related macular degeneration (AMD) from bilateral color fundus photographs. The model follows the original DeepSeeNet design by first predicting eye-level AMD risk factors, then combining predictions from both eyes into a patient-level simplified severity score.
22
+
23
+ ## Tasks
24
+
25
+ The implementation trains three image-level subnetworks:
26
+
27
+ | Task | Classes | Output |
28
+ |---|---:|---|
29
+ | `ADVAMD` | 2 | late AMD absent / present |
30
+ | `DRUS` | 3 | small/none, medium, large drusen |
31
+ | `PIG` | 2 | pigmentary abnormality absent / present |
32
+
33
+ The final AREDS simplified score is computed from bilateral predictions:
34
+
35
+ - score `5` if late AMD is predicted in either eye
36
+ - otherwise, score is based on large drusen and pigmentary abnormalities across both eyes
37
+ - bilateral medium drusen contributes one point
38
+
39
+ ## Citation
40
+
41
+ If you use this repository, please cite the original DeepSeeNet paper:
42
+
43
+ ```bibtex
44
+ @article{peng2019deepseenet,
45
+ title={DeepSeeNet: A Deep Learning Model for Automated Classification of Patient-based Age-related Macular Degeneration Severity from Color Fundus Photographs},
46
+ author={Peng, Yifan and Dharssi, Shazia and Chen, Qingyu and Keenan, Tiarnan D. and Agr\'{o}n, Elvira and Wong, Wai T. and Chew, Emily Y. and Lu, Zhiyong},
47
+ journal={Ophthalmology},
48
+ volume={126},
49
+ number={4},
50
+ pages={565--575},
51
+ year={2019},
52
+ publisher={Elsevier},
53
+ doi={10.1016/j.ophtha.2018.11.015}
54
+ }
55
+ ```
RESULTS.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Checkpoint Results
2
+
3
+ ## ADVAMD
4
+
5
+ ```text
6
+ Task: ADVAMD | endpoint: late_amd | positive_class=1
7
+
8
+ Metrics
9
+ -------
10
+ overall_accuracy 0.9658 (0.9628-0.9689)
11
+ sensitivity 0.8417 (0.8255-0.8576)
12
+ specificity 0.9852 (0.9831-0.9874)
13
+ kappa 0.8498 (0.8367-0.8632)
14
+ auc 0.9811 (0.9777-0.9844)
15
+
16
+ Classifier metrics
17
+ ------------------
18
+ loss 0.1119
19
+ exact_accuracy 0.9658 (0.9628-0.9689)
20
+ exact_kappa 0.8498 (0.8367-0.8632)
21
+
22
+ Confusion matrix (rows=true, cols=pred):
23
+ [[11217 168]
24
+ [ 282 1499]]
25
+
26
+ Binary confusion matrix (rows=true, cols=pred):
27
+ [[11217 168]
28
+ [ 282 1499]]
29
+ ```
30
+
31
+ ## DRUS
32
+
33
+ ```text
34
+ Task: DRUS | endpoint: large_drusen | positive_class=2
35
+
36
+ Metrics
37
+ -------
38
+ overall_accuracy 0.8816 (0.8763-0.8869)
39
+ sensitivity 0.7708 (0.7588-0.7832)
40
+ specificity 0.9368 (0.9319-0.9418)
41
+ kappa 0.7263 (0.7144-0.7386)
42
+ auc 0.9489 (0.9452-0.9524)
43
+
44
+ Classifier metrics
45
+ ------------------
46
+ loss 0.5903
47
+ exact_accuracy 0.7471 (0.7400-0.7542)
48
+ exact_kappa 0.6170 (0.6066-0.6280)
49
+ macro_ovr_auc 0.8960 (0.8919-0.9001)
50
+
51
+ Confusion matrix (rows=true, cols=pred):
52
+ [[4205 820 115]
53
+ [ 951 2255 440]
54
+ [ 182 822 3376]]
55
+
56
+ Binary confusion matrix (rows=true, cols=pred):
57
+ [[8231 555]
58
+ [1004 3376]]
59
+ ```
60
+
61
+ ## PIG
62
+
63
+ ```text
64
+ Task: PIG | endpoint: pigmentary_abnormality | positive_class=1
65
+
66
+ Metrics
67
+ -------
68
+ overall_accuracy 0.8925 (0.8874-0.8976)
69
+ sensitivity 0.8606 (0.8502-0.8701)
70
+ specificity 0.9113 (0.9053-0.9171)
71
+ kappa 0.7702 (0.7594-0.7811)
72
+ auc 0.9498 (0.9460-0.9536)
73
+
74
+ Classifier metrics
75
+ ------------------
76
+ loss 0.2734
77
+ exact_accuracy 0.8925 (0.8874-0.8976)
78
+ exact_kappa 0.7702 (0.7594-0.7811)
79
+
80
+ Confusion matrix (rows=true, cols=pred):
81
+ [[7541 734]
82
+ [ 682 4209]]
83
+
84
+ Binary confusion matrix (rows=true, cols=pred):
85
+ [[7541 734]
86
+ [ 682 4209]]
87
+ ```
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from augmentations import get_val_transforms
10
+ from model import DeepSeeNet
11
+
12
+
13
+ N_CLASSES = {
14
+ "ADVAMD": 2,
15
+ "DRUS": 3,
16
+ "PIG": 2,
17
+ }
18
+
19
+ LABELS = {
20
+ "ADVAMD": ["no_late_amd", "late_amd"],
21
+ "DRUS": ["small_none", "medium", "large"],
22
+ "PIG": ["no_pigment", "pigment"],
23
+ }
24
+
25
+
26
+ class AlbumentationsTransform:
27
+ def __init__(self, transform):
28
+ self.transform = transform
29
+
30
+ def __call__(self, image):
31
+ return self.transform(image=np.asarray(image))["image"]
32
+
33
+
34
+ def parse_args():
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--checkpoint-folder", default="./checkpoints")
37
+ parser.add_argument("--backbone", default="inception_v3")
38
+ parser.add_argument("--image-size", type=int, default=1024)
39
+ parser.add_argument("--server-name", default="127.0.0.1")
40
+ parser.add_argument("--server-port", type=int, default=7860)
41
+ parser.add_argument("--share", action="store_true")
42
+ return parser.parse_args()
43
+
44
+
45
+ def load_model(path, task, backbone, device):
46
+ checkpoint = torch.load(path, map_location=device)
47
+ checkpoint_args = checkpoint.get("args", {})
48
+
49
+ model = DeepSeeNet(
50
+ n_classes=N_CLASSES[task],
51
+ backbone=checkpoint_args.get("backbone", backbone),
52
+ pretrained=False,
53
+ ).to(device)
54
+
55
+ model.load_state_dict(checkpoint["model"])
56
+ model.eval()
57
+ return model
58
+
59
+
60
+ def load_image(image, transform, device):
61
+ if image is None:
62
+ raise ValueError("Please upload both left and right images.")
63
+
64
+ image = image.convert("RGB")
65
+ return transform(image).unsqueeze(0).to(device)
66
+
67
+
68
+ @torch.no_grad()
69
+ def predict(model, image, task):
70
+ logits = model(image)[0].detach().cpu()
71
+ probs = F.softmax(logits, dim=0)
72
+ pred = int(torch.argmax(logits).item())
73
+
74
+ return {
75
+ "prediction": pred,
76
+ "label": LABELS[task][pred],
77
+ "confidence": float(probs[pred]),
78
+ "probabilities": {
79
+ LABELS[task][i]: float(probs[i])
80
+ for i in range(len(LABELS[task]))
81
+ },
82
+ }
83
+
84
+
85
+ def simplified_score(scores):
86
+ if scores["ADVAMD"]["left"]["prediction"] == 1 or scores["ADVAMD"]["right"]["prediction"] == 1:
87
+ return 5
88
+
89
+ score = 0
90
+ score += scores["PIG"]["left"]["prediction"] == 1
91
+ score += scores["PIG"]["right"]["prediction"] == 1
92
+ score += scores["DRUS"]["left"]["prediction"] == 2
93
+ score += scores["DRUS"]["right"]["prediction"] == 2
94
+ score += (
95
+ scores["DRUS"]["left"]["prediction"] == 1
96
+ and scores["DRUS"]["right"]["prediction"] == 1
97
+ )
98
+
99
+ return int(min(score, 5))
100
+
101
+
102
+ def format_probs(probabilities):
103
+ return " | ".join(
104
+ f"{label}: {prob:.3f}"
105
+ for label, prob in probabilities.items()
106
+ )
107
+
108
+
109
+ def model_info(args, device):
110
+ return f"""
111
+ # DeepSeeNet
112
+
113
+ <div style="display: grid; grid-template-columns: repeat(4, max-content); gap: 0.75rem 2rem; align-items: center;">
114
+ <div><b>Model</b><br><code>{args.backbone}</code></div>
115
+ <div><b>Input size</b><br><code>{args.image_size} × {args.image_size}</code></div>
116
+ <div><b>Device</b><br><code>{device.type}</code></div>
117
+ <div><b>Checkpoint folder</b><br><code>{args.checkpoint_folder}</code></div>
118
+ </div>
119
+ """
120
+
121
+
122
+ def make_app(args):
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ checkpoint_folder = Path(args.checkpoint_folder)
125
+ transform = AlbumentationsTransform(get_val_transforms(args.image_size))
126
+
127
+ models = {
128
+ "ADVAMD": load_model(checkpoint_folder / "advamd.pt", "ADVAMD", args.backbone, device),
129
+ "DRUS": load_model(checkpoint_folder / "drus.pt", "DRUS", args.backbone, device),
130
+ "PIG": load_model(checkpoint_folder / "pig.pt", "PIG", args.backbone, device),
131
+ }
132
+
133
+ def run(left_image, right_image):
134
+ left = load_image(left_image, transform, device)
135
+ right = load_image(right_image, transform, device)
136
+
137
+ scores = {}
138
+ for task, model in models.items():
139
+ scores[task] = {
140
+ "left": predict(model, left, task),
141
+ "right": predict(model, right, task),
142
+ }
143
+
144
+ score = simplified_score(scores)
145
+
146
+ summary_rows = [
147
+ ["AREDS simplified score", score],
148
+ ["Left eye", f"{scores['DRUS']['left']['label']}, {scores['PIG']['left']['label']}, {scores['ADVAMD']['left']['label']}"],
149
+ ["Right eye", f"{scores['DRUS']['right']['label']}, {scores['PIG']['right']['label']}, {scores['ADVAMD']['right']['label']}"],
150
+ ]
151
+
152
+ detail_rows = []
153
+ for task in ["ADVAMD", "DRUS", "PIG"]:
154
+ for eye in ["left", "right"]:
155
+ result = scores[task][eye]
156
+ detail_rows.append(
157
+ [
158
+ task,
159
+ eye,
160
+ result["label"],
161
+ f"{result['confidence']:.3f}",
162
+ format_probs(result["probabilities"]),
163
+ ]
164
+ )
165
+
166
+ return summary_rows, detail_rows
167
+
168
+ with gr.Blocks(title="DeepSeeNet") as demo:
169
+ gr.Markdown(model_info(args, device))
170
+
171
+ with gr.Row():
172
+ left_image = gr.Image(type="pil", label="Left image")
173
+ right_image = gr.Image(type="pil", label="Right image")
174
+
175
+ button = gr.Button("Run")
176
+
177
+ summary = gr.Dataframe(
178
+ headers=["Item", "Result"],
179
+ label="Summary",
180
+ )
181
+ details = gr.Dataframe(
182
+ headers=["Task", "Eye", "Prediction", "Confidence", "Probabilities"],
183
+ label="Model outputs",
184
+ )
185
+
186
+ button.click(
187
+ run,
188
+ inputs=[left_image, right_image],
189
+ outputs=[summary, details],
190
+ )
191
+
192
+ return demo
193
+
194
+
195
+ def main():
196
+ args = parse_args()
197
+ demo = make_app(args)
198
+ demo.launch(
199
+ server_name=args.server_name,
200
+ server_port=args.server_port,
201
+ share=args.share,
202
+ )
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
augmentations.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ augmentations.py
3
+
4
+ Simple camera-style augmentations for color fundus photography (CFP)
5
+ classification.
6
+
7
+ Expected input:
8
+ RGB NumPy image, shape (H, W, 3)
9
+
10
+ Dependencies:
11
+ pip install albumentations opencv-python
12
+ """
13
+
14
+ import cv2
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+
18
+
19
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
20
+ IMAGENET_STD = (0.229, 0.224, 0.225)
21
+
22
+
23
+ def get_train_transforms(
24
+ image_size=1024,
25
+ mean=IMAGENET_MEAN,
26
+ std=IMAGENET_STD,
27
+ ):
28
+ return A.Compose([
29
+ A.Resize(image_size, image_size),
30
+
31
+ # Geometry is safe
32
+ A.HorizontalFlip(p=0.5),
33
+
34
+ A.ShiftScaleRotate(
35
+ shift_limit=0.02,
36
+ scale_limit=0.03, # slightly reduced
37
+ rotate_limit=5, # slightly reduced
38
+ border_mode=0,
39
+ value=0,
40
+ p=0.3,
41
+ ),
42
+
43
+ # MUCH weaker photometric changes
44
+ A.RandomBrightnessContrast(
45
+ brightness_limit=0.08, # ↓ from 0.15
46
+ contrast_limit=0.08,
47
+ p=0.3,
48
+ ),
49
+
50
+ # Remove or reduce gamma
51
+ A.RandomGamma(
52
+ gamma_limit=(95, 105), # very mild
53
+ p=0.2,
54
+ ),
55
+
56
+ # Remove hue shift entirely (important)
57
+ # Hue shifts are not realistic for fundus physiology
58
+ # -> comment this out or reduce heavily
59
+ # A.HueSaturationValue(...)
60
+
61
+ # Keep mild quality perturbation
62
+ A.OneOf([
63
+ A.GaussianBlur(blur_limit=(3, 5)),
64
+ A.Downscale(scale_min=0.85, scale_max=0.95, interpolation=cv2.INTER_LINEAR),
65
+ A.ImageCompression(quality_lower=80, quality_upper=100),
66
+ ], p=0.15),
67
+
68
+ A.Normalize(mean=mean, std=std),
69
+ ToTensorV2(),
70
+ ])
71
+
72
+
73
+ def get_val_transforms(
74
+ image_size=1024,
75
+ mean=IMAGENET_MEAN,
76
+ std=IMAGENET_STD,
77
+ ):
78
+ """
79
+ Validation/test transforms.
80
+ """
81
+ return A.Compose([
82
+ A.Resize(image_size, image_size),
83
+ A.Normalize(mean=mean, std=std),
84
+ ToTensorV2(),
85
+ ])
checkpoints/advamd.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c600e3f70526da0c4d65d5e4d55a563f9a083d44da223385d7a8572e194e191e
3
+ size 89723328
checkpoints/drus.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efde754bd4d6be4b9afc4d0c17d016f138f52239ab36c4ce74ba6b24cef16245
3
+ size 89697156
checkpoints/pig.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fcdafe1a939b5603b3fbf76276702a60c68b657b359a0702b4cf9ae74fea8ee
3
+ size 89696006
dataloader.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch datasets and dataloaders for AREDS fundus images."""
2
+
3
+ from pathlib import Path
4
+ from typing import Callable, Optional, Tuple, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from PIL import Image
9
+ from torch import Tensor
10
+ from torch.utils.data import Dataset
11
+ from torchvision import transforms
12
+
13
+
14
+ TASKS = ("ADVAMD", "DRUS", "PIG")
15
+
16
+
17
+ DEFAULT_TRANSFORM = transforms.Compose(
18
+ [
19
+ transforms.Resize(224),
20
+ transforms.CenterCrop(224),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(
23
+ mean=(0.485, 0.456, 0.406),
24
+ std=(0.229, 0.224, 0.225),
25
+ ),
26
+ ]
27
+ )
28
+
29
+
30
+ class AREDSDataset(Dataset):
31
+ def __init__(
32
+ self,
33
+ csv_path: Union[str, Path],
34
+ image_root: Union[str, Path],
35
+ task: str,
36
+ transform: Optional[Callable[[Image.Image], Tensor]] = None,
37
+ ) -> None:
38
+ task = task.upper()
39
+ if task not in TASKS:
40
+ raise ValueError(f"task must be one of {TASKS}")
41
+ self.image_root = Path(image_root)
42
+ self.task = task
43
+ self.transform = transform or DEFAULT_TRANSFORM
44
+ self.data = pd.read_csv(csv_path)
45
+
46
+ def __len__(self) -> int:
47
+ return len(self.data)
48
+
49
+ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
50
+ row = self.data.iloc[index]
51
+ image_path = self.image_root / row.pathname
52
+ image = Image.open(image_path).convert("RGB")
53
+ image = self.transform(image)
54
+ label = torch.tensor(int(row[self.task]), dtype=torch.long)
55
+ return image, label
model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DeepSeeNet model definition."""
2
+
3
+ from torch import Tensor, nn
4
+
5
+ try:
6
+ import timm
7
+ except ImportError: # pragma: no cover - handled when timm is absent.
8
+ timm = None
9
+
10
+
11
+ class DeepSeeNet(nn.Module):
12
+ """DeepSeeNet risk-factor classifier in PyTorch.
13
+
14
+ Args:
15
+ n_classes: Number of output classes.
16
+ backbone: Any timm model name that supports ``num_classes=0``. The
17
+ default uses InceptionV3.
18
+ pretrained: Load ImageNet weights for the backbone.
19
+ dropout: Dropout probability used by the classifier head.
20
+ freeze_backbone: If true, keep the backbone frozen and train only the
21
+ classifier head.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ n_classes: int = 2,
27
+ backbone: str = "inception_v3",
28
+ pretrained: bool = True,
29
+ dropout: float = 0.5,
30
+ freeze_backbone: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+ if n_classes < 1:
34
+ raise ValueError("n_classes must be positive")
35
+ if timm is None:
36
+ raise ImportError("timm is required to build DeepSeeNet")
37
+
38
+ self.backbone_name = backbone
39
+ self.backbone = timm.create_model(
40
+ backbone,
41
+ pretrained=pretrained,
42
+ num_classes=0,
43
+ global_pool="avg",
44
+ )
45
+ in_features = self.backbone.num_features
46
+ self.classifier = nn.Sequential(
47
+ nn.Linear(in_features, 256),
48
+ nn.ReLU(inplace=True),
49
+ nn.Dropout(dropout),
50
+ nn.Linear(256, 128),
51
+ nn.ReLU(inplace=True),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(128, n_classes),
54
+ )
55
+
56
+ if freeze_backbone:
57
+ self.backbone.requires_grad_(False)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ features = self.backbone(x)
61
+ return self.classifier(features)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ albumentations
5
+ numpy
6
+ pandas
7
+ scikit-learn
8
+ tqdm
9
+ gradio
10
+ pillow
run_inference.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run DeepSeeNet inference for AREDS simplified score."""
2
+
3
+ import argparse
4
+ import json
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from dataloader import DEFAULT_TRANSFORM
10
+ from model import DeepSeeNet
11
+
12
+
13
+ N_CLASSES = {
14
+ "ADVAMD": 2,
15
+ "DRUS": 3,
16
+ "PIG": 2,
17
+ }
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--left-image", required=True)
23
+ parser.add_argument("--right-image", required=True)
24
+ parser.add_argument("--advamd-checkpoint", required=True)
25
+ parser.add_argument("--drus-checkpoint", required=True)
26
+ parser.add_argument("--pig-checkpoint", required=True)
27
+ parser.add_argument("--backbone", default="inception_v3")
28
+ return parser.parse_args()
29
+
30
+
31
+ def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet:
32
+ checkpoint = torch.load(checkpoint_path, map_location=device)
33
+ checkpoint_args = checkpoint.get("args", {})
34
+ model = DeepSeeNet(
35
+ n_classes=N_CLASSES[task],
36
+ backbone=checkpoint_args.get("backbone", backbone),
37
+ pretrained=False,
38
+ ).to(device)
39
+ model.load_state_dict(checkpoint["model"])
40
+ model.eval()
41
+ return model
42
+
43
+
44
+ def load_image(path: str, device) -> torch.Tensor:
45
+ image = Image.open(path).convert("RGB")
46
+ return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device)
47
+
48
+
49
+ @torch.no_grad()
50
+ def predict(model: DeepSeeNet, image: torch.Tensor) -> int:
51
+ return int(model(image).argmax(dim=1).item())
52
+
53
+
54
+ def simplified_score(scores: dict[str, tuple[int, int]]) -> int:
55
+ score = 0
56
+ if scores["ADVAMD"][0] or scores["ADVAMD"][1]:
57
+ return 5
58
+ score += scores["PIG"][0] == 1
59
+ score += scores["PIG"][1] == 1
60
+ score += scores["DRUS"][0] == 2
61
+ score += scores["DRUS"][1] == 2
62
+ score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1
63
+ return min(score, 5)
64
+
65
+
66
+ def main() -> None:
67
+ args = parse_args()
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ images = {
70
+ "left": load_image(args.left_image, device),
71
+ "right": load_image(args.right_image, device),
72
+ }
73
+ checkpoints = {
74
+ "ADVAMD": args.advamd_checkpoint,
75
+ "DRUS": args.drus_checkpoint,
76
+ "PIG": args.pig_checkpoint,
77
+ }
78
+
79
+ scores = {}
80
+ for task, checkpoint in checkpoints.items():
81
+ model = load_model(checkpoint, task, args.backbone, device)
82
+ scores[task] = (
83
+ predict(model, images["left"]),
84
+ predict(model, images["right"]),
85
+ )
86
+
87
+ print(
88
+ json.dumps(
89
+ {
90
+ "simplified_score": simplified_score(scores),
91
+ "risk_factors": scores,
92
+ },
93
+ indent=2,
94
+ )
95
+ )
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
test.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any, Callable
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+ from tqdm import tqdm
14
+
15
+ try:
16
+ from sklearn.metrics import (
17
+ accuracy_score,
18
+ cohen_kappa_score,
19
+ confusion_matrix,
20
+ recall_score,
21
+ roc_auc_score,
22
+ )
23
+ except ImportError as exc:
24
+ raise ImportError(
25
+ "This evaluation script needs scikit-learn. Install with: pip install scikit-learn"
26
+ ) from exc
27
+
28
+ from augmentations import get_val_transforms
29
+ from dataloader import AREDSDataset
30
+ from model import DeepSeeNet
31
+
32
+
33
+ N_CLASSES = {
34
+ "ADVAMD": 2,
35
+ "DRUS": 3,
36
+ "PIG": 2,
37
+ }
38
+
39
+ DEFAULT_POSITIVE_CLASS = {
40
+ "ADVAMD": 1,
41
+ "DRUS": 2,
42
+ "PIG": 1,
43
+ }
44
+
45
+ ENDPOINT_NAME = {
46
+ "ADVAMD": "late_amd",
47
+ "DRUS": "large_drusen",
48
+ "PIG": "pigmentary_abnormality",
49
+ }
50
+
51
+
52
+ def parse_args() -> argparse.Namespace:
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--test-csv", required=True)
55
+ parser.add_argument("--image-root", required=True)
56
+ parser.add_argument("--checkpoint", required=True)
57
+ parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES)
58
+ parser.add_argument("--backbone", default="inception_v3")
59
+ parser.add_argument("--image-size", type=int, default=1024)
60
+ parser.add_argument("--batch-size", type=int, default=32)
61
+ parser.add_argument("--num-workers", type=int, default=16)
62
+
63
+ parser.add_argument("--positive-class", type=int, default=None)
64
+ parser.add_argument("--bootstrap-iters", type=int, default=2000)
65
+ parser.add_argument("--seed", type=int, default=123)
66
+ parser.add_argument("--bootstrap-unit-column", default=None)
67
+ parser.add_argument("--output-dir", default=None)
68
+ return parser.parse_args()
69
+
70
+
71
+ class AlbumentationsTransform:
72
+ def __init__(self, transform) -> None:
73
+ self.transform = transform
74
+
75
+ def __call__(self, image):
76
+ return self.transform(image=np.asarray(image))["image"]
77
+
78
+
79
+ @torch.no_grad()
80
+ def collect_predictions(model: torch.nn.Module, loader: DataLoader, device: torch.device) -> dict[str, np.ndarray | float]:
81
+ model.eval()
82
+ total_loss = 0.0
83
+ total_samples = 0
84
+ all_labels: list[np.ndarray] = []
85
+ all_logits: list[np.ndarray] = []
86
+
87
+ for images, labels in tqdm(loader, desc="test"):
88
+ images = images.to(device)
89
+ labels = labels.to(device)
90
+
91
+ logits = model(images)
92
+ if isinstance(logits, (tuple, list)):
93
+ logits = logits[0]
94
+
95
+ loss = F.cross_entropy(logits, labels)
96
+ batch_size = labels.size(0)
97
+ total_loss += loss.item() * batch_size
98
+ total_samples += batch_size
99
+
100
+ all_labels.append(labels.detach().cpu().numpy())
101
+ all_logits.append(logits.detach().cpu().numpy())
102
+
103
+ labels_np = np.concatenate(all_labels).astype(int)
104
+ logits_np = np.concatenate(all_logits, axis=0)
105
+ probs_np = torch.softmax(torch.from_numpy(logits_np), dim=1).numpy()
106
+ preds_np = probs_np.argmax(axis=1).astype(int)
107
+
108
+ return {
109
+ "loss": float(total_loss / max(total_samples, 1)),
110
+ "labels": labels_np,
111
+ "logits": logits_np,
112
+ "probs": probs_np,
113
+ "preds": preds_np,
114
+ }
115
+
116
+
117
+ def specificity_score(y_true_bin: np.ndarray, y_pred_bin: np.ndarray) -> float:
118
+ tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
119
+ fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
120
+ denom = tn + fp
121
+ return float(tn / denom) if denom else float("nan")
122
+
123
+
124
+ def safe_auc(y_true_bin: np.ndarray, y_score: np.ndarray) -> float:
125
+ if len(np.unique(y_true_bin)) < 2:
126
+ return float("nan")
127
+ return float(roc_auc_score(y_true_bin, y_score))
128
+
129
+
130
+ def compute_metrics(
131
+ y_true: np.ndarray,
132
+ y_pred: np.ndarray,
133
+ probs: np.ndarray,
134
+ n_classes: int,
135
+ positive_class: int,
136
+ ) -> dict[str, float]:
137
+ y_true_bin = (y_true == positive_class).astype(int)
138
+ y_pred_bin = (y_pred == positive_class).astype(int)
139
+ pos_score = probs[:, positive_class]
140
+
141
+ metrics = {
142
+ "loss": float("nan"),
143
+ "exact_accuracy": float(accuracy_score(y_true, y_pred)),
144
+ "exact_kappa": float(cohen_kappa_score(y_true, y_pred)),
145
+ "overall_accuracy": float(accuracy_score(y_true_bin, y_pred_bin)),
146
+ "sensitivity": float(recall_score(y_true_bin, y_pred_bin, pos_label=1, zero_division=0)),
147
+ "specificity": specificity_score(y_true_bin, y_pred_bin),
148
+ "kappa": float(cohen_kappa_score(y_true_bin, y_pred_bin)),
149
+ "auc": safe_auc(y_true_bin, pos_score),
150
+ }
151
+
152
+ if n_classes > 2 and len(np.unique(y_true)) > 1:
153
+ try:
154
+ metrics["macro_ovr_auc"] = float(
155
+ roc_auc_score(y_true, probs, labels=list(range(n_classes)), multi_class="ovr", average="macro")
156
+ )
157
+ except ValueError:
158
+ metrics["macro_ovr_auc"] = float("nan")
159
+
160
+ return metrics
161
+
162
+
163
+ def make_bootstrap_indices(
164
+ n: int,
165
+ n_iters: int,
166
+ rng: np.random.Generator,
167
+ units: np.ndarray | None = None,
168
+ ) -> list[np.ndarray]:
169
+ if n_iters <= 0:
170
+ return []
171
+
172
+ if units is None:
173
+ return [rng.integers(0, n, size=n) for _ in range(n_iters)]
174
+
175
+ unique_units = np.array(pd.unique(units))
176
+ row_indices_by_unit = {unit: np.where(units == unit)[0] for unit in unique_units}
177
+ out = []
178
+ for _ in range(n_iters):
179
+ sampled_units = rng.choice(unique_units, size=len(unique_units), replace=True)
180
+ out.append(np.concatenate([row_indices_by_unit[u] for u in sampled_units]))
181
+ return out
182
+
183
+
184
+ def bootstrap_ci(
185
+ metric_fn: Callable[[np.ndarray], dict[str, float]],
186
+ indices: list[np.ndarray],
187
+ ) -> dict[str, dict[str, float]]:
188
+ if not indices:
189
+ return {}
190
+
191
+ values_by_metric: dict[str, list[float]] = {}
192
+ for idx in tqdm(indices, desc="bootstrap", leave=False):
193
+ vals = metric_fn(idx)
194
+ for key, value in vals.items():
195
+ values_by_metric.setdefault(key, []).append(value)
196
+
197
+ intervals: dict[str, dict[str, float]] = {}
198
+ for key, values in values_by_metric.items():
199
+ arr = np.asarray(values, dtype=float)
200
+ intervals[key] = {
201
+ "ci_low": float(np.nanpercentile(arr, 2.5)),
202
+ "ci_high": float(np.nanpercentile(arr, 97.5)),
203
+ }
204
+ return intervals
205
+
206
+
207
+ def combine_with_ci(metrics: dict[str, float], ci: dict[str, dict[str, float]]) -> dict[str, Any]:
208
+ out: dict[str, Any] = {}
209
+ for key, value in metrics.items():
210
+ out[key] = {"value": float(value)}
211
+ if key in ci:
212
+ out[key].update(ci[key])
213
+ return out
214
+
215
+
216
+ def print_metric_table(metrics_with_ci: dict[str, Any]) -> None:
217
+ print("\nMetrics")
218
+ print("-------")
219
+ for key in ["overall_accuracy", "sensitivity", "specificity", "kappa", "auc"]:
220
+ item = metrics_with_ci[key]
221
+ if "ci_low" in item:
222
+ print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})")
223
+ else:
224
+ print(f"{key:20s} {item['value']:.4f}")
225
+
226
+ print("\nClassifier metrics")
227
+ print("------------------")
228
+ for key in ["loss", "exact_accuracy", "exact_kappa", "macro_ovr_auc"]:
229
+ if key not in metrics_with_ci:
230
+ continue
231
+ item = metrics_with_ci[key]
232
+ if "ci_low" in item:
233
+ print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})")
234
+ else:
235
+ print(f"{key:20s} {item['value']:.4f}")
236
+
237
+
238
+ def main() -> None:
239
+ args = parse_args()
240
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
241
+
242
+ task = args.task.upper()
243
+ n_classes = N_CLASSES[task]
244
+ positive_class = DEFAULT_POSITIVE_CLASS[task] if args.positive_class is None else args.positive_class
245
+ if not 0 <= positive_class < n_classes:
246
+ raise ValueError(f"positive_class={positive_class} is invalid for task={task} with {n_classes} classes")
247
+
248
+ dataset = AREDSDataset(
249
+ args.test_csv,
250
+ args.image_root,
251
+ task,
252
+ transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
253
+ )
254
+ loader = DataLoader(
255
+ dataset,
256
+ batch_size=args.batch_size,
257
+ shuffle=False,
258
+ num_workers=args.num_workers,
259
+ pin_memory=device.type == "cuda",
260
+ )
261
+
262
+ model = DeepSeeNet(
263
+ n_classes=n_classes,
264
+ backbone=args.backbone,
265
+ pretrained=False,
266
+ ).to(device)
267
+ checkpoint = torch.load(args.checkpoint, map_location=device)
268
+ model.load_state_dict(checkpoint["model"])
269
+
270
+ pred_dict = collect_predictions(model, loader, device)
271
+ y_true = pred_dict["labels"]
272
+ y_pred = pred_dict["preds"]
273
+ probs = pred_dict["probs"]
274
+
275
+ metrics = compute_metrics(y_true, y_pred, probs, n_classes=n_classes, positive_class=positive_class)
276
+ metrics["loss"] = float(pred_dict["loss"])
277
+
278
+ units = None
279
+ if args.bootstrap_unit_column:
280
+ df_for_units = pd.read_csv(args.test_csv)
281
+ if args.bootstrap_unit_column not in df_for_units.columns:
282
+ raise KeyError(
283
+ f"--bootstrap-unit-column {args.bootstrap_unit_column!r} not found in {args.test_csv}. "
284
+ f"Available columns: {list(df_for_units.columns)}"
285
+ )
286
+ if len(df_for_units) != len(y_true):
287
+ raise ValueError(
288
+ "CSV length does not match dataset length. "
289
+ f"CSV rows={len(df_for_units)}, dataset rows={len(y_true)}"
290
+ )
291
+ units = df_for_units[args.bootstrap_unit_column].to_numpy()
292
+
293
+ rng = np.random.default_rng(args.seed)
294
+ bs_indices = make_bootstrap_indices(
295
+ n=len(y_true),
296
+ n_iters=args.bootstrap_iters,
297
+ rng=rng,
298
+ units=units,
299
+ )
300
+
301
+ def metric_fn(idx: np.ndarray) -> dict[str, float]:
302
+ out = compute_metrics(
303
+ y_true[idx],
304
+ y_pred[idx],
305
+ probs[idx],
306
+ n_classes=n_classes,
307
+ positive_class=positive_class,
308
+ )
309
+ out.pop("loss", None)
310
+ return out
311
+
312
+ ci = bootstrap_ci(metric_fn, bs_indices)
313
+ metrics_with_ci = combine_with_ci(metrics, ci)
314
+
315
+ cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes)))
316
+ endpoint_cm = confusion_matrix(
317
+ (y_true == positive_class).astype(int),
318
+ (y_pred == positive_class).astype(int),
319
+ labels=[0, 1],
320
+ )
321
+
322
+ meta = {
323
+ "task": task,
324
+ "endpoint": ENDPOINT_NAME[task],
325
+ "positive_class": int(positive_class),
326
+ "n_classes": int(n_classes),
327
+ "n_samples": int(len(y_true)),
328
+ "bootstrap_iters": int(args.bootstrap_iters),
329
+ "bootstrap_unit_column": args.bootstrap_unit_column,
330
+ }
331
+
332
+ print(f"\nTask: {task} | endpoint: {ENDPOINT_NAME[task]} | positive_class={positive_class}")
333
+ print_metric_table(metrics_with_ci)
334
+ print("\nConfusion matrix (rows=true, cols=pred):")
335
+ print(cm)
336
+ print("\nBinary confusion matrix (rows=true, cols=pred):")
337
+ print(endpoint_cm)
338
+
339
+ if args.output_dir:
340
+ output_dir = Path(args.output_dir)
341
+ output_dir.mkdir(parents=True, exist_ok=True)
342
+
343
+ with (output_dir / "metrics.json").open("w") as f:
344
+ json.dump({"meta": meta, "metrics": metrics_with_ci}, f, indent=2)
345
+
346
+ pd.DataFrame(cm).to_csv(output_dir / "confusion_matrix.csv", index=False)
347
+ pd.DataFrame(endpoint_cm, index=["true_neg", "true_pos"], columns=["pred_neg", "pred_pos"]).to_csv(
348
+ output_dir / "endpoint_confusion_matrix.csv"
349
+ )
350
+
351
+ pred_df = pd.read_csv(args.test_csv)
352
+ if len(pred_df) == len(y_true):
353
+ pred_df = pred_df.copy()
354
+ else:
355
+ pred_df = pd.DataFrame(index=np.arange(len(y_true)))
356
+ pred_df["y_true"] = y_true
357
+ pred_df["y_pred"] = y_pred
358
+ pred_df[f"y_true_{ENDPOINT_NAME[task]}"] = (y_true == positive_class).astype(int)
359
+ pred_df[f"y_pred_{ENDPOINT_NAME[task]}"] = (y_pred == positive_class).astype(int)
360
+ for c in range(n_classes):
361
+ pred_df[f"prob_class_{c}"] = probs[:, c]
362
+ pred_df.to_csv(output_dir / "predictions.csv", index=False)
363
+
364
+ print(f"\nSaved outputs to: {output_dir}")
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main()
train.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from augmentations import get_train_transforms, get_val_transforms
12
+ from dataloader import AREDSDataset
13
+ from model import DeepSeeNet
14
+
15
+
16
+ N_CLASSES = {
17
+ "ADVAMD": 2,
18
+ "DRUS": 3,
19
+ "PIG": 2,
20
+ }
21
+
22
+
23
+ class AlbumentationsTransform:
24
+ def __init__(self, transform):
25
+ self.transform = transform
26
+
27
+ def __call__(self, image):
28
+ return self.transform(image=np.asarray(image))["image"]
29
+
30
+
31
+ def set_seed(seed):
32
+ random.seed(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed_all(seed)
36
+
37
+
38
+ def get_class_weights(dataset, task, device):
39
+ labels = torch.tensor(dataset.data[task].to_numpy(), dtype=torch.long)
40
+ counts = torch.bincount(labels, minlength=N_CLASSES[task]).clamp_min(1)
41
+ weights = counts.sum() / (len(counts) * counts)
42
+ return weights.to(device)
43
+
44
+
45
+ def build_scheduler(optimizer, args):
46
+ if args.scheduler == "cosine":
47
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
48
+ optimizer,
49
+ T_max=args.epochs,
50
+ eta_min=args.min_lr,
51
+ )
52
+
53
+ if args.scheduler == "step":
54
+ return torch.optim.lr_scheduler.StepLR(
55
+ optimizer,
56
+ step_size=args.step_size,
57
+ gamma=args.gamma,
58
+ )
59
+
60
+ return None
61
+
62
+
63
+ def train_one_epoch(
64
+ model,
65
+ loader,
66
+ optimizer,
67
+ scaler,
68
+ criterion,
69
+ device,
70
+ use_amp=True,
71
+ grad_clip=0.0,
72
+ ):
73
+ model.train()
74
+
75
+ running_loss = 0.0
76
+ running_correct = 0
77
+ running_samples = 0
78
+
79
+ pbar = tqdm(loader, desc="Train", leave=False)
80
+
81
+ for images, labels in pbar:
82
+ images = images.to(device, non_blocking=True)
83
+ labels = labels.to(device, non_blocking=True)
84
+
85
+ optimizer.zero_grad(set_to_none=True)
86
+
87
+ with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
88
+ logits = model(images)
89
+ loss = criterion(logits, labels)
90
+
91
+ if scaler is not None:
92
+ scaler.scale(loss).backward()
93
+
94
+ if grad_clip > 0:
95
+ scaler.unscale_(optimizer)
96
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
97
+
98
+ scaler.step(optimizer)
99
+ scaler.update()
100
+ else:
101
+ loss.backward()
102
+
103
+ if grad_clip > 0:
104
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
105
+
106
+ optimizer.step()
107
+
108
+ batch_size = labels.size(0)
109
+ running_loss += loss.item() * batch_size
110
+ running_correct += (logits.argmax(dim=1) == labels).sum().item()
111
+ running_samples += batch_size
112
+
113
+ pbar.set_postfix(
114
+ loss=f"{running_loss / running_samples:.4f}",
115
+ acc=f"{running_correct / running_samples:.4f}",
116
+ )
117
+
118
+ return running_loss / running_samples, running_correct / running_samples
119
+
120
+
121
+ @torch.no_grad()
122
+ def evaluate(model, loader, criterion, device, use_amp=True):
123
+ model.eval()
124
+
125
+ running_loss = 0.0
126
+ running_correct = 0
127
+ running_samples = 0
128
+
129
+ pbar = tqdm(loader, desc="Val", leave=False)
130
+
131
+ for images, labels in pbar:
132
+ images = images.to(device, non_blocking=True)
133
+ labels = labels.to(device, non_blocking=True)
134
+
135
+ with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
136
+ logits = model(images)
137
+ loss = criterion(logits, labels)
138
+
139
+ batch_size = labels.size(0)
140
+ running_loss += loss.item() * batch_size
141
+ running_correct += (logits.argmax(dim=1) == labels).sum().item()
142
+ running_samples += batch_size
143
+
144
+ pbar.set_postfix(
145
+ loss=f"{running_loss / running_samples:.4f}",
146
+ acc=f"{running_correct / running_samples:.4f}",
147
+ )
148
+
149
+ return running_loss / running_samples, running_correct / running_samples
150
+
151
+
152
+ def save_checkpoint(path, model, optimizer, epoch, best_val_loss, args):
153
+ path = Path(path)
154
+ path.parent.mkdir(parents=True, exist_ok=True)
155
+
156
+ torch.save(
157
+ {
158
+ "epoch": epoch,
159
+ "model": model.state_dict(),
160
+ "optimizer": optimizer.state_dict(),
161
+ "best_val_loss": best_val_loss,
162
+ "args": vars(args),
163
+ },
164
+ path,
165
+ )
166
+
167
+
168
+ def main(args):
169
+ set_seed(args.seed)
170
+
171
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
+ use_amp = args.amp and device.type == "cuda"
173
+
174
+ train_dataset = AREDSDataset(
175
+ args.train_csv,
176
+ args.image_root,
177
+ args.task,
178
+ transform=AlbumentationsTransform(get_train_transforms(args.image_size)),
179
+ )
180
+
181
+ val_dataset = AREDSDataset(
182
+ args.valid_csv,
183
+ args.image_root,
184
+ args.task,
185
+ transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
186
+ )
187
+
188
+ train_loader = DataLoader(
189
+ train_dataset,
190
+ batch_size=args.batch_size,
191
+ shuffle=True,
192
+ num_workers=args.num_workers,
193
+ pin_memory=device.type == "cuda",
194
+ )
195
+
196
+ val_loader = DataLoader(
197
+ val_dataset,
198
+ batch_size=args.batch_size,
199
+ shuffle=False,
200
+ num_workers=args.num_workers,
201
+ pin_memory=device.type == "cuda",
202
+ )
203
+
204
+ model = DeepSeeNet(
205
+ n_classes=N_CLASSES[args.task],
206
+ backbone=args.backbone,
207
+ pretrained=not args.no_pretrained,
208
+ freeze_backbone=args.freeze_backbone,
209
+ ).to(device)
210
+
211
+ class_weights = None
212
+ if not args.no_class_weights:
213
+ class_weights = get_class_weights(train_dataset, args.task, device)
214
+
215
+ criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
216
+
217
+ optimizer = torch.optim.AdamW(
218
+ model.parameters(),
219
+ lr=args.lr,
220
+ weight_decay=args.weight_decay,
221
+ )
222
+
223
+ scheduler = build_scheduler(optimizer, args)
224
+ scaler = torch.amp.GradScaler("cuda") if use_amp else None
225
+
226
+ wandb = None
227
+ if args.wandb:
228
+ import wandb
229
+
230
+ wandb.init(project=args.wandb_project, config=vars(args))
231
+
232
+ output_dir = Path(args.output_dir)
233
+ best_val_loss = float("inf")
234
+
235
+ print(f"Device: {device}")
236
+ print(f"Task: {args.task}")
237
+ print(f"Train samples: {len(train_dataset)}")
238
+ print(f"Val samples: {len(val_dataset)}")
239
+ print(f"Image size: {args.image_size}")
240
+ print(f"Batch size: {args.batch_size}")
241
+ print(f"Pretrained: {not args.no_pretrained}")
242
+ if class_weights is not None:
243
+ print(f"Class weights: {class_weights.detach().cpu().tolist()}")
244
+
245
+ for epoch in range(1, args.epochs + 1):
246
+ print(f"\nEpoch [{epoch:03d}/{args.epochs}]")
247
+
248
+ train_loss, train_acc = train_one_epoch(
249
+ model=model,
250
+ loader=train_loader,
251
+ optimizer=optimizer,
252
+ scaler=scaler,
253
+ criterion=criterion,
254
+ device=device,
255
+ use_amp=args.amp,
256
+ grad_clip=args.grad_clip,
257
+ )
258
+
259
+ val_loss, val_acc = evaluate(
260
+ model=model,
261
+ loader=val_loader,
262
+ criterion=torch.nn.CrossEntropyLoss(),
263
+ device=device,
264
+ use_amp=args.amp,
265
+ )
266
+
267
+ lr = optimizer.param_groups[0]["lr"]
268
+
269
+ print(
270
+ f"train_loss={train_loss:.4f} "
271
+ f"train_acc={train_acc:.4f} "
272
+ f"val_loss={val_loss:.4f} "
273
+ f"val_acc={val_acc:.4f} "
274
+ f"lr={lr:.2e}"
275
+ )
276
+
277
+ if wandb is not None:
278
+ wandb.log(
279
+ {
280
+ "epoch": epoch,
281
+ "lr": lr,
282
+ "train_loss": train_loss,
283
+ "train_acc": train_acc,
284
+ "val_loss": val_loss,
285
+ "val_acc": val_acc,
286
+ }
287
+ )
288
+
289
+ if val_loss < best_val_loss:
290
+ best_val_loss = val_loss
291
+ save_checkpoint(
292
+ output_dir / "best.pt",
293
+ model,
294
+ optimizer,
295
+ epoch,
296
+ best_val_loss,
297
+ args,
298
+ )
299
+ print(f"Saved best checkpoint: val_loss={best_val_loss:.4f}")
300
+
301
+ if args.save_every > 0 and epoch % args.save_every == 0:
302
+ save_checkpoint(
303
+ output_dir / f"epoch_{epoch:03d}.pt",
304
+ model,
305
+ optimizer,
306
+ epoch,
307
+ best_val_loss,
308
+ args,
309
+ )
310
+
311
+ if scheduler is not None:
312
+ scheduler.step()
313
+
314
+ save_checkpoint(
315
+ output_dir / "last.pt",
316
+ model,
317
+ optimizer,
318
+ args.epochs,
319
+ best_val_loss,
320
+ args,
321
+ )
322
+
323
+ print("Training complete.")
324
+ print(f"Best val loss: {best_val_loss:.4f}")
325
+
326
+
327
+ def parse_args():
328
+ parser = argparse.ArgumentParser(description="Train DeepSeeNet.")
329
+
330
+ parser.add_argument("--train-csv", required=True)
331
+ parser.add_argument("--valid-csv", required=True)
332
+ parser.add_argument("--image-root", required=True)
333
+ parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES)
334
+ parser.add_argument("--output-dir", default="checkpoints/deepseenet")
335
+
336
+ parser.add_argument("--backbone", default="inception_v3")
337
+ parser.add_argument("--image-size", type=int, default=1024)
338
+ parser.add_argument("--epochs", type=int, default=20)
339
+ parser.add_argument("--batch-size", type=int, default=32)
340
+ parser.add_argument("--num-workers", type=int, default=4)
341
+
342
+ parser.add_argument("--lr", type=float, default=1e-4)
343
+ parser.add_argument("--weight-decay", type=float, default=1e-4)
344
+ parser.add_argument("--no-pretrained", action="store_true")
345
+ parser.add_argument("--freeze-backbone", action="store_true")
346
+ parser.add_argument("--no-class-weights", action="store_true")
347
+
348
+ parser.add_argument("--scheduler", choices=("none", "cosine", "step"), default="cosine")
349
+ parser.add_argument("--min-lr", type=float, default=1e-6)
350
+ parser.add_argument("--step-size", type=int, default=5)
351
+ parser.add_argument("--gamma", type=float, default=0.5)
352
+
353
+ parser.add_argument("--amp", action="store_true")
354
+ parser.add_argument("--grad-clip", type=float, default=0.0)
355
+ parser.add_argument("--save-every", type=int, default=0)
356
+ parser.add_argument("--seed", type=int, default=42)
357
+
358
+ parser.add_argument("--wandb", action="store_true")
359
+ parser.add_argument("--wandb-project", default="deepseenet")
360
+
361
+ return parser.parse_args()
362
+
363
+
364
+ if __name__ == "__main__":
365
+ args = parse_args()
366
+ main(args)