farrell236 commited on
Commit
d0344ce
·
1 Parent(s): 7629975
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple Gradio app for testing an EyeQ QC model.
4
+
5
+ Example
6
+ -------
7
+ python app_eyeq.py \
8
+ --checkpoint ./checkpoints/eyeq_vit_base/best.pt
9
+
10
+ Then open the printed local URL in your browser.
11
+ """
12
+
13
+ import argparse
14
+ from pathlib import Path
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import torch
19
+ from PIL import Image
20
+ from torchvision import transforms
21
+ import timm
22
+
23
+
24
+ ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"}
25
+
26
+
27
+ def build_transform(img_size: int):
28
+ return transforms.Compose([
29
+ transforms.Resize((img_size, img_size)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
32
+ ])
33
+
34
+
35
+ def load_model(checkpoint_path: str, device: torch.device):
36
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
37
+
38
+ args = ckpt.get("args", {})
39
+ model_name = args.get("model", "vit_base_patch16_224")
40
+ img_size = int(args.get("img_size", 224))
41
+
42
+ id_to_label = ckpt.get("id_to_label", ID_TO_LABEL)
43
+ id_to_label = {int(k): v for k, v in id_to_label.items()}
44
+
45
+ model = timm.create_model(
46
+ model_name,
47
+ pretrained=False,
48
+ num_classes=len(id_to_label),
49
+ )
50
+ model.load_state_dict(ckpt["model"], strict=True)
51
+ model.to(device)
52
+ model.eval()
53
+
54
+ tfm = build_transform(img_size)
55
+ return model, tfm, id_to_label, model_name, img_size
56
+
57
+
58
+ def get_eyeq_class_ids(id_to_label):
59
+ """Return class IDs for Good, Usable, Reject.
60
+
61
+ Falls back to the standard EyeQ ordering if the checkpoint does not store
62
+ string labels in the expected form.
63
+ """
64
+ label_to_id = {str(v).lower(): int(k) for k, v in id_to_label.items()}
65
+
66
+ good_id = label_to_id.get("good", 0)
67
+ usable_id = label_to_id.get("usable", 1)
68
+ reject_id = label_to_id.get("reject", 2)
69
+
70
+ return good_id, usable_id, reject_id
71
+
72
+
73
+ def soft_eyeq_decision(probs, id_to_label, reject_threshold=0.60, reject_margin=0.15):
74
+ """Apply a conservative Reject rule.
75
+
76
+ Reject is only returned when:
77
+ 1. P(Reject) >= reject_threshold, and
78
+ 2. P(Reject) beats the best non-Reject class by reject_margin.
79
+
80
+ Otherwise, the prediction is forced to Good vs Usable.
81
+ """
82
+ good_id, usable_id, reject_id = get_eyeq_class_ids(id_to_label)
83
+
84
+ prob_good = float(probs[good_id])
85
+ prob_usable = float(probs[usable_id])
86
+ prob_reject = float(probs[reject_id])
87
+
88
+ best_non_reject_id = good_id if prob_good >= prob_usable else usable_id
89
+ best_non_reject_prob = max(prob_good, prob_usable)
90
+
91
+ if (
92
+ prob_reject >= reject_threshold
93
+ and (prob_reject - best_non_reject_prob) >= reject_margin
94
+ ):
95
+ pred_id = reject_id
96
+ decision = "Soft rule: Reject threshold and margin were both satisfied."
97
+ else:
98
+ pred_id = best_non_reject_id
99
+ decision = "Soft rule: Reject was not confident enough, so prediction was forced to Good/Usable."
100
+
101
+ return pred_id, id_to_label[pred_id], decision
102
+
103
+
104
+ def update_margin_slider(reject_threshold, reject_margin):
105
+ """Keep reject_margin within a sensible range for the current threshold."""
106
+ max_margin = min(0.50, float(reject_threshold))
107
+ reject_margin = min(float(reject_margin), max_margin)
108
+
109
+ return gr.update(
110
+ maximum=max_margin,
111
+ value=reject_margin,
112
+ )
113
+
114
+
115
+ @torch.no_grad()
116
+ def predict_quality(
117
+ image: Image.Image,
118
+ model,
119
+ tfm,
120
+ id_to_label,
121
+ device,
122
+ reject_threshold=0.60,
123
+ reject_margin=0.15,
124
+ ):
125
+ if image is None:
126
+ return None, {}, "Upload an image to run QC."
127
+
128
+ image = image.convert("RGB")
129
+ x = tfm(image).unsqueeze(0).to(device)
130
+
131
+ logits = model(x)
132
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
133
+
134
+ raw_pred_id = int(np.argmax(probs))
135
+ raw_pred_label = id_to_label[raw_pred_id]
136
+
137
+ soft_pred_id, soft_pred_label, decision = soft_eyeq_decision(
138
+ probs=probs,
139
+ id_to_label=id_to_label,
140
+ reject_threshold=reject_threshold,
141
+ reject_margin=reject_margin,
142
+ )
143
+
144
+ prob_dict = {
145
+ id_to_label[i]: float(probs[i])
146
+ for i in range(len(probs))
147
+ }
148
+
149
+ detail = (
150
+ f"Raw argmax: {raw_pred_label}\n"
151
+ f"Soft decision: {soft_pred_label}\n"
152
+ f"Reject threshold: {reject_threshold:.2f} | Reject margin: {reject_margin:.2f}\n"
153
+ f"{decision}"
154
+ )
155
+
156
+ return soft_pred_label, prob_dict, detail
157
+
158
+
159
+ def make_app(checkpoint_path: str):
160
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161
+ model, tfm, id_to_label, model_name, img_size = load_model(checkpoint_path, device)
162
+
163
+ def run(image, reject_threshold, reject_margin):
164
+ pred_label, prob_dict, detail = predict_quality(
165
+ image=image,
166
+ model=model,
167
+ tfm=tfm,
168
+ id_to_label=id_to_label,
169
+ device=device,
170
+ reject_threshold=reject_threshold,
171
+ reject_margin=reject_margin,
172
+ )
173
+ return pred_label, prob_dict, detail
174
+
175
+ with gr.Blocks(title="EyeQ CFP Quality Control") as demo:
176
+ gr.Markdown("# EyeQ CFP Quality Control")
177
+ gr.Markdown(
178
+ f"Model: `{model_name}` \n"
179
+ f"Input size: `{img_size} × {img_size}` \n"
180
+ f"Device: `{device}` \n"
181
+ f"Checkpoint: `{checkpoint_path}`"
182
+ )
183
+
184
+ with gr.Row():
185
+ with gr.Column(scale=1):
186
+ image_input = gr.Image(
187
+ label="Input CFP",
188
+ type="pil",
189
+ height=520,
190
+ )
191
+ with gr.Accordion("Soft Reject rule", open=True):
192
+ reject_threshold = gr.Slider(
193
+ minimum=0.40,
194
+ maximum=0.95,
195
+ value=0.60,
196
+ step=0.01,
197
+ label="Reject threshold",
198
+ info="Minimum Reject probability required before an image can be called Reject.",
199
+ )
200
+ reject_margin = gr.Slider(
201
+ minimum=0.00,
202
+ maximum=0.50,
203
+ value=0.15,
204
+ step=0.01,
205
+ label="Reject margin",
206
+ info="Reject must beat both Good and Usable by at least this much.",
207
+ )
208
+
209
+ run_button = gr.Button("Run QC", variant="primary")
210
+
211
+ with gr.Column(scale=1):
212
+ pred_output = gr.Label(label="Predicted quality")
213
+ prob_output = gr.Label(label="Class probabilities", num_top_classes=3)
214
+ decision_output = gr.Textbox(
215
+ label="Decision details",
216
+ lines=4,
217
+ interactive=False,
218
+ )
219
+
220
+ run_inputs = [image_input, reject_threshold, reject_margin]
221
+ run_outputs = [pred_output, prob_output, decision_output]
222
+
223
+ run_button.click(
224
+ fn=run,
225
+ inputs=run_inputs,
226
+ outputs=run_outputs,
227
+ )
228
+
229
+ image_input.change(
230
+ fn=run,
231
+ inputs=run_inputs,
232
+ outputs=run_outputs,
233
+ )
234
+
235
+ reject_threshold.change(
236
+ fn=update_margin_slider,
237
+ inputs=[reject_threshold, reject_margin],
238
+ outputs=reject_margin,
239
+ ).then(
240
+ fn=run,
241
+ inputs=run_inputs,
242
+ outputs=run_outputs,
243
+ )
244
+
245
+ reject_margin.change(
246
+ fn=run,
247
+ inputs=run_inputs,
248
+ outputs=run_outputs,
249
+ )
250
+
251
+ return demo
252
+
253
+
254
+ def parse_args():
255
+ parser = argparse.ArgumentParser()
256
+ parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/eyeq_deploy.pt")
257
+ parser.add_argument("--host", type=str, default="0.0.0.0")
258
+ parser.add_argument("--port", type=int, default=7860)
259
+ parser.add_argument("--share", action="store_true")
260
+ return parser.parse_args()
261
+
262
+
263
+ def main():
264
+ args = parse_args()
265
+
266
+ checkpoint_path = Path(args.checkpoint)
267
+ if not checkpoint_path.exists():
268
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
269
+
270
+ demo = make_app(str(checkpoint_path))
271
+ demo.launch(
272
+ # server_name=args.host,
273
+ # server_port=args.port,
274
+ # share=args.share,
275
+ )
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
checkpoints/eyeq_vit_base/best_report.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Best epoch: 3
2
+ Best test balanced accuracy: 0.8573
3
+
4
+ precision recall f1-score support
5
+
6
+ Good 0.9262 0.9337 0.9299 8471
7
+ Usable 0.7829 0.7760 0.7794 4558
8
+ Reject 0.8697 0.8621 0.8659 3220
9
+
10
+ accuracy 0.8753 16249
11
+ macro avg 0.8596 0.8573 0.8584 16249
12
+ weighted avg 0.8748 0.8753 0.8750 16249
13
+
14
+ Confusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]
15
+ [[7909 556 6]
16
+ [ 611 3537 410]
17
+ [ 19 425 2776]]
checkpoints/eyeq_vit_base/eyeq_deploy.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71226f3e62eeffe52af548f99d90730cd23009e06cf3b9aafe2e555c58752bc3
3
+ size 343261042
checkpoints/eyeq_vit_base/test_eval/test_confusion_matrix.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ,pred_Good,pred_Usable,pred_Reject
2
+ true_Good,7908,557,6
3
+ true_Usable,611,3537,410
4
+ true_Reject,19,426,2775
checkpoints/eyeq_vit_base/test_eval/test_predictions.csv ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/eyeq_vit_base/test_eval/test_report.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Checkpoint: checkpoints/eyeq_vit_base/best.pt
2
+ Test CSV: /data/MIDS/datasets/retina/EyeQ/data/Label_EyeQ_test.csv
3
+ Test images: /data/MIDS/datasets/retina/EyePACS/test
4
+ Model: vit_base_patch16_224
5
+ Image size: 224
6
+ Device: cuda
7
+
8
+ test_loss=0.312007
9
+ test_acc=0.875131
10
+ test_bal_acc=0.857112
11
+
12
+ precision recall f1-score support
13
+
14
+ Good 0.9262 0.9335 0.9299 8471
15
+ Usable 0.7825 0.7760 0.7792 4558
16
+ Reject 0.8696 0.8618 0.8657 3220
17
+
18
+ accuracy 0.8751 16249
19
+ macro avg 0.8595 0.8571 0.8583 16249
20
+ weighted avg 0.8747 0.8751 0.8749 16249
21
+
22
+ Confusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]
23
+ [[7908 557 6]
24
+ [ 611 3537 410]
25
+ [ 19 426 2775]]
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations
2
+ gradio
3
+ huggingface_hub
4
+ numpy
5
+ opencv-python
6
+ pandas
7
+ pillow
8
+ pydantic
9
+ timm
10
+ torch
11
+ torchvision
12
+ torchaudio
13
+ tqdm
test.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate an EyeQ CFP image-quality-control model on Label_EyeQ_test.csv.
4
+
5
+ Example
6
+ -------
7
+ python EyeQ_test.py \
8
+ --images_dir /data/MIDS/datasets/retina/EyePACS \
9
+ --csv_dir /data/MIDS/datasets/retina/EyeQ/data \
10
+ --checkpoint ./checkpoints/eyeq_vit_base/best.pt \
11
+ --output_dir ./checkpoints/eyeq_vit_base/test_eval \
12
+ --batch_size 32 \
13
+ --num_workers 24
14
+ """
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+ from typing import Dict, Tuple
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ from PIL import Image
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from torchvision import transforms
28
+
29
+ import timm
30
+ from sklearn.metrics import (
31
+ accuracy_score,
32
+ balanced_accuracy_score,
33
+ classification_report,
34
+ confusion_matrix,
35
+ )
36
+ from tqdm import tqdm
37
+
38
+
39
+ ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"}
40
+ LABEL_TO_ID: Dict[str, int] = {
41
+ "good": 0,
42
+ "usable": 1,
43
+ "reject": 2,
44
+ "0": 0,
45
+ "1": 1,
46
+ "2": 2,
47
+ }
48
+
49
+
50
+ class EyeQDataset(Dataset):
51
+ def __init__(self, df: pd.DataFrame, images_dir: str, transform=None):
52
+ self.df = df.reset_index(drop=True)
53
+ self.images_dir = Path(images_dir)
54
+ self.transform = transform
55
+
56
+ def __len__(self):
57
+ return len(self.df)
58
+
59
+ def __getitem__(self, idx):
60
+ row = self.df.iloc[idx]
61
+
62
+ image_name = str(row["image"])
63
+ image_path = self.images_dir / image_name
64
+
65
+ image = Image.open(image_path).convert("RGB")
66
+ label = int(row["quality"])
67
+
68
+ if self.transform is not None:
69
+ image = self.transform(image)
70
+
71
+ return image, label, image_name
72
+
73
+
74
+ def normalize_quality_label(x) -> int:
75
+ key = str(x).strip().lower()
76
+
77
+ if key in LABEL_TO_ID:
78
+ return LABEL_TO_ID[key]
79
+
80
+ try:
81
+ value = int(float(key))
82
+ if value in [0, 1, 2]:
83
+ return value
84
+ except ValueError:
85
+ pass
86
+
87
+ raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.")
88
+
89
+
90
+ def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame:
91
+ df = pd.read_csv(csv_path)
92
+
93
+ if "image" not in df.columns:
94
+ raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}")
95
+ if "quality" not in df.columns:
96
+ raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}")
97
+
98
+ # Keep DR_grade if present for optional downstream inspection.
99
+ keep_cols = ["image", "quality"]
100
+ if "DR_grade" in df.columns:
101
+ keep_cols.append("DR_grade")
102
+
103
+ df = df[keep_cols].copy()
104
+ df["image"] = df["image"].astype(str)
105
+ df["quality"] = df["quality"].apply(normalize_quality_label)
106
+
107
+ images_dir = Path(images_dir)
108
+ exists = df["image"].apply(lambda x: (images_dir / x).exists())
109
+
110
+ missing = int((~exists).sum())
111
+ if missing > 0:
112
+ print(f"Warning: dropping {missing} rows with missing image files from {csv_path}")
113
+ print(f" searched in: {images_dir}")
114
+
115
+ df = df.loc[exists].reset_index(drop=True)
116
+
117
+ if len(df) == 0:
118
+ raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}")
119
+
120
+ return df
121
+
122
+
123
+ def build_transform(img_size: int):
124
+ return transforms.Compose([
125
+ transforms.Resize((img_size, img_size)),
126
+ transforms.ToTensor(),
127
+ transforms.Normalize(
128
+ mean=(0.485, 0.456, 0.406),
129
+ std=(0.229, 0.224, 0.225),
130
+ ),
131
+ ])
132
+
133
+
134
+ def load_model(checkpoint_path: str, device: torch.device):
135
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
136
+
137
+ ckpt_args = ckpt.get("args", {})
138
+ model_name = ckpt_args.get("model", "vit_base_patch16_224")
139
+ img_size = int(ckpt_args.get("img_size", 224))
140
+
141
+ id_to_label = ckpt.get("id_to_label", ID_TO_LABEL)
142
+ id_to_label = {int(k): str(v) for k, v in id_to_label.items()}
143
+
144
+ model = timm.create_model(
145
+ model_name,
146
+ pretrained=False,
147
+ num_classes=len(id_to_label),
148
+ )
149
+
150
+ model.load_state_dict(ckpt["model"], strict=True)
151
+ model.to(device)
152
+ model.eval()
153
+
154
+ return model, id_to_label, model_name, img_size, ckpt
155
+
156
+
157
+ @torch.no_grad()
158
+ def evaluate(model, loader, criterion, device, amp=False):
159
+ model.eval()
160
+
161
+ running_loss = 0.0
162
+ all_labels = []
163
+ all_preds = []
164
+ all_probs = []
165
+ all_images = []
166
+
167
+ for images, labels, image_names in tqdm(loader, desc="Test"):
168
+ images = images.to(device, non_blocking=True)
169
+ labels = labels.to(device, non_blocking=True)
170
+
171
+ with torch.cuda.amp.autocast(enabled=amp and device.type == "cuda"):
172
+ logits = model(images)
173
+ loss = criterion(logits, labels)
174
+ probs = torch.softmax(logits, dim=1)
175
+
176
+ preds = probs.argmax(dim=1)
177
+
178
+ running_loss += loss.item() * images.size(0)
179
+
180
+ all_labels.extend(labels.detach().cpu().numpy().tolist())
181
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
182
+ all_probs.extend(probs.detach().cpu().numpy().tolist())
183
+ all_images.extend(list(image_names))
184
+
185
+ test_loss = running_loss / len(loader.dataset)
186
+
187
+ y_true = np.array(all_labels)
188
+ y_pred = np.array(all_preds)
189
+ probs = np.array(all_probs)
190
+
191
+ acc = accuracy_score(y_true, y_pred)
192
+ bal_acc = balanced_accuracy_score(y_true, y_pred)
193
+
194
+ return test_loss, acc, bal_acc, y_true, y_pred, probs, all_images
195
+
196
+
197
+ def print_label_counts(name: str, df: pd.DataFrame):
198
+ print(f"{name}: {len(df)}")
199
+ for label_id in [0, 1, 2]:
200
+ count = int((df["quality"] == label_id).sum())
201
+ print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}")
202
+
203
+
204
+ def parse_args():
205
+ parser = argparse.ArgumentParser()
206
+
207
+ parser.add_argument("--images_dir", type=str, required=True,
208
+ help="EyePACS root containing train/ and test/ folders.")
209
+ parser.add_argument("--csv_dir", type=str, required=True,
210
+ help="Directory containing Label_EyeQ_test.csv.")
211
+ parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/best.pt")
212
+ parser.add_argument("--output_dir", type=str, default=None)
213
+
214
+ parser.add_argument("--batch_size", type=int, default=32)
215
+ parser.add_argument("--num_workers", type=int, default=8)
216
+ parser.add_argument("--amp", action="store_true", default=True)
217
+ parser.add_argument("--no_amp", dest="amp", action="store_false")
218
+ parser.add_argument("--cpu", action="store_true")
219
+
220
+ return parser.parse_args()
221
+
222
+
223
+ def main():
224
+ args = parse_args()
225
+
226
+ images_root = Path(args.images_dir)
227
+ csv_root = Path(args.csv_dir)
228
+ checkpoint_path = Path(args.checkpoint)
229
+
230
+ test_images_dir = images_root / "test"
231
+ test_csv = csv_root / "Label_EyeQ_test.csv"
232
+
233
+ if args.output_dir is None:
234
+ output_dir = checkpoint_path.parent / "test_eval"
235
+ else:
236
+ output_dir = Path(args.output_dir)
237
+
238
+ output_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ if not checkpoint_path.exists():
241
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
242
+ if not test_images_dir.exists():
243
+ raise FileNotFoundError(f"Test image directory not found: {test_images_dir}")
244
+ if not test_csv.exists():
245
+ raise FileNotFoundError(f"Test CSV not found: {test_csv}")
246
+
247
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
248
+
249
+ model, id_to_label, model_name, img_size, ckpt = load_model(str(checkpoint_path), device)
250
+ transform = build_transform(img_size)
251
+
252
+ test_df = load_eyeq_csv(str(test_csv), str(test_images_dir))
253
+ test_ds = EyeQDataset(test_df, str(test_images_dir), transform)
254
+
255
+ test_loader = DataLoader(
256
+ test_ds,
257
+ batch_size=args.batch_size,
258
+ shuffle=False,
259
+ num_workers=args.num_workers,
260
+ pin_memory=(device.type == "cuda"),
261
+ persistent_workers=(args.num_workers > 0),
262
+ )
263
+
264
+ criterion = nn.CrossEntropyLoss()
265
+
266
+ print("Evaluation summary")
267
+ print(f"Checkpoint: {checkpoint_path}")
268
+ print(f"Test CSV: {test_csv}")
269
+ print(f"Test images: {test_images_dir}")
270
+ print(f"Output dir: {output_dir}")
271
+ print(f"Model: {model_name}")
272
+ print(f"Image size: {img_size}")
273
+ print(f"Device: {device}")
274
+ print(f"Labels: {id_to_label}")
275
+ print_label_counts("Test", test_df)
276
+
277
+ test_loss, acc, bal_acc, y_true, y_pred, probs, image_names = evaluate(
278
+ model=model,
279
+ loader=test_loader,
280
+ criterion=criterion,
281
+ device=device,
282
+ amp=args.amp,
283
+ )
284
+
285
+ target_names = [id_to_label[i] for i in [0, 1, 2]]
286
+
287
+ report = classification_report(
288
+ y_true,
289
+ y_pred,
290
+ labels=[0, 1, 2],
291
+ target_names=target_names,
292
+ digits=4,
293
+ )
294
+
295
+ cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
296
+
297
+ print()
298
+ print(f"test_loss={test_loss:.4f}")
299
+ print(f"test_acc={acc:.4f}")
300
+ print(f"test_bal_acc={bal_acc:.4f}")
301
+ print()
302
+ print(report)
303
+ print("Confusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]")
304
+ print(cm)
305
+
306
+ # Save text report
307
+ with open(output_dir / "test_report.txt", "w") as f:
308
+ f.write(f"Checkpoint: {checkpoint_path}\n")
309
+ f.write(f"Test CSV: {test_csv}\n")
310
+ f.write(f"Test images: {test_images_dir}\n")
311
+ f.write(f"Model: {model_name}\n")
312
+ f.write(f"Image size: {img_size}\n")
313
+ f.write(f"Device: {device}\n\n")
314
+ f.write(f"test_loss={test_loss:.6f}\n")
315
+ f.write(f"test_acc={acc:.6f}\n")
316
+ f.write(f"test_bal_acc={bal_acc:.6f}\n\n")
317
+ f.write(report)
318
+ f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n")
319
+ f.write(str(cm))
320
+ f.write("\n")
321
+
322
+ # Save confusion matrix CSV
323
+ cm_df = pd.DataFrame(
324
+ cm,
325
+ index=[f"true_{name}" for name in target_names],
326
+ columns=[f"pred_{name}" for name in target_names],
327
+ )
328
+ cm_df.to_csv(output_dir / "test_confusion_matrix.csv")
329
+
330
+ # Save per-image predictions
331
+ pred_df = test_df.copy()
332
+ pred_df["pred_quality"] = y_pred
333
+ pred_df["true_label"] = [id_to_label[int(x)] for x in y_true]
334
+ pred_df["pred_label"] = [id_to_label[int(x)] for x in y_pred]
335
+ pred_df["prob_good"] = probs[:, 0]
336
+ pred_df["prob_usable"] = probs[:, 1]
337
+ pred_df["prob_reject"] = probs[:, 2]
338
+ pred_df["correct"] = pred_df["quality"].values == pred_df["pred_quality"].values
339
+
340
+ pred_df.to_csv(output_dir / "test_predictions.csv", index=False)
341
+
342
+ print()
343
+ print(f"Saved report: {output_dir / 'test_report.txt'}")
344
+ print(f"Saved confusion: {output_dir / 'test_confusion_matrix.csv'}")
345
+ print(f"Saved predictions: {output_dir / 'test_predictions.csv'}")
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
train.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train a CFP image-quality-control model on EyeQ / EyePACS-style data.
4
+
5
+ Expected dataset format
6
+ -----------------------
7
+ EyePACS/
8
+ train/
9
+ 10009_left.jpeg
10
+ 10009_right.jpeg
11
+ ...
12
+ test/
13
+ ...
14
+
15
+ data/
16
+ Label_EyeQ_train.csv
17
+ Label_EyeQ_test.csv
18
+
19
+ Label CSV format:
20
+ ,image,quality,DR_grade
21
+ 0,10009_left.jpeg,0,0
22
+ 1,10009_right.jpeg,0,0
23
+ 2,10014_left.jpeg,2,0
24
+
25
+ For EyeQ, this script assumes:
26
+ quality = 0 -> Good
27
+ quality = 1 -> Usable
28
+ quality = 2 -> Reject
29
+
30
+ DR_grade is ignored because this script trains only the image-quality model.
31
+
32
+ Example
33
+ -------
34
+ python EyeQ_train.py \
35
+ --images_dir /path/to/EyePACS \
36
+ --csv_dir /path/to/data \
37
+ --output_dir ./runs/eyeq_vit_base \
38
+ --epochs 30 \
39
+ --batch_size 32 \
40
+ --lr 3e-5
41
+ """
42
+
43
+ import argparse
44
+ import random
45
+ from pathlib import Path
46
+ from typing import Dict, Tuple
47
+
48
+ import numpy as np
49
+ import pandas as pd
50
+ from PIL import Image
51
+
52
+ import torch
53
+ import torch.nn as nn
54
+ from torch.utils.data import Dataset, DataLoader
55
+ from torchvision import transforms
56
+
57
+ import timm
58
+ from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix
59
+ from tqdm import tqdm
60
+
61
+
62
+ ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"}
63
+ LABEL_TO_ID: Dict[str, int] = {
64
+ "good": 0,
65
+ "usable": 1,
66
+ "reject": 2,
67
+ "0": 0,
68
+ "1": 1,
69
+ "2": 2,
70
+ }
71
+
72
+
73
+ class EyeQDataset(Dataset):
74
+ def __init__(self, df: pd.DataFrame, images_dir: str, transform=None):
75
+ self.df = df.reset_index(drop=True)
76
+ self.images_dir = Path(images_dir)
77
+ self.transform = transform
78
+
79
+ def __len__(self):
80
+ return len(self.df)
81
+
82
+ def __getitem__(self, idx):
83
+ row = self.df.iloc[idx]
84
+ image_path = self.images_dir / str(row["image"])
85
+ image = Image.open(image_path).convert("RGB")
86
+ label = int(row["quality"])
87
+
88
+ if self.transform is not None:
89
+ image = self.transform(image)
90
+
91
+ return image, label
92
+
93
+
94
+ def seed_everything(seed: int):
95
+ random.seed(seed)
96
+ np.random.seed(seed)
97
+ torch.manual_seed(seed)
98
+ torch.cuda.manual_seed_all(seed)
99
+ torch.backends.cudnn.benchmark = True
100
+
101
+
102
+ def normalize_quality_label(x) -> int:
103
+ key = str(x).strip().lower()
104
+ if key in LABEL_TO_ID:
105
+ return LABEL_TO_ID[key]
106
+ try:
107
+ value = int(float(key))
108
+ if value in [0, 1, 2]:
109
+ return value
110
+ except ValueError:
111
+ pass
112
+ raise ValueError(f"Unknown quality label: {x}. Expected 0/1/2 or Good/Usable/Reject.")
113
+
114
+
115
+ def load_eyeq_csv(csv_path: str, images_dir: str) -> pd.DataFrame:
116
+ df = pd.read_csv(csv_path)
117
+
118
+ if "image" not in df.columns:
119
+ raise ValueError(f"CSV must contain an 'image' column. Found columns: {list(df.columns)}")
120
+ if "quality" not in df.columns:
121
+ raise ValueError(f"CSV must contain a 'quality' column. Found columns: {list(df.columns)}")
122
+
123
+ df = df[["image", "quality"]].copy()
124
+ df["image"] = df["image"].astype(str)
125
+ df["quality"] = df["quality"].apply(normalize_quality_label)
126
+
127
+ images_dir = Path(images_dir)
128
+ exists = df["image"].apply(lambda x: (images_dir / x).exists())
129
+ missing = int((~exists).sum())
130
+ if missing > 0:
131
+ print(f"Warning: dropping {missing} rows with missing image files from {csv_path}")
132
+ print(f" searched in: {images_dir}")
133
+ df = df.loc[exists].reset_index(drop=True)
134
+
135
+ if len(df) == 0:
136
+ raise RuntimeError(f"No valid images found for {csv_path}. Searched in: {images_dir}")
137
+
138
+ return df
139
+
140
+
141
+ def build_transforms(img_size: int) -> Tuple[transforms.Compose, transforms.Compose]:
142
+ train_tfms = transforms.Compose([
143
+ transforms.Resize((img_size, img_size)),
144
+ transforms.RandomHorizontalFlip(p=0.5),
145
+ transforms.RandomApply([
146
+ transforms.ColorJitter(
147
+ brightness=0.15,
148
+ contrast=0.15,
149
+ saturation=0.10,
150
+ hue=0.02,
151
+ )
152
+ ], p=0.8),
153
+ transforms.RandomApply([
154
+ transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))
155
+ ], p=0.15),
156
+ transforms.ToTensor(),
157
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
158
+ ])
159
+
160
+ test_tfms = transforms.Compose([
161
+ transforms.Resize((img_size, img_size)),
162
+ transforms.ToTensor(),
163
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
164
+ ])
165
+
166
+ return train_tfms, test_tfms
167
+
168
+
169
+ def build_model(model_name: str, num_classes: int, pretrained: bool):
170
+ return timm.create_model(
171
+ model_name,
172
+ pretrained=pretrained,
173
+ num_classes=num_classes,
174
+ )
175
+
176
+
177
+ def train_one_epoch(model, loader, criterion, optimizer, scaler, device, epoch):
178
+ model.train()
179
+ running_loss = 0.0
180
+ all_preds = []
181
+ all_labels = []
182
+
183
+ pbar = tqdm(loader, desc=f"Train {epoch}", leave=False)
184
+ for images, labels in pbar:
185
+ images = images.to(device, non_blocking=True)
186
+ labels = labels.to(device, non_blocking=True)
187
+
188
+ optimizer.zero_grad(set_to_none=True)
189
+
190
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
191
+ logits = model(images)
192
+ loss = criterion(logits, labels)
193
+
194
+ if scaler is not None:
195
+ scaler.scale(loss).backward()
196
+ scaler.step(optimizer)
197
+ scaler.update()
198
+ else:
199
+ loss.backward()
200
+ optimizer.step()
201
+
202
+ running_loss += loss.item() * images.size(0)
203
+ preds = logits.argmax(dim=1)
204
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
205
+ all_labels.extend(labels.detach().cpu().numpy().tolist())
206
+
207
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
208
+
209
+ epoch_loss = running_loss / len(loader.dataset)
210
+ acc = accuracy_score(all_labels, all_preds)
211
+ bal_acc = balanced_accuracy_score(all_labels, all_preds)
212
+ return epoch_loss, acc, bal_acc
213
+
214
+
215
+ @torch.no_grad()
216
+ def evaluate(model, loader, criterion, device, split_name="Test"):
217
+ model.eval()
218
+ running_loss = 0.0
219
+ all_preds = []
220
+ all_labels = []
221
+
222
+ pbar = tqdm(loader, desc=split_name, leave=False)
223
+ for images, labels in pbar:
224
+ images = images.to(device, non_blocking=True)
225
+ labels = labels.to(device, non_blocking=True)
226
+
227
+ logits = model(images)
228
+ loss = criterion(logits, labels)
229
+
230
+ running_loss += loss.item() * images.size(0)
231
+ preds = logits.argmax(dim=1)
232
+ all_preds.extend(preds.detach().cpu().numpy().tolist())
233
+ all_labels.extend(labels.detach().cpu().numpy().tolist())
234
+
235
+ val_loss = running_loss / len(loader.dataset)
236
+ acc = accuracy_score(all_labels, all_preds)
237
+ bal_acc = balanced_accuracy_score(all_labels, all_preds)
238
+ return val_loss, acc, bal_acc, np.array(all_labels), np.array(all_preds)
239
+
240
+
241
+ def save_checkpoint(path, model, optimizer, scheduler, epoch, best_metric, args):
242
+ torch.save({
243
+ "epoch": epoch,
244
+ "model": model.state_dict(),
245
+ "optimizer": optimizer.state_dict(),
246
+ "scheduler": scheduler.state_dict() if scheduler is not None else None,
247
+ "best_metric": best_metric,
248
+ "args": vars(args),
249
+ "id_to_label": ID_TO_LABEL,
250
+ }, path)
251
+
252
+
253
+ def parse_args():
254
+ parser = argparse.ArgumentParser()
255
+ parser.add_argument("--images_dir", type=str, required=True, help="EyePACS root containing train/ and test/ folders.")
256
+ parser.add_argument("--csv_dir", type=str, required=True, help="Directory containing Label_EyeQ_train.csv and Label_EyeQ_test.csv.")
257
+ parser.add_argument("--output_dir", type=str, default="./runs/eyeq_vit_base")
258
+
259
+ parser.add_argument("--model", type=str, default="vit_base_patch16_224")
260
+ parser.add_argument("--img_size", type=int, default=224)
261
+ parser.add_argument("--pretrained", action="store_true", default=True)
262
+ parser.add_argument("--no_pretrained", dest="pretrained", action="store_false")
263
+
264
+ parser.add_argument("--epochs", type=int, default=30)
265
+ parser.add_argument("--batch_size", type=int, default=32)
266
+ parser.add_argument("--num_workers", type=int, default=8)
267
+ parser.add_argument("--lr", type=float, default=3e-5)
268
+ parser.add_argument("--weight_decay", type=float, default=1e-4)
269
+ parser.add_argument("--seed", type=int, default=42)
270
+ parser.add_argument("--amp", action="store_true", default=True)
271
+ parser.add_argument("--no_amp", dest="amp", action="store_false")
272
+ parser.add_argument("--class_weights", action="store_true", help="Use inverse-frequency class weights.")
273
+
274
+ return parser.parse_args()
275
+
276
+
277
+ def print_label_counts(name: str, df: pd.DataFrame):
278
+ print(f"{name}: {len(df)}")
279
+ for label_id in [0, 1, 2]:
280
+ count = int((df["quality"] == label_id).sum())
281
+ print(f" {ID_TO_LABEL[label_id]} ({label_id}): {count}")
282
+
283
+
284
+ def main():
285
+ args = parse_args()
286
+ seed_everything(args.seed)
287
+
288
+ output_dir = Path(args.output_dir)
289
+ output_dir.mkdir(parents=True, exist_ok=True)
290
+
291
+ images_root = Path(args.images_dir)
292
+ csv_root = Path(args.csv_dir)
293
+
294
+ train_images_dir = images_root / "train"
295
+ test_images_dir = images_root / "test"
296
+ train_csv = csv_root / "Label_EyeQ_train.csv"
297
+ test_csv = csv_root / "Label_EyeQ_test.csv"
298
+
299
+ train_df = load_eyeq_csv(str(train_csv), str(train_images_dir))
300
+ test_df = load_eyeq_csv(str(test_csv), str(test_images_dir))
301
+
302
+ train_tfms, test_tfms = build_transforms(args.img_size)
303
+
304
+ train_ds = EyeQDataset(train_df, str(train_images_dir), train_tfms)
305
+ test_ds = EyeQDataset(test_df, str(test_images_dir), test_tfms)
306
+
307
+ train_loader = DataLoader(
308
+ train_ds,
309
+ batch_size=args.batch_size,
310
+ shuffle=True,
311
+ num_workers=args.num_workers,
312
+ pin_memory=True,
313
+ drop_last=True,
314
+ )
315
+ test_loader = DataLoader(
316
+ test_ds,
317
+ batch_size=args.batch_size,
318
+ shuffle=False,
319
+ num_workers=args.num_workers,
320
+ pin_memory=True,
321
+ )
322
+
323
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
324
+ model = build_model(args.model, num_classes=3, pretrained=args.pretrained).to(device)
325
+
326
+ if args.class_weights:
327
+ counts = train_df["quality"].value_counts().sort_index().reindex([0, 1, 2], fill_value=1).values
328
+ weights = counts.sum() / (len(counts) * counts)
329
+ weights = torch.tensor(weights, dtype=torch.float32, device=device)
330
+ criterion = nn.CrossEntropyLoss(weight=weights)
331
+ print(f"Using class weights: {weights.detach().cpu().numpy().round(3).tolist()}")
332
+ else:
333
+ criterion = nn.CrossEntropyLoss()
334
+
335
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
336
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
337
+ scaler = torch.cuda.amp.GradScaler() if args.amp and device.type == "cuda" else None
338
+
339
+ print("Dataset summary")
340
+ print(f"Train CSV: {train_csv}")
341
+ print(f"Test CSV: {test_csv}")
342
+ print(f"Train images: {train_images_dir}")
343
+ print(f"Test images: {test_images_dir}")
344
+ print_label_counts("Train", train_df)
345
+ print_label_counts("Test", test_df)
346
+ print(f"Model: {args.model}")
347
+ print(f"Device: {device}")
348
+
349
+ best_bal_acc = -1.0
350
+
351
+ for epoch in range(1, args.epochs + 1):
352
+ train_loss, train_acc, train_bal_acc = train_one_epoch(
353
+ model, train_loader, criterion, optimizer, scaler, device, epoch
354
+ )
355
+ test_loss, test_acc, test_bal_acc, y_true, y_pred = evaluate(
356
+ model, test_loader, criterion, device, split_name="Test"
357
+ )
358
+ scheduler.step()
359
+
360
+ print(
361
+ f"Epoch {epoch:03d}/{args.epochs} | "
362
+ f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} train_bal_acc={train_bal_acc:.4f} | "
363
+ f"test_loss={test_loss:.4f} test_acc={test_acc:.4f} test_bal_acc={test_bal_acc:.4f}"
364
+ )
365
+
366
+ save_checkpoint(output_dir / "last.pt", model, optimizer, scheduler, epoch, best_bal_acc, args)
367
+
368
+ if test_bal_acc > best_bal_acc:
369
+ best_bal_acc = test_bal_acc
370
+ best_path = output_dir / "best.pt"
371
+ save_checkpoint(best_path, model, optimizer, scheduler, epoch, best_bal_acc, args)
372
+
373
+ report = classification_report(
374
+ y_true,
375
+ y_pred,
376
+ labels=[0, 1, 2],
377
+ target_names=[ID_TO_LABEL[i] for i in [0, 1, 2]],
378
+ digits=4,
379
+ )
380
+ cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
381
+
382
+ with open(output_dir / "best_report.txt", "w") as f:
383
+ f.write(f"Best epoch: {epoch}\n")
384
+ f.write(f"Best test balanced accuracy: {best_bal_acc:.4f}\n\n")
385
+ f.write(report)
386
+ f.write("\nConfusion matrix rows=true cols=pred, labels=[Good, Usable, Reject]\n")
387
+ f.write(str(cm))
388
+ f.write("\n")
389
+
390
+ print(f" Saved new best checkpoint: {best_path}")
391
+
392
+ print(f"Training complete. Best test balanced accuracy: {best_bal_acc:.4f}")
393
+ print(f"Outputs saved to: {output_dir}")
394
+
395
+
396
+ if __name__ == "__main__":
397
+ main()