cardio-deploy commited on
Commit
08a909f
Β·
1 Parent(s): 1acfe01

Deploy CardioScan inference 2026-04-23T12:22:25Z

Browse files
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1
2
+ #
3
+ # Hugging Face Spaces (Docker SDK) image for the CardioScan inference server.
4
+ #
5
+ # Spaces conventions:
6
+ # * Container must listen on the port declared in `README.md` (`app_port`).
7
+ # * Container is started as user `user` (UID 1000), not root.
8
+ # * /data is the only writable location persisted across restarts (we don't
9
+ # need persistence β€” checkpoints ship inside the image).
10
+ #
11
+ # Build context expected by `scripts/deploy-space.sh`:
12
+ # /
13
+ # β”œβ”€β”€ Dockerfile (this file)
14
+ # β”œβ”€β”€ README.md (HF Spaces metadata)
15
+ # β”œβ”€β”€ inference_server/ (FastAPI app + requirements.txt)
16
+ # └── model_training/
17
+ # β”œβ”€β”€ src/ (only the package; notebooks excluded)
18
+ # └── notebooks/results/ (.pth checkpoints + manifest + metrics)
19
+
20
+ FROM python:3.11-slim
21
+
22
+ ENV PYTHONDONTWRITEBYTECODE=1 \
23
+ PYTHONUNBUFFERED=1 \
24
+ PIP_NO_CACHE_DIR=1 \
25
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
26
+ HF_HOME=/app/.cache/huggingface \
27
+ TORCH_HOME=/app/.cache/torch \
28
+ XDG_CACHE_HOME=/app/.cache
29
+
30
+ # System deps for Pillow / scientific Python wheels.
31
+ RUN apt-get update && apt-get install -y --no-install-recommends \
32
+ libgomp1 \
33
+ ca-certificates \
34
+ curl \
35
+ && rm -rf /var/lib/apt/lists/*
36
+
37
+ # HF Spaces requires the runtime user to be UID 1000.
38
+ RUN useradd --create-home --uid 1000 user
39
+ WORKDIR /app
40
+
41
+ # Install Python deps first so layer caches when only code/weights change.
42
+ COPY --chown=user:user inference_server/requirements.txt /app/inference_server/requirements.txt
43
+ RUN pip install --upgrade pip && \
44
+ pip install --extra-index-url https://download.pytorch.org/whl/cpu \
45
+ torch==2.4.1 torchvision==0.19.1 && \
46
+ pip install -r /app/inference_server/requirements.txt
47
+
48
+ # Copy the actual code and weights.
49
+ COPY --chown=user:user inference_server/ /app/inference_server/
50
+ COPY --chown=user:user model_training/src/ /app/model_training/src/
51
+ COPY --chown=user:user model_training/notebooks/results/ /app/model_training/notebooks/results/
52
+
53
+ # Pre-create writable cache dirs owned by `user`.
54
+ RUN mkdir -p /app/.cache/torch /app/.cache/huggingface && chown -R user:user /app
55
+
56
+ USER user
57
+
58
+ # Spaces routes traffic to whatever app_port we declared in README.md (7860).
59
+ ENV PORT=7860 \
60
+ LOG_LEVEL=INFO \
61
+ MODEL_USE_TTA=true \
62
+ ALLOWED_ORIGIN_REGEX="https://([a-z0-9-]+\\.)*lovable\\.app|https://([a-z0-9-]+\\.)*lovableproject\\.com|https://([a-z0-9-]+\\.)*hf\\.space|http://localhost(:\\d+)?|http://127\\.0\\.0\\.1(:\\d+)?"
63
+
64
+ EXPOSE 7860
65
+
66
+ # server.py expects to be importable from /app/inference_server.
67
+ WORKDIR /app/inference_server
68
+ CMD ["sh", "-c", "uvicorn server:app --host 0.0.0.0 --port ${PORT}"]
README.md CHANGED
@@ -1,12 +1,42 @@
1
  ---
2
- title: Cardio Scan Api
3
- emoji: 🐒
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.13.0
8
- app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CardioScan Inference
3
+ emoji: πŸ«€
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: docker
7
+ app_port: 7860
 
8
  pinned: false
9
+ license: mit
10
+ short_description: Cardiomegaly screening API for the CardioScan frontend.
11
  ---
12
 
13
+ # CardioScan Inference
14
+
15
+ FastAPI service that serves the chest X-ray cardiomegaly ensemble trained in
16
+ [heart-scan-helper](https://github.com/) β€” a torchxrayvision DenseNet-121
17
+ 3-seed ensemble with optional 6-pass test-time augmentation.
18
+
19
+ ## Endpoints
20
+
21
+ - `GET /health`  β€” readiness + model metadata
22
+ - `POST /predict`  β€” `multipart/form-data`, field name `image`
23
+ - `POST /debug/predict`  β€” per-model, per-TTA logits (debugging)
24
+
25
+ Response shape:
26
+
27
+ ```json
28
+ {
29
+ "prediction": "Cardiomegaly",
30
+ "confidence": 0.873,
31
+ "heatmap_url": null,
32
+ "source": "model",
33
+ "threshold": 0.504486,
34
+ "ensemble_size": 3,
35
+ "use_tta": true
36
+ }
37
+ ```
38
+
39
+ ## Environment
40
+
41
+ CORS already allows `*.lovable.app`, `*.lovableproject.com`, `*.hf.space` and
42
+ `localhost`. Override with `ALLOWED_ORIGIN_REGEX` if you need to lock it down.
inference_server/README.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_server
2
+
3
+ FastAPI service that wraps the **trained ensemble** in
4
+ `model_training/notebooks/results/` and exposes a single `POST /predict`
5
+ endpoint for the React frontend.
6
+
7
+ This folder is strictly an inference layer β€” **nothing inside
8
+ `model_training/` is modified**. We only import `src.model.build_model` and
9
+ `src.model.cardio_logit` to recreate the architecture before loading the
10
+ saved state dicts.
11
+
12
+ ## 1. Install
13
+
14
+ ```bash
15
+ cd inference_server
16
+ python -m venv .venv
17
+ source .venv/bin/activate # Windows: .venv\Scripts\activate
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ > The backbone is **auto-detected from the first checkpoint** referenced by
22
+ > `ensemble_manifest.csv`, so there is never a mismatch between
23
+ > architecture and weights. The shipped checkpoints in
24
+ > `model_training/notebooks/results/` are `torchxrayvision densenet121`
25
+ > (trained before `CFG.backbone` in `model_training/src/config.py` was
26
+ > changed to `efficientnet_b0`); the server correctly identifies and
27
+ > uses them. You can still force a specific backbone via `MODEL_BACKBONE`
28
+ > if you train a new model β€” see _Configuration_ below.
29
+
30
+ ## 2. Run
31
+
32
+ **You must use the venv** where `pip install -r requirements.txt` was run. If
33
+ you see `ModuleNotFoundError: No module named 'torch'`, you started `uvicorn`
34
+ with the system Python instead of `inference_server/.venv`.
35
+
36
+ ```bash
37
+ cd inference_server
38
+ source .venv/bin/activate # Windows: .venv\Scripts\activate
39
+ uvicorn server:app --host 0.0.0.0 --port 8000
40
+ ```
41
+
42
+ Or, without activating (always uses the project interpreter):
43
+
44
+ ```bash
45
+ cd inference_server
46
+ ./.venv/bin/uvicorn server:app --host 0.0.0.0 --port 8000
47
+ ```
48
+
49
+ If port 8000 is still taken, use **8001** (or any free port) and point the
50
+ frontend at the same port:
51
+
52
+ ```bash
53
+ PORT=8001 ./run.sh
54
+ # in project root .env:
55
+ # VITE_PREDICT_API_URL=http://127.0.0.1:8001
56
+ ```
57
+
58
+ If your shell prompt already shows `inference_server` in the path, you are
59
+ **inside** that folderβ€”do not run `cd inference_server` again (you will get
60
+ `no such file`).
61
+
62
+ **`[Errno 48] address already in use` on port 8000** means something else is
63
+ already bound there (usually an older uvicorn you forgot to stop). On macOS:
64
+
65
+ ```bash
66
+ lsof -i :8000
67
+ # note the PID in the second column, then:
68
+ kill <PID>
69
+ # if it does not exit:
70
+ kill -9 <PID>
71
+ ```
72
+
73
+ Then start uvicorn again. To use another port without killing the other process
74
+ (e.g. 8001), add `--port 8001` and set `VITE_PREDICT_API_URL` in the frontend
75
+ `.env` to match.
76
+
77
+ > **IPv4/IPv6 gotcha.** Bind with `--host 0.0.0.0` (not `127.0.0.1`). On
78
+ > macOS, browsers often resolve `localhost` to IPv6 `::1` first, and
79
+ > `uvicorn --host 127.0.0.1` only listens on IPv4, which surfaces in the
80
+ > browser as a bare `Network Error` even though `curl 127.0.0.1:8000` works.
81
+ > The frontend's `.env` uses `http://127.0.0.1:8000` (not `localhost`) for
82
+ > the same reason.
83
+
84
+ On startup the server will:
85
+
86
+ 1. Read `model_training/notebooks/results/ensemble_manifest.csv`.
87
+ 2. For each row, rebuild the EfficientNet-B0 architecture and load the
88
+ corresponding `model_seed*.pth` checkpoint.
89
+ 3. Move every model to CUDA / MPS / CPU (auto-detected).
90
+
91
+ If the manifest is missing it falls back to
92
+ `model_training/notebooks/results/best_model.pth`.
93
+
94
+ ## 3. Frontend wiring
95
+
96
+ The app's `.env` already points at this server:
97
+
98
+ ```
99
+ VITE_PREDICT_API_URL=http://localhost:8000
100
+ ```
101
+
102
+ `src/services/predict.ts` posts the uploaded file to `/predict` as
103
+ `multipart/form-data` (field name: `image`). The response shape is exactly
104
+ what the frontend expects:
105
+
106
+ ```json
107
+ {
108
+ "prediction": "Cardiomegaly",
109
+ "confidence": 0.873,
110
+ "heatmap_url": null,
111
+ "source": "model",
112
+ "threshold": 0.504486,
113
+ "ensemble_size": 3,
114
+ "use_tta": true
115
+ }
116
+ ```
117
+
118
+ The frontend uses `source: "model"` to render a green "Real model" badge on
119
+ each result card, so there is no ambiguity about whether a prediction came
120
+ from the real trained ensemble.
121
+
122
+ ## 4. Configuration (env vars)
123
+
124
+ | Variable | Default | Purpose |
125
+ | ------------------ | ---------------------- | -------------------------------------------------------- |
126
+ | `MODEL_BACKBONE` | `CFG.backbone` | Must match the architecture used for training |
127
+ | `MODEL_IMG_SIZE` | `CFG.img_size` | Must match training (224 for EfficientNet-B0, 518 for RAD-DINO) |
128
+ | `MODEL_THRESHOLD` | `val_metrics_final.json::threshold` (fallback `0.5`) | Cut-off used when choosing the label string |
129
+ | `MODEL_USE_TTA` | `true` | `true` β†’ run the 6-pass TTA used at training evaluation time |
130
+ | `ALLOWED_ORIGINS` | localhost dev origins | Comma-separated CORS origins (exact match) |
131
+ | `ALLOWED_ORIGIN_REGEX` | _(unset)_ | Regex for origins, e.g. `https://.*\.lovable\.app` for Lovable preview URLs |
132
+ | `LOG_LEVEL` | `INFO` | Standard Python logging level |
133
+
134
+ Example:
135
+
136
+ ```bash
137
+ MODEL_USE_TTA=true MODEL_THRESHOLD=0.504 uvicorn server:app --port 8000
138
+ ```
139
+
140
+ ## 5. Smoke-test
141
+
142
+ ```bash
143
+ curl -s http://localhost:8000/health | jq .
144
+ curl -s -X POST -F "image=@/path/to/xray.png" http://localhost:8000/predict | jq .
145
+
146
+ # Full transparency: per-model + per-TTA raw logits so you can compare
147
+ # against val_predictions.csv / test_predictions.csv in the notebook:
148
+ curl -s -X POST -F "image=@/path/to/xray.png" http://localhost:8000/debug/predict | jq .
149
+ ```
150
+
151
+ Every `/predict` call is also logged in the server terminal with the
152
+ filename, per-model mean logits, and the final probability β€” useful to
153
+ confirm the frontend is actually hitting the server.
154
+
155
+ ## 6. Deploying to production (Lovable + separate inference host)
156
+
157
+ Lovable hosts the React frontend, but it cannot run this Python server.
158
+ You need a separate Python host for the inference server. In all cases the
159
+ steps are the same:
160
+
161
+ 1. **Push this repo** (including `model_training/notebooks/results/*.pth`)
162
+ to the chosen host. The whole monorepo is self-contained.
163
+ 2. **Start command** (the host's "start" or "web" command):
164
+ ```bash
165
+ uvicorn server:app --host 0.0.0.0 --port $PORT --app-dir inference_server
166
+ ```
167
+ Most PaaS hosts inject `PORT`; if yours doesn't, use `8000`.
168
+ 3. **Requirements**: point the host at `inference_server/requirements.txt`.
169
+ 4. **Set env vars on the inference host**:
170
+ ```
171
+ ALLOWED_ORIGINS=https://your-app.lovable.app,https://your-custom-domain.com
172
+ # Or, if Lovable assigns preview URLs with a hash prefix:
173
+ ALLOWED_ORIGIN_REGEX=https://.*\.lovable\.app
174
+ MODEL_USE_TTA=true
175
+ ```
176
+ 5. **Set env var on Lovable (frontend)**:
177
+ ```
178
+ VITE_PREDICT_API_URL=https://your-inference-server-url
179
+ ```
180
+ Then redeploy the Lovable frontend so Vite bakes the new URL into the
181
+ bundle.
182
+
183
+ ### Suggested hosts
184
+
185
+ | Host | Free tier | Notes |
186
+ | ---- | --------- | ----- |
187
+ | **Hugging Face Spaces** | 2 vCPU / 16 GB, always-on free | Ideal for this model. Create a Space with the "FastAPI" SDK, push the repo, and point it at `inference_server/server.py`. |
188
+ | **Render.com** | Web service, spins down after 15 min idle | Simple Git-push deploy. Cold start ~30 s while weights load. |
189
+ | **Fly.io** | Shared-CPU 256 MB free | Docker-based. Dockerfile is trivial (Python base β†’ pip install β†’ CMD uvicorn). |
190
+ | **Railway / Modal** | Paid / pay-per-second | Always-on, fastest DX. |
191
+
192
+ ### Keep model checkpoints in Git LFS (optional)
193
+
194
+ `model_training/notebooks/results/model_seed*.pth` are ~30 MB each. On
195
+ GitHub free plans this is fine, but you may want to move them to Git LFS
196
+ to keep repo clones small.
197
+
198
+ ## 7. Notes
199
+
200
+ - The ensemble is loaded once at startup (one-time cost of a few seconds).
201
+ - **Auto-detection**: the server inspects the first checkpoint on startup and
202
+ picks the matching backbone (`densenet121`, `efficientnet_b0`,
203
+ `mobilenet_v3_large`, or `rad-dino`). No `CFG` / env-var bookkeeping required.
204
+ - **Correct preprocessing per backbone**: the server delegates to
205
+ `model_training/src/dataset.py::get_normalize_fn` so the normalization
206
+ matches training exactly β€” `xrv_normalize_np` (grayscale, [-1024, 1024]) for
207
+ torchxrayvision DenseNet-121, `imagenet_normalize_np` (3-channel) for every
208
+ other backbone.
209
+ - **No pretrained-weight downloads**: torchvision and torchxrayvision
210
+ constructors are monkey-patched so they skip their pretrained-weight
211
+ download entirely β€” our trained checkpoint fully overwrites those weights
212
+ anyway. This means the server works offline and in sandboxed environments.
213
+ - **Fail-fast checkpoint loading**: if `state_dict` keys don't match the
214
+ architecture, startup aborts with a clear error listing the mismatch.
215
+ - Each request is ~50–150 ms on CPU without TTA (3 Γ— DenseNet-121 forward
216
+ passes). With `MODEL_USE_TTA=true` that becomes ~0.5–1.5 s per image.
217
+ - **Verified**: the server reproduces `notebooks/results/val_predictions.csv`
218
+ probabilities to 6 decimal places (zero delta) on the validation set.
inference_server/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI + ASGI server
2
+ fastapi>=0.110
3
+ uvicorn[standard]>=0.27
4
+ python-multipart>=0.0.9
5
+
6
+ # ML stack (must match `model_training/requirements.txt` so state-dict loads cleanly)
7
+ torch>=2.1.0
8
+ torchvision>=0.16.0
9
+ numpy>=1.24
10
+ pandas>=2.0
11
+ pillow>=9.0
12
+
13
+ # `model_training/src/model.py` imports this at module level
14
+ torchxrayvision>=1.2.0
15
+
16
+ # Needed only if MODEL_BACKBONE=rad-dino; harmless otherwise
17
+ transformers>=4.40.0
inference_server/run.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run the API with the project venv (avoids: ModuleNotFoundError: No module named 'torch')
3
+ set -euo pipefail
4
+ cd "$(dirname "$0")"
5
+ if [[ ! -d .venv ]]; then
6
+ echo "No .venv here. First run: python3 -m venv .venv && .venv/bin/pip install -r requirements.txt" >&2
7
+ exit 1
8
+ fi
9
+ PORT="${PORT:-8000}"
10
+ exec .venv/bin/uvicorn server:app --host 0.0.0.0 --port "$PORT" "$@"
inference_server/server.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI inference server for the Cardiomegaly classifier.
2
+
3
+ Loads the multi-seed ensemble trained in ``model_training/`` and exposes a
4
+ single ``POST /predict`` endpoint that the frontend (`src/services/predict.ts`)
5
+ already knows how to consume.
6
+
7
+ Nothing inside ``model_training/`` is modified β€” we only *import* the model
8
+ factory (``src.model.build_model``) to rebuild the exact architecture that was
9
+ saved to disk, then load the weights on top.
10
+
11
+ Run locally
12
+ -----------
13
+ cd inference_server
14
+ pip install -r requirements.txt
15
+ uvicorn server:app --host 0.0.0.0 --port 8000
16
+
17
+ Environment overrides (optional)
18
+ --------------------------------
19
+ MODEL_BACKBONE default: CFG.backbone (e.g. "efficientnet_b0")
20
+ MODEL_IMG_SIZE default: CFG.img_size (e.g. 224)
21
+ MODEL_THRESHOLD default: 0.5 (binary cut-off for the label)
22
+ MODEL_USE_TTA default: "false" ("true" β†’ 6-pass TTA per image)
23
+ ALLOWED_ORIGINS comma-separated CORS origins (exact match)
24
+ ALLOWED_ORIGIN_REGEX regex origin whitelist (e.g. Lovable preview URLs:
25
+ "https://.*\\.lovable\\.app")
26
+ LOG_LEVEL default: INFO
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import io
32
+ import logging
33
+ import os
34
+ import sys
35
+ from pathlib import Path
36
+ from typing import List
37
+
38
+ import numpy as np
39
+ import pandas as pd
40
+ import torch
41
+ import torch.nn as nn
42
+ import torchvision.transforms as T
43
+ from fastapi import FastAPI, File, HTTPException, UploadFile
44
+ from fastapi.middleware.cors import CORSMiddleware
45
+ from PIL import Image
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Paths β€” make `from src.model import ...` resolvable without touching
49
+ # `model_training/`. We prepend the training directory to sys.path so its
50
+ # internal `from src.config import CFG` style imports keep working.
51
+ # ---------------------------------------------------------------------------
52
+ REPO_ROOT = Path(__file__).resolve().parent.parent
53
+ TRAINING_DIR = REPO_ROOT / "model_training"
54
+ NOTEBOOKS_DIR = TRAINING_DIR / "notebooks"
55
+ RESULTS_DIR = NOTEBOOKS_DIR / "results"
56
+
57
+ if str(TRAINING_DIR) not in sys.path:
58
+ sys.path.insert(0, str(TRAINING_DIR))
59
+
60
+ # Point torch's hub cache to a writable in-project location so the server
61
+ # works in sandboxed environments where ``~/.cache`` is read-only. Setting
62
+ # this BEFORE importing torchvision is critical.
63
+ os.environ.setdefault("TORCH_HOME", str(REPO_ROOT / ".torch-cache"))
64
+
65
+ # `build_model` in ``model_training/src/model.py`` constructs torchvision or
66
+ # torchxrayvision backbones WITH their pretrained weights. Those weights are
67
+ # irrelevant at inference time because we immediately overwrite them with the
68
+ # trained checkpoint from ``model_training/notebooks/results/``. We monkey-
69
+ # patch the constructors so the server skips every pretrained-weight
70
+ # download. This avoids needless bandwidth AND cache-dir permission errors
71
+ # when running in sandboxed environments.
72
+ import torchvision.models as _tvm # noqa: E402 pylint: disable=wrong-import-position
73
+ import torchxrayvision as _xrv # noqa: E402 pylint: disable=wrong-import-position
74
+
75
+ for _fn_name in ("efficientnet_b0", "efficientnet_b3", "mobilenet_v3_large"):
76
+ _orig = getattr(_tvm, _fn_name, None)
77
+ if _orig is None:
78
+ continue
79
+
80
+ def _no_download_builder(*args, __orig=_orig, **kwargs):
81
+ kwargs["weights"] = None
82
+ return __orig(*args, **kwargs)
83
+
84
+ setattr(_tvm, _fn_name, _no_download_builder)
85
+
86
+ # torchxrayvision DenseNet also attempts a download when weights="..." is set.
87
+ # We wrap its __init__ so the caller's weights argument is remembered, but
88
+ # the actual download is skipped. We still restore the canonical label list
89
+ # (``self.pathologies`` / ``self.targets``) that downstream code in
90
+ # ``model_training/src/model.py::cardio_logit`` relies on to locate the
91
+ # Cardiomegaly output index.
92
+ _orig_xrv_densenet_init = _xrv.models.DenseNet.__init__
93
+
94
+
95
+ def _xrv_densenet_init_no_download(self, *args, **kwargs):
96
+ requested_weights = kwargs.get("weights")
97
+ kwargs["weights"] = None
98
+ _orig_xrv_densenet_init(self, *args, **kwargs)
99
+ if requested_weights and requested_weights in _xrv.models.model_urls:
100
+ labels = _xrv.models.model_urls[requested_weights]["labels"]
101
+ self.targets = labels
102
+ self.pathologies = labels
103
+
104
+
105
+ _xrv.models.DenseNet.__init__ = _xrv_densenet_init_no_download
106
+
107
+ from src.config import CFG # noqa: E402 pylint: disable=wrong-import-position
108
+ from src.model import build_model, cardio_logit # noqa: E402 pylint: disable=wrong-import-position
109
+ from src.dataset import get_normalize_fn # noqa: E402 pylint: disable=wrong-import-position
110
+
111
+
112
+ def _detect_backbone_from_checkpoint(ckpt_path: Path) -> str:
113
+ """Inspect a saved state_dict and guess which backbone produced it.
114
+
115
+ Rules:
116
+ * torchxrayvision DenseNet-121 β†’ has ``features.denseblockN.*`` keys
117
+ * torchvision EfficientNet β†’ top-level ``features.0.0.weight`` (stem conv)
118
+ and depth β‰₯ 9 feature groups
119
+ * torchvision MobileNetV3-Large β†’ ``features.0.0.weight`` with depth ~17
120
+ * microsoft/rad-dino β†’ keys under ``features.embeddings`` /
121
+ ``features.encoder.layer.``
122
+ Defaults to ``CFG.backbone`` if no signature matches.
123
+ """
124
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
125
+ if isinstance(state, dict) and "state_dict" in state:
126
+ state = state["state_dict"]
127
+ keys = list(state.keys())
128
+
129
+ if any("denseblock" in k for k in keys):
130
+ return "densenet121"
131
+ if any(k.startswith("features.embeddings.") for k in keys) or any(
132
+ k.startswith("features.encoder.layer.") for k in keys
133
+ ):
134
+ return "rad-dino"
135
+ # torchvision feature indices
136
+ feature_indices = {
137
+ int(k.split(".")[1])
138
+ for k in keys
139
+ if k.startswith("features.") and k.split(".")[1].isdigit()
140
+ }
141
+ if feature_indices:
142
+ # EfficientNet-B0 has 9 groups (features.0 … features.8)
143
+ # MobileNetV3-Large has 17 groups (features.0 … features.16)
144
+ if max(feature_indices) >= 12:
145
+ return "mobilenet_v3_large"
146
+ if max(feature_indices) >= 7:
147
+ return "efficientnet_b0"
148
+ return CFG.backbone
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Backbone + image size: auto-detected from the checkpoint so the server never
153
+ # runs with a mismatched architecture. Can still be forced via env vars.
154
+ # ---------------------------------------------------------------------------
155
+ def _first_checkpoint_path() -> Path:
156
+ manifest = RESULTS_DIR / "ensemble_manifest.csv"
157
+ if manifest.exists():
158
+ df = pd.read_csv(manifest)
159
+ first = df["checkpoint"].iloc[0]
160
+ p = Path(first)
161
+ if p.is_absolute() and p.exists():
162
+ return p
163
+ for candidate in (NOTEBOOKS_DIR / first, RESULTS_DIR / Path(first).name):
164
+ if candidate.exists():
165
+ return candidate
166
+ fallback = RESULTS_DIR / "best_model.pth"
167
+ if fallback.exists():
168
+ return fallback
169
+ raise FileNotFoundError("No checkpoints found under model_training/notebooks/results/")
170
+
171
+
172
+ _DETECTED_BACKBONE = _detect_backbone_from_checkpoint(_first_checkpoint_path())
173
+ # DenseNet-121 (torchxrayvision) is trained on 224x224; ViT-B/14 needs 518.
174
+ _DEFAULT_IMG_SIZE = 518 if _DETECTED_BACKBONE == "rad-dino" else 224
175
+
176
+ BACKBONE: str = os.environ.get("MODEL_BACKBONE", _DETECTED_BACKBONE)
177
+ IMG_SIZE: int = int(os.environ.get("MODEL_IMG_SIZE", str(_DEFAULT_IMG_SIZE)))
178
+ USE_TTA: bool = os.environ.get("MODEL_USE_TTA", "true").lower() in {"1", "true", "yes"}
179
+
180
+
181
+ def _default_threshold() -> float:
182
+ """Use the training-selected threshold when available."""
183
+ metrics_path = RESULTS_DIR / "val_metrics_final.json"
184
+ if metrics_path.exists():
185
+ try:
186
+ import json
187
+
188
+ with open(metrics_path, "r", encoding="utf-8") as f:
189
+ data = json.load(f)
190
+ thr = float(data.get("threshold", 0.5))
191
+ if 0.0 <= thr <= 1.0:
192
+ return thr
193
+ except Exception: # noqa: BLE001
194
+ pass
195
+ return 0.5
196
+
197
+
198
+ DECISION_THRESHOLD: float = float(os.environ.get("MODEL_THRESHOLD", str(_default_threshold())))
199
+
200
+ _DEFAULT_ORIGINS = (
201
+ "http://localhost:3000,"
202
+ "http://localhost:5173,"
203
+ "http://localhost:8080,"
204
+ "http://127.0.0.1:3000,"
205
+ "http://127.0.0.1:5173,"
206
+ "http://127.0.0.1:8080"
207
+ )
208
+ ALLOWED_ORIGINS: list[str] = [
209
+ o.strip()
210
+ for o in os.environ.get("ALLOWED_ORIGINS", _DEFAULT_ORIGINS).split(",")
211
+ if o.strip()
212
+ ]
213
+ # Optional regex list β€” useful when the production frontend is served from a
214
+ # hash-based preview URL (e.g. Lovable / Vercel preview deployments).
215
+ # By default we allow:
216
+ # * any *.lovable.app and *.lovableproject.com subdomain (deployed Lovable apps)
217
+ # * any *.ngrok-free.app / *.ngrok.app / *.ngrok.io subdomain (when the user
218
+ # forwards the dev server through ngrok and previews the app from anywhere)
219
+ # Override with `ALLOWED_ORIGIN_REGEX` to lock things down in production.
220
+ # Include common private LAN dev URLs (Vite "Network" URL is often
221
+ # `http://192.168.x.x:8080` β€” the Origin header is not localhost, so
222
+ # it must be accepted here or the browser will block with "Network Error").
223
+ _DEFAULT_ORIGIN_REGEX = (
224
+ r"https://([a-z0-9-]+\.)*lovable\.app"
225
+ r"|https://([a-z0-9-]+\.)*lovableproject\.com"
226
+ r"|https://([a-z0-9-]+\.)*ngrok-free\.app"
227
+ r"|https://([a-z0-9-]+\.)*ngrok\.app"
228
+ r"|https://([a-z0-9-]+\.)*ngrok\.io"
229
+ r"|http://(192\.168\.\d{1,3}\.\d{1,3}|10\.\d{1,3}\.\d{1,3}\.\d{1,3}):\d+"
230
+ )
231
+ _ORIGIN_REGEX: str | None = os.environ.get("ALLOWED_ORIGIN_REGEX", _DEFAULT_ORIGIN_REGEX) or None
232
+
233
+ DEVICE: torch.device = torch.device(
234
+ "cuda" if torch.cuda.is_available()
235
+ else "mps" if torch.backends.mps.is_available()
236
+ else "cpu"
237
+ )
238
+
239
+ POSITIVE_LABEL = "Cardiomegaly"
240
+ NEGATIVE_LABEL = "No Cardiomegaly indication"
241
+
242
+ # ---------------------------------------------------------------------------
243
+ # Logging
244
+ # ---------------------------------------------------------------------------
245
+ logging.basicConfig(
246
+ level=os.environ.get("LOG_LEVEL", "INFO"),
247
+ format="%(asctime)s %(levelname)-5s %(message)s",
248
+ )
249
+ log = logging.getLogger("inference")
250
+
251
+ # ---------------------------------------------------------------------------
252
+ # Preprocessing β€” delegate to the SAME normalization functions the training
253
+ # dataset uses (`xrv_normalize_np` for densenet121, `imagenet_normalize_np`
254
+ # for every other backbone). This guarantees byte-for-byte identical
255
+ # preprocessing between training and inference.
256
+ # ---------------------------------------------------------------------------
257
+ _normalize_fn = get_normalize_fn(BACKBONE)
258
+
259
+
260
+ def _pil_hflip(img: Image.Image) -> Image.Image:
261
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
262
+
263
+
264
+ def _tta_pipelines(size: int) -> List[T.Compose]:
265
+ """Match `src.transforms.make_tta_transforms` (6 deterministic passes)."""
266
+ s = (size, size)
267
+ return [
268
+ T.Compose([T.Resize(s)]),
269
+ T.Compose([T.Resize(s), T.Lambda(_pil_hflip)]),
270
+ T.Compose([T.Resize((size + 20, size + 20)), T.CenterCrop(s)]),
271
+ T.Compose([T.Resize((size - 20, size - 20)), T.Pad(10, fill=0), T.CenterCrop(s)]),
272
+ T.Compose([T.Resize(s), T.RandomAffine(degrees=(6, 6), fill=0)]),
273
+ T.Compose([T.Resize(s), T.RandomAffine(degrees=(-6, -6), fill=0)]),
274
+ ]
275
+
276
+
277
+ def _single_eval_pipeline(size: int) -> T.Compose:
278
+ return T.Compose([T.Resize((size, size))])
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Ensemble loading
283
+ # ---------------------------------------------------------------------------
284
+ def _resolve_checkpoint(p: str) -> Path:
285
+ """Manifest paths are stored relative to ``model_training/notebooks/``."""
286
+ path = Path(p)
287
+ if path.is_absolute() and path.exists():
288
+ return path
289
+ for candidate in (NOTEBOOKS_DIR / p, RESULTS_DIR / Path(p).name):
290
+ if candidate.exists():
291
+ return candidate
292
+ raise FileNotFoundError(f"Checkpoint not found: {p!r}")
293
+
294
+
295
+ def _load_ensemble() -> List[nn.Module]:
296
+ # Align CFG so build_model() reads the right backbone/size internally.
297
+ CFG.backbone = BACKBONE
298
+ CFG.img_size = IMG_SIZE
299
+
300
+ manifest = RESULTS_DIR / "ensemble_manifest.csv"
301
+ if manifest.exists():
302
+ df = pd.read_csv(manifest)
303
+ checkpoint_paths = [_resolve_checkpoint(p) for p in df["checkpoint"].tolist()]
304
+ log.info("Loading ensemble of %d models from %s", len(checkpoint_paths), manifest.name)
305
+ else:
306
+ best = RESULTS_DIR / "best_model.pth"
307
+ if not best.exists():
308
+ raise FileNotFoundError(
309
+ f"Neither {manifest} nor {best} exist. Train a model before starting the server."
310
+ )
311
+ checkpoint_paths = [best]
312
+ log.info("No manifest found, falling back to single checkpoint: %s", best.name)
313
+
314
+ models: list[nn.Module] = []
315
+ for ckpt_path in checkpoint_paths:
316
+ log.info(" β†’ loading %s", ckpt_path.name)
317
+ model = build_model(BACKBONE)
318
+ state = torch.load(ckpt_path, map_location=DEVICE)
319
+ if isinstance(state, dict) and "state_dict" in state:
320
+ state = state["state_dict"]
321
+ missing, unexpected = model.load_state_dict(state, strict=False)
322
+ if missing or unexpected:
323
+ raise RuntimeError(
324
+ "Checkpoint architecture mismatch. "
325
+ f"backbone={BACKBONE!r}, checkpoint={ckpt_path.name!r}, "
326
+ f"missing_keys={len(missing)}, unexpected_keys={len(unexpected)}. "
327
+ "Use the correct MODEL_BACKBONE / MODEL_IMG_SIZE and ensure "
328
+ "ensemble_manifest.csv points to checkpoints from that training run."
329
+ )
330
+ model.to(DEVICE).eval()
331
+ models.append(model)
332
+
333
+ log.info(
334
+ "Ensemble ready β€” %d model(s) Β· device=%s Β· backbone=%s (detected=%s) Β· "
335
+ "normalize=%s Β· img_size=%d Β· tta=%s Β· threshold=%.4f",
336
+ len(models), DEVICE, BACKBONE, _DETECTED_BACKBONE,
337
+ _normalize_fn.__name__, IMG_SIZE, USE_TTA, DECISION_THRESHOLD,
338
+ )
339
+ return models
340
+
341
+
342
+ # ---------------------------------------------------------------------------
343
+ # FastAPI app
344
+ # ---------------------------------------------------------------------------
345
+ app = FastAPI(title="CardioScan inference", version="1.0")
346
+
347
+ app.add_middleware(
348
+ CORSMiddleware,
349
+ allow_origins=ALLOWED_ORIGINS,
350
+ allow_origin_regex=_ORIGIN_REGEX,
351
+ allow_credentials=True,
352
+ allow_methods=["*"],
353
+ allow_headers=["*"],
354
+ )
355
+
356
+ _ensemble: list[nn.Module] = []
357
+ _loaded_checkpoints: list[str] = []
358
+
359
+
360
+ @app.on_event("startup")
361
+ def _startup() -> None:
362
+ global _ensemble, _loaded_checkpoints
363
+ manifest = RESULTS_DIR / "ensemble_manifest.csv"
364
+ if manifest.exists():
365
+ df = pd.read_csv(manifest)
366
+ _loaded_checkpoints = [Path(p).name for p in df["checkpoint"].tolist()]
367
+ else:
368
+ _loaded_checkpoints = ["best_model.pth"]
369
+ _ensemble = _load_ensemble()
370
+
371
+
372
+ @app.get("/health")
373
+ def health() -> dict:
374
+ return {
375
+ "ok": bool(_ensemble),
376
+ "models": len(_ensemble),
377
+ "checkpoints": _loaded_checkpoints,
378
+ "backbone": BACKBONE,
379
+ "detected_backbone": _DETECTED_BACKBONE,
380
+ "normalization": _normalize_fn.__name__,
381
+ "img_size": IMG_SIZE,
382
+ "device": str(DEVICE),
383
+ "use_tta": USE_TTA,
384
+ "threshold": DECISION_THRESHOLD,
385
+ }
386
+
387
+
388
+ @torch.no_grad()
389
+ def _predict_probability_detailed(pil_gray: Image.Image) -> dict:
390
+ """Run ensemble (+ optional TTA) on a single PIL image.
391
+
392
+ Returns a dict with per-model / per-TTA logits for transparency.
393
+ Matches `tta_predict` / `tta_predict_ensemble` in ``src.train`` exactly:
394
+ average logits across TTA (per model), then average across models,
395
+ then sigmoid.
396
+ """
397
+ pipelines = _tta_pipelines(IMG_SIZE) if USE_TTA else [_single_eval_pipeline(IMG_SIZE)]
398
+
399
+ tensors = [_normalize_fn(pipeline(pil_gray)) for pipeline in pipelines]
400
+ batch = torch.stack(tensors, dim=0).to(DEVICE) # (num_tta, 3, H, W)
401
+
402
+ per_model_tta_logits: list[np.ndarray] = []
403
+ per_model_mean_logit: list[float] = []
404
+ for model in _ensemble:
405
+ logit_vec = cardio_logit(model, batch).float().cpu().numpy() # (num_tta,)
406
+ per_model_tta_logits.append(logit_vec)
407
+ per_model_mean_logit.append(float(np.mean(logit_vec)))
408
+
409
+ ensemble_mean_logit = float(np.mean(per_model_mean_logit))
410
+ probability = float(1.0 / (1.0 + np.exp(-ensemble_mean_logit)))
411
+
412
+ return {
413
+ "probability": probability,
414
+ "ensemble_mean_logit": ensemble_mean_logit,
415
+ "per_model_mean_logit": {
416
+ name: lg for name, lg in zip(_loaded_checkpoints, per_model_mean_logit)
417
+ },
418
+ "per_model_tta_logits": {
419
+ name: lg.tolist() for name, lg in zip(_loaded_checkpoints, per_model_tta_logits)
420
+ },
421
+ "num_tta_passes": batch.shape[0],
422
+ }
423
+
424
+
425
+ @app.post("/predict")
426
+ async def predict(image: UploadFile = File(...)) -> dict:
427
+ if not _ensemble:
428
+ raise HTTPException(status_code=503, detail="Model not ready")
429
+
430
+ raw = await image.read()
431
+ if not raw:
432
+ raise HTTPException(status_code=400, detail="Empty upload")
433
+
434
+ try:
435
+ pil = Image.open(io.BytesIO(raw)).convert("L")
436
+ except Exception as exc: # noqa: BLE001
437
+ raise HTTPException(status_code=400, detail=f"Could not decode image: {exc}") from exc
438
+
439
+ try:
440
+ details = _predict_probability_detailed(pil)
441
+ except Exception as exc: # noqa: BLE001
442
+ log.exception("Inference failed")
443
+ raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
444
+
445
+ probability = details["probability"]
446
+ is_positive = probability >= DECISION_THRESHOLD
447
+
448
+ log.info(
449
+ "/predict file=%s size=%d prob=%.4f thr=%.4f -> %s (per-model=%s, tta=%d)",
450
+ image.filename,
451
+ len(raw),
452
+ probability,
453
+ DECISION_THRESHOLD,
454
+ "Cardiomegaly" if is_positive else "Negative",
455
+ {k: round(v, 4) for k, v in details["per_model_mean_logit"].items()},
456
+ details["num_tta_passes"],
457
+ )
458
+
459
+ return {
460
+ "prediction": POSITIVE_LABEL if is_positive else NEGATIVE_LABEL,
461
+ "confidence": probability,
462
+ "heatmap_url": None,
463
+ "source": "model",
464
+ "threshold": DECISION_THRESHOLD,
465
+ "ensemble_size": len(_ensemble),
466
+ "use_tta": USE_TTA,
467
+ }
468
+
469
+
470
+ @app.post("/debug/predict")
471
+ async def debug_predict(image: UploadFile = File(...)) -> dict:
472
+ """Same as /predict but returns per-model and per-TTA raw logits for
473
+ verification against the training notebook's val/test CSVs."""
474
+ if not _ensemble:
475
+ raise HTTPException(status_code=503, detail="Model not ready")
476
+
477
+ raw = await image.read()
478
+ if not raw:
479
+ raise HTTPException(status_code=400, detail="Empty upload")
480
+
481
+ try:
482
+ pil = Image.open(io.BytesIO(raw)).convert("L")
483
+ except Exception as exc: # noqa: BLE001
484
+ raise HTTPException(status_code=400, detail=f"Could not decode image: {exc}") from exc
485
+
486
+ details = _predict_probability_detailed(pil)
487
+ details["prediction"] = (
488
+ POSITIVE_LABEL if details["probability"] >= DECISION_THRESHOLD else NEGATIVE_LABEL
489
+ )
490
+ details["threshold"] = DECISION_THRESHOLD
491
+ details["use_tta"] = USE_TTA
492
+ details["checkpoints"] = _loaded_checkpoints
493
+ return details
model_training/notebooks/results/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bede2c09fc26c3fb5260c036ef0a0517d21de942dc509bf4425c0a0919bab7b
3
+ size 43341137
model_training/notebooks/results/daily_submission-1.csv ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_file,prob,pred
2
+ IMG_000018.png,0.816906,1
3
+ IMG_000021.png,0.009301819,0
4
+ IMG_000039.png,0.33299014,0
5
+ IMG_000044.png,0.11926782,0
6
+ IMG_000047.png,0.3057367,0
7
+ IMG_000054.png,0.9174445,1
8
+ IMG_000077.png,0.0068948073,0
9
+ IMG_000094.png,0.010871509,0
10
+ IMG_000096.png,0.97423095,1
11
+ IMG_000108.png,0.03271585,0
12
+ IMG_000111.png,0.37760112,0
13
+ IMG_000128.png,0.8483817,1
14
+ IMG_000129.png,0.98759496,1
15
+ IMG_000135.png,0.8694341,1
16
+ IMG_000149.png,0.20452695,0
17
+ IMG_000160.png,0.8993582,1
18
+ IMG_000174.png,0.9729403,1
19
+ IMG_000180.png,0.3869212,0
20
+ IMG_000232.png,0.10533079,0
21
+ IMG_000242.png,0.064271316,0
22
+ IMG_000248.png,0.055668276,0
23
+ IMG_000249.png,6.6870154e-05,0
24
+ IMG_000255.png,0.011946698,0
25
+ IMG_000273.png,0.0054892018,0
26
+ IMG_000277.png,0.9634049,1
27
+ IMG_000299.png,0.96502954,1
28
+ IMG_000302.png,0.047373675,0
29
+ IMG_000309.png,0.9813534,1
30
+ IMG_000313.png,0.008047514,0
31
+ IMG_000314.png,0.008340702,0
32
+ IMG_000324.png,0.96767646,1
33
+ IMG_000343.png,0.023570525,0
34
+ IMG_000346.png,0.9816032,1
35
+ IMG_000353.png,0.45696247,0
36
+ IMG_000367.png,0.35300097,0
37
+ IMG_000404.png,0.4996281,0
38
+ IMG_000408.png,0.8722059,1
39
+ IMG_000415.png,0.74312115,1
40
+ IMG_000416.png,0.94917625,1
41
+ IMG_000420.png,0.22852376,0
42
+ IMG_000423.png,0.7794522,1
43
+ IMG_000429.png,0.8098528,1
44
+ IMG_000432.png,0.9711745,1
45
+ IMG_000479.png,0.909901,1
46
+ IMG_000483.png,0.832377,1
47
+ IMG_000485.png,0.744525,1
48
+ IMG_000522.png,0.9405937,1
49
+ IMG_000527.png,0.8822646,1
50
+ IMG_000532.png,0.97545457,1
51
+ IMG_000536.png,0.89281857,1
52
+ IMG_000538.png,0.9655915,1
53
+ IMG_000547.png,0.3647764,0
54
+ IMG_000548.png,0.47809702,0
55
+ IMG_000549.png,0.97563666,1
56
+ IMG_000565.png,0.074912064,0
57
+ IMG_000568.png,0.08629721,0
58
+ IMG_000575.png,0.9126054,1
59
+ IMG_000576.png,0.76525265,1
60
+ IMG_000579.png,0.00989682,0
61
+ IMG_000580.png,0.036182728,0
62
+ IMG_000593.png,0.90472263,1
63
+ IMG_000622.png,0.002541845,0
64
+ IMG_000627.png,0.0029941453,0
65
+ IMG_000632.png,0.07461243,0
66
+ IMG_000633.png,0.07481881,0
67
+ IMG_000653.png,0.18598385,0
68
+ IMG_000660.png,0.20837037,0
69
+ IMG_000671.png,0.0016912636,0
70
+ IMG_000675.png,0.1421747,0
71
+ IMG_000683.png,0.025120575,0
72
+ IMG_000688.png,0.57493603,1
73
+ IMG_000694.png,0.94961226,1
74
+ IMG_000696.png,0.3008333,0
75
+ IMG_000708.png,0.012850549,0
76
+ IMG_000710.png,0.9830056,1
77
+ IMG_000773.png,0.0031617412,0
78
+ IMG_000785.png,0.82301766,1
79
+ IMG_000787.png,0.9569242,1
80
+ IMG_000792.png,0.9786517,1
81
+ IMG_000797.png,0.0004634541,0
82
+ IMG_000821.png,0.024412306,0
83
+ IMG_000854.png,0.19797583,0
84
+ IMG_000862.png,0.9750999,1
85
+ IMG_000870.png,0.8686767,1
86
+ IMG_000880.png,0.0010660468,0
87
+ IMG_000893.png,0.96038836,1
88
+ IMG_000910.png,0.07809025,0
89
+ IMG_000914.png,0.26980433,0
90
+ IMG_000917.png,0.98292184,1
91
+ IMG_000925.png,0.0027379196,0
92
+ IMG_000936.png,0.96947944,1
93
+ IMG_000942.png,0.82471883,1
94
+ IMG_000945.png,0.00028870278,0
95
+ IMG_000946.png,0.07984374,0
96
+ IMG_000951.png,0.26128584,0
97
+ IMG_000953.png,0.0010449657,0
98
+ IMG_000981.png,0.090446874,0
99
+ IMG_000985.png,0.018414477,0
100
+ IMG_000988.png,0.9778957,1
101
+ IMG_000991.png,0.036471717,0
102
+ IMG_001026.png,0.59295475,1
103
+ IMG_001027.png,0.9421946,1
104
+ IMG_001056.png,0.5752065,1
105
+ IMG_001063.png,0.028953826,0
106
+ IMG_001071.png,0.024936734,0
107
+ IMG_001083.png,0.95445114,1
108
+ IMG_001084.png,0.90794367,1
109
+ IMG_001085.png,0.7487882,1
110
+ IMG_001100.png,0.92409223,1
111
+ IMG_001112.png,0.1766397,0
112
+ IMG_001113.png,0.93751234,1
113
+ IMG_001121.png,0.80051255,1
114
+ IMG_001148.png,0.1581435,0
115
+ IMG_001154.png,0.001443795,0
116
+ IMG_001155.png,0.99754816,1
117
+ IMG_001158.png,0.21377504,0
118
+ IMG_001160.png,0.8512862,1
119
+ IMG_001162.png,0.55319506,1
120
+ IMG_001166.png,0.013784781,0
121
+ IMG_001168.png,0.6532163,1
122
+ IMG_001179.png,0.11473991,0
123
+ IMG_001180.png,0.057606816,0
124
+ IMG_001184.png,0.3746491,0
125
+ IMG_001205.png,0.0096385535,0
126
+ IMG_001206.png,0.3501062,0
127
+ IMG_001239.png,0.016910465,0
128
+ IMG_001259.png,0.9424403,1
129
+ IMG_001271.png,0.8103772,1
130
+ IMG_001290.png,0.22025856,0
131
+ IMG_001291.png,0.133331,0
132
+ IMG_001304.png,0.9841086,1
133
+ IMG_001317.png,0.75506747,1
134
+ IMG_001321.png,0.014835258,0
135
+ IMG_001334.png,0.03285374,0
136
+ IMG_001336.png,0.36698562,0
137
+ IMG_001349.png,0.03645498,0
138
+ IMG_001352.png,0.72149867,1
139
+ IMG_001363.png,0.00038983554,0
140
+ IMG_001365.png,0.6854751,1
141
+ IMG_001368.png,0.9896269,1
142
+ IMG_001384.png,0.71602064,1
143
+ IMG_001393.png,0.017989887,0
144
+ IMG_001394.png,0.0023077542,0
145
+ IMG_001405.png,0.8804799,1
146
+ IMG_001412.png,0.5575575,1
147
+ IMG_001417.png,0.71472263,1
148
+ IMG_001436.png,0.9894505,1
149
+ IMG_001443.png,0.00088119216,0
150
+ IMG_001495.png,0.005760317,0
151
+ IMG_001504.png,0.017389266,0
152
+ IMG_001506.png,0.012142896,0
153
+ IMG_001508.png,0.81739867,1
154
+ IMG_001511.png,0.9034313,1
155
+ IMG_001528.png,0.025507784,0
156
+ IMG_001536.png,0.5163997,1
157
+ IMG_001570.png,0.014479928,0
158
+ IMG_001572.png,0.004765836,0
159
+ IMG_001577.png,0.03651942,0
160
+ IMG_001581.png,0.0067347363,0
161
+ IMG_001588.png,0.7685501,1
162
+ IMG_001593.png,0.92315125,1
163
+ IMG_001605.png,0.028199496,0
164
+ IMG_001606.png,0.092707306,0
165
+ IMG_001621.png,0.00926323,0
166
+ IMG_001628.png,0.33566543,0
167
+ IMG_001638.png,0.0071623144,0
168
+ IMG_001645.png,0.0010000179,0
169
+ IMG_001671.png,0.0020201257,0
170
+ IMG_001684.png,0.87558657,1
171
+ IMG_001688.png,0.86326087,1
172
+ IMG_001702.png,0.014059455,0
173
+ IMG_001710.png,0.97905207,1
174
+ IMG_001718.png,0.00522173,0
175
+ IMG_001726.png,0.26297852,0
176
+ IMG_001731.png,0.56691897,1
177
+ IMG_001743.png,0.9594168,1
model_training/notebooks/results/ensemble_manifest.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ seed,best_val_auc,checkpoint
2
+ 8,0.9496841803043353,results/model_seed8.pth
3
+ 7,0.9425782371518806,results/model_seed7.pth
4
+ 2024,0.94717197817973,results/model_seed2024.pth
model_training/notebooks/results/model_seed2024.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f97f75f36fc98bb7a585e9167b2494fbb0e39b59c57faccf818021bace913189
3
+ size 28464092
model_training/notebooks/results/model_seed42.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a10caab6adc4408c708209adebe0293a060943c3156b06ed1c19f794a7d8504c
3
+ size 17015941
model_training/notebooks/results/model_seed7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd24fddc1efd99010271a20b31d316f9459213cd5fbc5142654737ca825b6577
3
+ size 28461893
model_training/notebooks/results/model_seed8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94f02549afb37082890d35231782b2a61117f535c752c8acd6881a69ebe880bb
3
+ size 28461893
model_training/notebooks/results/test_metrics_final.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "threshold": 0.504486,
3
+ "auc": 0.9391329313809934,
4
+ "sensitivity": 0.8703703703703703,
5
+ "specificity": 0.8527131782945736,
6
+ "accuracy": 0.8607594936708861,
7
+ "youden": 0.7230835486649441,
8
+ "composite": 0.9003373528567327,
9
+ "precision": 0.831858407079646,
10
+ "tp": 94,
11
+ "tn": 110,
12
+ "fp": 19,
13
+ "fn": 14
14
+ }
model_training/notebooks/results/test_predictions.csv ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ filename,y_true,prob,pred,correct
2
+ IMG_000824.png,0,0.021219134,0,1
3
+ IMG_000990.png,1,0.7771948,1,1
4
+ IMG_001690.png,1,0.85049516,1,1
5
+ IMG_000879.png,1,0.8450817,1,1
6
+ IMG_000884.png,0,0.11847936,0,1
7
+ IMG_001538.png,1,0.8167436,1,1
8
+ IMG_001404.png,1,0.799158,1,1
9
+ IMG_001043.png,0,0.030076103,0,1
10
+ IMG_001194.png,0,0.20013258,0,1
11
+ IMG_000345.png,0,0.003733192,0,1
12
+ IMG_001019.png,0,0.4092049,0,1
13
+ IMG_000233.png,1,0.754443,1,1
14
+ IMG_000994.png,1,0.8932385,1,1
15
+ IMG_000219.png,0,0.74411124,1,0
16
+ IMG_000758.png,0,0.88436824,1,0
17
+ IMG_001127.png,0,0.0010003505,0,1
18
+ IMG_000185.png,1,0.43144453,0,0
19
+ IMG_000344.png,1,0.7200385,1,1
20
+ IMG_000462.png,1,0.6866032,1,1
21
+ IMG_001243.png,0,0.005522695,0,1
22
+ IMG_000110.png,0,0.00616389,0,1
23
+ IMG_000178.png,1,0.38014814,0,0
24
+ IMG_001282.png,0,0.047400977,0,1
25
+ IMG_000451.png,1,0.95899737,1,1
26
+ IMG_001312.png,1,0.87483233,1,1
27
+ IMG_001298.png,1,0.89103687,1,1
28
+ IMG_000514.png,1,0.9204423,1,1
29
+ IMG_001297.png,0,0.026098764,0,1
30
+ IMG_000941.png,1,0.84286815,1,1
31
+ IMG_000218.png,1,0.32278225,0,0
32
+ IMG_001117.png,0,0.42146167,0,1
33
+ IMG_000001.png,0,0.0846162,0,1
34
+ IMG_001105.png,0,0.34927747,0,1
35
+ IMG_000079.png,0,0.0015609249,0,1
36
+ IMG_000400.png,0,0.06923158,0,1
37
+ IMG_000260.png,1,0.9263501,1,1
38
+ IMG_000998.png,1,0.6549944,1,1
39
+ IMG_001281.png,1,0.71612954,1,1
40
+ IMG_000690.png,0,0.030174967,0,1
41
+ IMG_001627.png,1,0.95443106,1,1
42
+ IMG_001107.png,1,0.7405456,1,1
43
+ IMG_000763.png,0,0.14266923,0,1
44
+ IMG_000545.png,0,0.022861702,0,1
45
+ IMG_001343.png,1,0.78821915,1,1
46
+ IMG_001175.png,1,0.8847534,1,1
47
+ IMG_000682.png,1,0.18611808,0,0
48
+ IMG_001398.png,0,0.010194484,0,1
49
+ IMG_001556.png,1,0.80163234,1,1
50
+ IMG_001225.png,1,0.95753545,1,1
51
+ IMG_001460.png,0,0.39417872,0,1
52
+ IMG_001746.png,1,0.822455,1,1
53
+ IMG_001406.png,1,0.9840471,1,1
54
+ IMG_001052.png,0,0.24067384,0,1
55
+ IMG_001640.png,0,0.022804145,0,1
56
+ IMG_001728.png,1,0.9836304,1,1
57
+ IMG_001592.png,0,0.36900538,0,1
58
+ IMG_001473.png,1,0.9562594,1,1
59
+ IMG_000376.png,0,0.007365662,0,1
60
+ IMG_001390.png,0,0.019877983,0,1
61
+ IMG_000169.png,0,0.030604592,0,1
62
+ IMG_001310.png,0,0.011923165,0,1
63
+ IMG_000151.png,0,0.6607677,1,0
64
+ IMG_001240.png,0,0.013350104,0,1
65
+ IMG_000864.png,1,0.7523066,1,1
66
+ IMG_001741.png,1,0.9743217,1,1
67
+ IMG_001364.png,1,0.35030937,0,0
68
+ IMG_000121.png,0,0.0022226572,0,1
69
+ IMG_000359.png,0,0.004687663,0,1
70
+ IMG_001075.png,1,0.8264759,1,1
71
+ IMG_000187.png,0,0.17870605,0,1
72
+ IMG_000237.png,1,0.9433076,1,1
73
+ IMG_001562.png,1,0.96882564,1,1
74
+ IMG_001114.png,1,0.98810416,1,1
75
+ IMG_000223.png,0,0.021592166,0,1
76
+ IMG_001745.png,0,0.95945805,1,0
77
+ IMG_000826.png,0,0.029931583,0,1
78
+ IMG_001704.png,0,0.23410255,0,1
79
+ IMG_000406.png,0,0.14752638,0,1
80
+ IMG_000634.png,1,0.6518697,1,1
81
+ IMG_001238.png,0,0.08548438,0,1
82
+ IMG_000993.png,1,0.7866735,1,1
83
+ IMG_001736.png,0,0.40755022,0,1
84
+ IMG_000481.png,1,0.001395701,0,0
85
+ IMG_001395.png,1,0.7820107,1,1
86
+ IMG_000553.png,0,0.02335425,0,1
87
+ IMG_000754.png,1,0.7037294,1,1
88
+ IMG_001018.png,0,0.95035726,1,0
89
+ IMG_001492.png,1,0.88814825,1,1
90
+ IMG_000737.png,0,0.00022573896,0,1
91
+ IMG_001617.png,0,0.017841075,0,1
92
+ IMG_000411.png,1,0.46054697,0,0
93
+ IMG_000764.png,1,0.9863303,1,1
94
+ IMG_001198.png,0,0.53324735,1,0
95
+ IMG_001237.png,1,0.40191436,0,0
96
+ IMG_001639.png,1,0.6333021,1,1
97
+ IMG_000829.png,1,0.9704466,1,1
98
+ IMG_001192.png,1,0.8689759,1,1
99
+ IMG_001143.png,0,0.12318405,0,1
100
+ IMG_000860.png,0,0.22503993,0,1
101
+ IMG_001308.png,1,0.9235162,1,1
102
+ IMG_000067.png,1,0.8142528,1,1
103
+ IMG_000899.png,0,0.0091903545,0,1
104
+ IMG_000602.png,1,0.7867315,1,1
105
+ IMG_000022.png,0,0.12565576,0,1
106
+ IMG_001553.png,0,0.5937602,1,0
107
+ IMG_000689.png,0,0.0035498547,0,1
108
+ IMG_001719.png,1,0.97812426,1,1
109
+ IMG_000074.png,0,0.017582327,0,1
110
+ IMG_001584.png,0,0.29081395,0,1
111
+ IMG_000127.png,0,0.13165514,0,1
112
+ IMG_000204.png,0,0.044497356,0,1
113
+ IMG_001663.png,1,0.7866542,1,1
114
+ IMG_001116.png,0,0.031116765,0,1
115
+ IMG_000964.png,1,0.9310901,1,1
116
+ IMG_001677.png,1,0.927338,1,1
117
+ IMG_001418.png,1,0.8989984,1,1
118
+ IMG_000974.png,1,0.9450868,1,1
119
+ IMG_000477.png,1,0.962653,1,1
120
+ IMG_001126.png,0,0.00042581963,0,1
121
+ IMG_001373.png,1,0.84743226,1,1
122
+ IMG_000589.png,1,0.8598813,1,1
123
+ IMG_001389.png,1,0.7277787,1,1
124
+ IMG_001292.png,0,0.30514425,0,1
125
+ IMG_000240.png,1,0.9254172,1,1
126
+ IMG_000437.png,0,0.07736599,0,1
127
+ IMG_000751.png,0,0.21614793,0,1
128
+ IMG_001278.png,0,0.00081374723,0,1
129
+ IMG_000724.png,1,0.8973737,1,1
130
+ IMG_000944.png,1,0.9696045,1,1
131
+ IMG_000930.png,1,0.06275673,0,0
132
+ IMG_001586.png,1,0.85567933,1,1
133
+ IMG_001415.png,0,0.42133018,0,1
134
+ IMG_000235.png,0,0.00044268242,0,1
135
+ IMG_000789.png,0,0.51844907,1,0
136
+ IMG_001722.png,0,0.02285177,0,1
137
+ IMG_000457.png,1,0.9203535,1,1
138
+ IMG_001385.png,1,0.33224016,0,0
139
+ IMG_001545.png,0,0.33094287,0,1
140
+ IMG_001327.png,0,0.0026560111,0,1
141
+ IMG_000693.png,1,0.7572429,1,1
142
+ IMG_001427.png,0,0.18905798,0,1
143
+ IMG_001660.png,0,0.059407346,0,1
144
+ IMG_001189.png,0,0.047472887,0,1
145
+ IMG_000257.png,0,0.8905588,1,0
146
+ IMG_001369.png,0,0.36061272,0,1
147
+ IMG_001691.png,1,0.964177,1,1
148
+ IMG_001732.png,1,0.9945857,1,1
149
+ IMG_001520.png,0,0.22373733,0,1
150
+ IMG_001376.png,1,0.7934236,1,1
151
+ IMG_000436.png,0,0.71353745,1,0
152
+ IMG_000741.png,1,0.98955476,1,1
153
+ IMG_000446.png,0,0.76797557,1,0
154
+ IMG_000876.png,0,0.43126056,0,1
155
+ IMG_001462.png,0,0.013134937,0,1
156
+ IMG_001066.png,0,0.00046867327,0,1
157
+ IMG_000114.png,0,0.3608086,0,1
158
+ IMG_001015.png,0,0.04901976,0,1
159
+ IMG_001372.png,0,0.018457958,0,1
160
+ IMG_000949.png,1,0.99760324,1,1
161
+ IMG_000919.png,0,0.016946174,0,1
162
+ IMG_001172.png,1,0.8392962,1,1
163
+ IMG_001068.png,0,0.8429959,1,0
164
+ IMG_000087.png,0,0.6321575,1,0
165
+ IMG_000888.png,0,0.84551364,1,0
166
+ IMG_000286.png,1,0.9955705,1,1
167
+ IMG_000214.png,1,0.73011214,1,1
168
+ IMG_001065.png,0,0.047320206,0,1
169
+ IMG_000510.png,0,0.00021625705,0,1
170
+ IMG_000315.png,0,0.040586166,0,1
171
+ IMG_000422.png,1,0.86234355,1,1
172
+ IMG_000468.png,1,0.9084235,1,1
173
+ IMG_000749.png,0,0.048509527,0,1
174
+ IMG_000822.png,0,0.00459977,0,1
175
+ IMG_001025.png,0,0.1423657,0,1
176
+ IMG_000777.png,1,0.94110733,1,1
177
+ IMG_001235.png,1,0.955069,1,1
178
+ IMG_001265.png,1,0.4247596,0,0
179
+ IMG_000476.png,0,0.0074078897,0,1
180
+ IMG_000922.png,1,0.4249936,0,0
181
+ IMG_001294.png,0,0.5523263,1,0
182
+ IMG_001181.png,0,0.021967547,0,1
183
+ IMG_000179.png,1,0.6198387,1,1
184
+ IMG_000844.png,0,0.0009595261,0,1
185
+ IMG_001062.png,0,0.02210043,0,1
186
+ IMG_000647.png,0,0.048218727,0,1
187
+ IMG_000629.png,1,0.7496184,1,1
188
+ IMG_000743.png,1,0.94583535,1,1
189
+ IMG_000639.png,0,0.31978914,0,1
190
+ IMG_000207.png,0,0.0010147754,0,1
191
+ IMG_000116.png,0,0.4486534,0,1
192
+ IMG_000412.png,0,0.00072938023,0,1
193
+ IMG_001076.png,0,0.00037972903,0,1
194
+ IMG_000911.png,0,0.21701539,0,1
195
+ IMG_000226.png,1,0.9124851,1,1
196
+ IMG_001434.png,0,0.1508207,0,1
197
+ IMG_001534.png,0,0.007111399,0,1
198
+ IMG_000239.png,1,0.97189856,1,1
199
+ IMG_000729.png,0,0.024547786,0,1
200
+ IMG_000362.png,0,0.0368409,0,1
201
+ IMG_001563.png,1,0.92587185,1,1
202
+ IMG_000850.png,0,0.031102972,0,1
203
+ IMG_001283.png,0,0.0041644424,0,1
204
+ IMG_001077.png,0,0.0043652514,0,1
205
+ IMG_000109.png,1,0.94027257,1,1
206
+ IMG_001039.png,0,0.31002668,0,1
207
+ IMG_000855.png,1,0.9954782,1,1
208
+ IMG_000581.png,0,0.58953196,1,0
209
+ IMG_001670.png,1,0.41889447,0,0
210
+ IMG_001489.png,0,0.0039128903,0,1
211
+ IMG_001488.png,1,0.977498,1,1
212
+ IMG_001247.png,1,0.9217244,1,1
213
+ IMG_001340.png,1,0.9252595,1,1
214
+ IMG_000562.png,0,0.00028323507,0,1
215
+ IMG_001714.png,1,0.9137611,1,1
216
+ IMG_000113.png,1,0.83810467,1,1
217
+ IMG_000519.png,0,0.24291025,0,1
218
+ IMG_000352.png,1,0.89757895,1,1
219
+ IMG_000060.png,0,0.12018548,0,1
220
+ IMG_001689.png,1,0.88260335,1,1
221
+ IMG_001332.png,0,0.021517713,0,1
222
+ IMG_001227.png,0,0.5806698,1,0
223
+ IMG_000952.png,1,0.91132796,1,1
224
+ IMG_001223.png,0,0.0054034726,0,1
225
+ IMG_001662.png,1,0.9796165,1,1
226
+ IMG_001242.png,1,0.89553756,1,1
227
+ IMG_001542.png,1,0.2224111,0,0
228
+ IMG_000150.png,0,0.038930204,0,1
229
+ IMG_001021.png,0,0.00035879802,0,1
230
+ IMG_001516.png,0,0.62845725,1,0
231
+ IMG_001074.png,0,0.7126184,1,0
232
+ IMG_001648.png,1,0.9681084,1,1
233
+ IMG_000300.png,1,0.66993195,1,1
234
+ IMG_000194.png,0,0.16441628,0,1
235
+ IMG_000728.png,0,0.005496844,0,1
236
+ IMG_000818.png,1,0.92287654,1,1
237
+ IMG_001712.png,0,0.29255372,0,1
238
+ IMG_000719.png,0,0.075776406,0,1
model_training/notebooks/results/training_history.csv ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed,stage,epoch,train_loss,train_auc,train_acc,train_composite,val_loss,val_auc,val_acc,val_sens,val_spec,val_composite,lr
2
+ 8,frozen,1,0.9582490903990609,0.7691765873015873,0.6177536231884058,0.6771676587301587,0.7091568447649479,0.85099052540913,0.6708860759493671,0.37037037037037035,0.9224806201550387,0.7487080103359173,0.001
3
+ 8,frozen,2,0.6246700976576124,0.783531746031746,0.6838768115942029,0.726031746031746,0.48649174720048904,0.8657766293425209,0.7552742616033755,0.6574074074074074,0.8372093023255814,0.8065424921045077,0.001
4
+ 8,frozen,3,0.5868643377508436,0.7798181216931217,0.6920289855072463,0.7326074735449736,0.45043402910232544,0.8748923341946597,0.7721518987341772,0.7037037037037037,0.8294573643410853,0.8207364341085271,0.001
5
+ 8,frozen,4,0.5721473864146641,0.7899371693121694,0.7038043478260869,0.7456233465608466,0.43230215832591057,0.878409417169107,0.7805907172995781,0.7314814814814815,0.8217054263565892,0.8275014355440712,0.001
6
+ 8,frozen,5,0.54250961627279,0.8040046296296297,0.7246376811594203,0.7627959656084657,0.4336203671991825,0.8802038472581108,0.7763713080168776,0.7129629629629629,0.8294573643410853,0.8257070054550675,0.001
7
+ 8,frozen,6,0.534449862582343,0.8101421957671958,0.7336956521739131,0.7705076058201058,0.40960250422358513,0.8915446454206145,0.7890295358649789,0.7407407407407407,0.8294573643410853,0.8383218489807638,0.001
8
+ 8,frozen,7,0.530328972850527,0.814781746031746,0.7454710144927537,0.7780853174603175,0.40932638198137283,0.8911139821992536,0.7932489451476793,0.7407407407407407,0.8372093023255814,0.8400445018662074,0.001
9
+ 8,frozen,8,0.5051169097423553,0.8304497354497355,0.7373188405797102,0.7829629629629631,0.4084223732352257,0.8912575366063739,0.7974683544303798,0.75,0.8372093023255814,0.8424310938845824,0.001
10
+ 8,frozen,9,0.5022006383963994,0.8348148148148148,0.7672101449275363,0.7998478835978836,0.4024202488362789,0.895205282802182,0.7890295358649789,0.7314814814814815,0.8372093023255814,0.8397753373528568,0.001
11
+ 8,frozen,10,0.49878278289522443,0.8352347883597884,0.7518115942028986,0.7921808862433863,0.39458586275577545,0.8993683606086705,0.7890295358649789,0.7314814814814815,0.8372093023255814,0.8418568762561011,0.001
12
+ 8,finetune,1,0.4544305384159088,0.8672321428571428,0.7835144927536232,0.8254613095238095,0.38248663768172264,0.919250645994832,0.8270042194092827,0.75,0.8914728682170543,0.8699935400516796,0.0009891830623632338
13
+ 8,finetune,2,0.38789899774960107,0.9041567460317461,0.8342391304347826,0.8695585317460318,0.2955472208559513,0.9496841803043353,0.8860759493670886,0.9444444444444444,0.8372093023255814,0.9202555268446742,0.0009572050015330873
14
+ 8,finetune,3,0.3507444177355085,0.9221329365079365,0.8469202898550725,0.8848561507936508,0.3256516754627228,0.937123169681309,0.8818565400843882,0.8703703703703703,0.8914728682170543,0.9090223944875107,0.0009054634122155989
15
+ 8,finetune,4,0.33837775332587106,0.9301190476190476,0.8641304347826086,0.8962896825396826,0.33792560175061226,0.9344674131495836,0.8607594936708861,0.8425925925925926,0.875968992248062,0.8968741027849554,0.0008362196501476348
16
+ 8,finetune,5,0.30863836577960424,0.9406580687830688,0.8586956521739131,0.90088458994709,0.35501030273735523,0.9326729830605799,0.8565400843881856,0.7962962962962963,0.9069767441860465,0.8921547516508757,0.0007524999999999999
17
+ 8,finetune,6,0.32643192623342787,0.9347982804232804,0.8523550724637681,0.892974537037037,0.31595958210527897,0.9378409417169108,0.8860759493670886,0.8796296296296297,0.8914728682170543,0.9116960953201263,0.0006579634122155989
18
+ 7,frozen,1,0.9370916630540576,0.7815476190476189,0.6204710144927537,0.6849999999999999,0.6874790899455547,0.8518518518518519,0.6708860759493671,0.3888888888888889,0.9069767441860465,0.7498923341946597,0.001
19
+ 7,frozen,2,0.6154342787606376,0.801253306878307,0.6983695652173914,0.7406068121693122,0.5005612336099148,0.863982199253517,0.7510548523206751,0.6481481481481481,0.8372093023255814,0.803330462245191,0.001
20
+ 7,frozen,3,0.5728206719670977,0.7850958994708995,0.7028985507246377,0.7411193783068782,0.4355204217135906,0.8814958369221936,0.7848101265822784,0.7407407407407407,0.8217054263565892,0.8313594602354293,0.001
21
+ 7,frozen,4,0.5409952683108193,0.803399470899471,0.7119565217391305,0.7561838624338624,0.4371259845793247,0.8789836347975883,0.7805907172995781,0.7222222222222222,0.8294573643410853,0.8274117140396211,0.001
22
+ 7,frozen,5,0.5573380930083138,0.7994576719576719,0.7346014492753623,0.7652645502645502,0.41805607080459595,0.8866637955785241,0.7805907172995781,0.7407407407407407,0.813953488372093,0.8320054550674706,0.001
23
+ 7,frozen,6,0.5511363804340362,0.8038194444444444,0.7355072463768116,0.7692113095238096,0.4179871119558811,0.886017800746483,0.7974683544303798,0.75,0.8372093023255814,0.8398112259546369,0.001
24
+ 7,frozen,7,0.5211189142295293,0.8182903439153438,0.7336956521739131,0.7744229497354497,0.40560774877667427,0.8934108527131782,0.7974683544303798,0.75,0.8372093023255814,0.8435077519379844,0.001
25
+ 7,frozen,8,0.5251553058624268,0.8196527777777778,0.7391304347826086,0.7777628968253969,0.40271764248609543,0.8944157335630204,0.8059071729957806,0.7685185185185185,0.8372093023255814,0.8486398219925352,0.001
26
+ 7,frozen,9,0.5052971592971257,0.8322817460317459,0.7445652173913043,0.7871329365079364,0.3970758095383644,0.8983634797588287,0.8016877637130801,0.7592592592592593,0.8372093023255814,0.8482988802756245,0.001
27
+ 7,frozen,10,0.5330768167972565,0.8135284391534392,0.7373188405797102,0.7737086640211641,0.3852919824421406,0.904392764857881,0.810126582278481,0.7870370370370371,0.8294573643410853,0.8563199827734711,0.001
28
+ 7,finetune,1,0.47272565024239677,0.8557539682539683,0.7817028985507246,0.8182539682539682,0.44458993524312973,0.8919753086419753,0.7763713080168776,0.6481481481481481,0.8837209302325582,0.8289549239161642,0.0009891830623632338
29
+ 7,finetune,2,0.38450789025851656,0.9061078042328042,0.8170289855072463,0.8615062830687831,0.3831586390733719,0.9171691070915876,0.8354430379746836,0.7685185185185185,0.8914728682170543,0.873582400229687,0.0009572050015330873
30
+ 7,finetune,3,0.35653283638613564,0.922037037037037,0.8505434782608695,0.8862367724867725,0.31522940658032894,0.9425782371518806,0.8818565400843882,0.9074074074074074,0.8604651162790697,0.9132572494975595,0.0009054634122155989
31
+ 7,finetune,4,0.33924023721899305,0.9283333333333333,0.8505434782608695,0.8892261904761904,0.3659691587090492,0.9249210450760839,0.8312236286919831,0.7592592592592593,0.8914728682170543,0.8751435544071202,0.0008362196501476348
32
+ 7,finetune,5,0.33330591917037966,0.9319179894179894,0.8623188405797102,0.8970701058201058,0.32207747735083103,0.939420040195234,0.8649789029535865,0.8981481481481481,0.8372093023255814,0.9035493827160493,0.0007524999999999999
33
+ 7,finetune,6,0.31548327335289544,0.9377645502645502,0.8596014492753623,0.8990608465608465,0.35484062135219574,0.9318116566178581,0.8438818565400844,0.7962962962962963,0.8837209302325582,0.8859101349411427,0.0006579634122155989
34
+ 7,finetune,7,0.287742834006037,0.9492956349206351,0.8731884057971014,0.9116319444444445,0.3438141830265522,0.930304335343095,0.8481012658227848,0.7962962962962963,0.8914728682170543,0.887094458799885,0.0005567415893174885
35
+ 2024,frozen,1,0.9445781009537834,0.7710317460317461,0.6304347826086957,0.6849603174603175,0.7147812955081463,0.8519236290554119,0.6666666666666666,0.37037037037037035,0.9147286821705426,0.7472365776629342,0.001
36
+ 2024,frozen,2,0.6475456161158425,0.7795568783068784,0.6902173913043478,0.7264847883597884,0.4829690493643284,0.8683606086706861,0.7510548523206751,0.6481481481481481,0.8372093023255814,0.8055196669537755,0.001
37
+ 2024,frozen,3,0.557839333159583,0.7974801587301588,0.7119565217391305,0.7520337301587301,0.45181746035814285,0.8722365776629342,0.7721518987341772,0.7037037037037037,0.8294573643410853,0.8194085558426644,0.001
38
+ 2024,frozen,4,0.553470607314791,0.7997123015873016,0.717391304347826,0.7562053571428571,0.43246644735336304,0.8798449612403101,0.7805907172995781,0.7314814814814815,0.8217054263565892,0.8282192075796727,0.001
39
+ 2024,frozen,5,0.5317457369395665,0.8138822751322751,0.7318840579710145,0.772417328042328,0.41545065492391586,0.8885300028710881,0.7890295358649789,0.75,0.8217054263565892,0.8371913580246914,0.001
40
+ 2024,frozen,6,0.5310433038643428,0.8142757936507937,0.7273550724637681,0.7686259920634921,0.41416966915130615,0.8903962101636521,0.7848101265822784,0.7222222222222222,0.8372093023255814,0.835055986218777,0.001
41
+ 2024,frozen,7,0.535319973741259,0.810681216931217,0.7336956521739131,0.7706183862433862,0.41076211631298065,0.88946310651737,0.810126582278481,0.7870370370370371,0.8294573643410853,0.8488551536032156,0.001
42
+ 2024,frozen,8,0.5011444449424743,0.8314996693121693,0.7472826086956522,0.7886268187830687,0.4013783372938633,0.8952770600057421,0.7932489451476793,0.7407407407407407,0.8372093023255814,0.8421260407694516,0.001
43
+ 2024,frozen,9,0.5317847013473511,0.8112169312169313,0.7336956521739131,0.7711243386243387,0.38976578786969185,0.9020241171403963,0.8227848101265823,0.8148148148148148,0.8294573643410853,0.8620801033591732,0.001
44
+ 2024,frozen,10,0.5140501362936837,0.8266732804232804,0.7336956521739131,0.7782175925925926,0.3864467516541481,0.9039621016365202,0.8143459915611815,0.7962962962962963,0.8294573643410853,0.8584194659776055,0.001
45
+ 2024,finetune,1,0.4775453839983259,0.853396164021164,0.7608695652173914,0.8070949074074074,0.3899293653666973,0.9338931955211025,0.8354430379746836,0.9444444444444444,0.7441860465116279,0.8891042204995693,0.0009891830623632338
46
+ 2024,finetune,2,0.38173683881759646,0.9084457671957672,0.8251811594202898,0.8663458994708995,0.3108560163527727,0.9433677863910422,0.890295358649789,0.8888888888888888,0.8914728682170543,0.9167743324720068,0.0009572050015330873
47
+ 2024,finetune,3,0.36059877617018565,0.918994708994709,0.8460144927536232,0.8829497354497355,0.30220645666122437,0.94717197817973,0.890295358649789,0.8981481481481481,0.8837209302325582,0.9190532586850415,0.0009054634122155989
48
+ 2024,finetune,4,0.3435957955462592,0.9252678571428571,0.8505434782608695,0.8876934523809523,0.32907535694539547,0.9417886879127189,0.8734177215189873,0.9351851851851852,0.8217054263565892,0.910116996841803,0.0008362196501476348
49
+ 2024,finetune,5,0.3113635071686336,0.9391468253968255,0.8731884057971014,0.9062400793650794,0.35001079365611076,0.9295147861039335,0.8523206751054853,0.8240740740740741,0.875968992248062,0.8897681596325007,0.0007524999999999999
50
+ 2024,finetune,6,0.3115046965224402,0.9414550264550263,0.8605072463768116,0.9017989417989417,0.3371146433055401,0.9364771748492678,0.8312236286919831,0.7870370370370371,0.8682170542635659,0.8820521102497847,0.0006579634122155989
51
+ 2024,finetune,7,0.267705562710762,0.9570701058201058,0.8786231884057971,0.9179398148148148,0.3176867663860321,0.9399942578237152,0.8649789029535865,0.8425925925925926,0.8837209302325582,0.9015755096181453,0.0005567415893174885
model_training/notebooks/results/val_metrics_final.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "threshold": 0.504486,
3
+ "auc": 0.9499712891185758,
4
+ "sensitivity": 0.9166666666666666,
5
+ "specificity": 0.8837209302325582,
6
+ "accuracy": 0.8987341772151899,
7
+ "youden": 0.8003875968992249,
8
+ "composite": 0.9250825437840942,
9
+ "precision": 0.868421052631579,
10
+ "tp": 99,
11
+ "tn": 114,
12
+ "fp": 15,
13
+ "fn": 9
14
+ }
model_training/notebooks/results/val_predictions.csv ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ filename,y_true,prob,pred,correct
2
+ IMG_000833.png,0,0.08514188,0,1
3
+ IMG_000868.png,0,0.07656998,0,1
4
+ IMG_000471.png,1,0.926362,1,1
5
+ IMG_001514.png,0,0.042604674,0,1
6
+ IMG_001123.png,0,0.010878541,0,1
7
+ IMG_000311.png,0,0.0019978306,0,1
8
+ IMG_000768.png,0,0.1100046,0,1
9
+ IMG_001354.png,0,0.38976616,0,1
10
+ IMG_000677.png,1,0.9424959,1,1
11
+ IMG_001694.png,1,0.82901657,1,1
12
+ IMG_001249.png,1,0.9801762,1,1
13
+ IMG_001543.png,1,0.5918354,1,1
14
+ IMG_001138.png,1,0.9144192,1,1
15
+ IMG_001279.png,0,0.0023716255,0,1
16
+ IMG_001425.png,1,0.9752167,1,1
17
+ IMG_001664.png,0,0.0030602103,0,1
18
+ IMG_000517.png,0,0.26428953,0,1
19
+ IMG_000238.png,0,0.07020499,0,1
20
+ IMG_001419.png,1,0.9230154,1,1
21
+ IMG_000357.png,0,0.023781383,0,1
22
+ IMG_001413.png,0,0.53677505,1,0
23
+ IMG_001438.png,0,0.003928518,0,1
24
+ IMG_000730.png,1,0.9570253,1,1
25
+ IMG_000043.png,1,0.9750734,1,1
26
+ IMG_000034.png,0,0.006953893,0,1
27
+ IMG_000657.png,0,0.07155775,0,1
28
+ IMG_001544.png,0,0.90743506,1,0
29
+ IMG_000216.png,0,0.029633842,0,1
30
+ IMG_001193.png,0,0.015726507,0,1
31
+ IMG_001046.png,0,0.039458692,0,1
32
+ IMG_001555.png,0,0.3316432,0,1
33
+ IMG_001622.png,1,0.9557772,1,1
34
+ IMG_000170.png,1,0.8454006,1,1
35
+ IMG_000584.png,1,0.7091627,1,1
36
+ IMG_000750.png,0,0.6704473,1,0
37
+ IMG_000168.png,1,0.90156275,1,1
38
+ IMG_000230.png,0,0.66135454,1,0
39
+ IMG_000534.png,1,0.9762868,1,1
40
+ IMG_001523.png,0,0.5044856,0,1
41
+ IMG_000926.png,1,0.96041846,1,1
42
+ IMG_000776.png,0,0.22449876,0,1
43
+ IMG_000842.png,1,0.9752196,1,1
44
+ IMG_001338.png,1,0.29997367,0,0
45
+ IMG_000126.png,1,0.8138403,1,1
46
+ IMG_000377.png,1,0.8908096,1,1
47
+ IMG_001713.png,1,0.92572457,1,1
48
+ IMG_000996.png,0,0.3828967,0,1
49
+ IMG_000278.png,0,0.009509609,0,1
50
+ IMG_001392.png,1,0.96734214,1,1
51
+ IMG_000666.png,0,0.005913499,0,1
52
+ IMG_001012.png,0,0.0072901817,0,1
53
+ IMG_000515.png,0,0.18308273,0,1
54
+ IMG_000177.png,1,0.0014123927,0,0
55
+ IMG_000099.png,1,0.71664655,1,1
56
+ IMG_001485.png,0,0.07025461,0,1
57
+ IMG_000599.png,1,0.7633101,1,1
58
+ IMG_001377.png,0,0.00014443095,0,1
59
+ IMG_001490.png,1,0.9397476,1,1
60
+ IMG_000612.png,0,0.04453209,0,1
61
+ IMG_001293.png,1,0.92839843,1,1
62
+ IMG_001679.png,1,0.75982463,1,1
63
+ IMG_000783.png,0,0.102033496,0,1
64
+ IMG_001396.png,0,0.06896914,0,1
65
+ IMG_001047.png,0,0.0027958185,0,1
66
+ IMG_000090.png,0,0.3977522,0,1
67
+ IMG_000601.png,1,0.37769222,0,0
68
+ IMG_001475.png,1,0.89420277,1,1
69
+ IMG_000494.png,1,0.63474864,1,1
70
+ IMG_001216.png,0,0.025922079,0,1
71
+ IMG_001564.png,1,0.97837263,1,1
72
+ IMG_000928.png,0,0.057339218,0,1
73
+ IMG_001161.png,0,0.016286755,0,1
74
+ IMG_000024.png,0,0.8985922,1,0
75
+ IMG_000319.png,0,0.648694,1,0
76
+ IMG_001416.png,0,0.04064005,0,1
77
+ IMG_000478.png,1,0.9229935,1,1
78
+ IMG_001122.png,0,0.049890943,0,1
79
+ IMG_001740.png,1,0.93487585,1,1
80
+ IMG_001500.png,0,0.09520903,0,1
81
+ IMG_000140.png,1,0.98620355,1,1
82
+ IMG_001289.png,1,0.9263547,1,1
83
+ IMG_000937.png,1,0.60724884,1,1
84
+ IMG_000184.png,1,0.56130826,1,1
85
+ IMG_001060.png,0,0.0008340298,0,1
86
+ IMG_001379.png,1,0.72570825,1,1
87
+ IMG_000051.png,1,0.9593886,1,1
88
+ IMG_000684.png,0,0.031647827,0,1
89
+ IMG_000173.png,1,0.54984164,1,1
90
+ IMG_001529.png,0,0.37173158,0,1
91
+ IMG_001540.png,0,0.0022377113,0,1
92
+ IMG_000469.png,0,0.011063429,0,1
93
+ IMG_000452.png,0,0.5156396,1,0
94
+ IMG_000732.png,0,0.40668148,0,1
95
+ IMG_000546.png,0,0.8375038,1,0
96
+ IMG_000506.png,1,0.9919884,1,1
97
+ IMG_001633.png,0,0.33360916,0,1
98
+ IMG_000082.png,0,0.371273,0,1
99
+ IMG_000206.png,0,0.16295624,0,1
100
+ IMG_000975.png,0,0.00403203,0,1
101
+ IMG_000426.png,1,0.98051214,1,1
102
+ IMG_000441.png,0,0.07469518,0,1
103
+ IMG_000117.png,0,0.031310383,0,1
104
+ IMG_000973.png,0,0.00051067735,0,1
105
+ IMG_000881.png,1,0.8859954,1,1
106
+ IMG_000247.png,1,0.94318277,1,1
107
+ IMG_000574.png,0,0.014506694,0,1
108
+ IMG_000381.png,1,0.9848424,1,1
109
+ IMG_001618.png,0,0.012155103,0,1
110
+ IMG_001053.png,1,0.24243784,0,0
111
+ IMG_001400.png,0,0.25488326,0,1
112
+ IMG_000590.png,1,0.95325905,1,1
113
+ IMG_000342.png,1,0.8786827,1,1
114
+ IMG_001213.png,0,0.0063149747,0,1
115
+ IMG_000075.png,0,0.18751769,0,1
116
+ IMG_000608.png,0,0.09344582,0,1
117
+ IMG_000493.png,1,0.74505854,1,1
118
+ IMG_000630.png,0,0.37733248,0,1
119
+ IMG_001676.png,1,0.50721115,1,1
120
+ IMG_001102.png,1,0.90345204,1,1
121
+ IMG_000095.png,0,0.0039746515,0,1
122
+ IMG_001303.png,0,0.02584564,0,1
123
+ IMG_001266.png,0,0.0027153727,0,1
124
+ IMG_000755.png,1,0.5398157,1,1
125
+ IMG_001295.png,1,0.89617956,1,1
126
+ IMG_000832.png,1,0.985641,1,1
127
+ IMG_001705.png,0,0.017828463,0,1
128
+ IMG_001306.png,0,0.0074857906,0,1
129
+ IMG_000246.png,0,0.33929592,0,1
130
+ IMG_001530.png,1,0.72100955,1,1
131
+ IMG_000970.png,0,0.32502642,0,1
132
+ IMG_000059.png,1,0.9853226,1,1
133
+ IMG_001149.png,1,0.92740005,1,1
134
+ IMG_001200.png,1,0.95928955,1,1
135
+ IMG_001055.png,1,0.76670074,1,1
136
+ IMG_000995.png,0,0.8288552,1,0
137
+ IMG_000038.png,0,0.20308095,0,1
138
+ IMG_000524.png,0,0.005141989,0,1
139
+ IMG_001568.png,0,0.14497586,0,1
140
+ IMG_000122.png,0,0.051262774,0,1
141
+ IMG_000333.png,1,0.8129171,1,1
142
+ IMG_001109.png,0,0.18362695,0,1
143
+ IMG_000698.png,0,0.01703381,0,1
144
+ IMG_001573.png,1,0.9943408,1,1
145
+ IMG_001625.png,1,0.91411126,1,1
146
+ IMG_000509.png,0,0.0028896236,0,1
147
+ IMG_001448.png,1,0.9428977,1,1
148
+ IMG_000331.png,1,0.9803683,1,1
149
+ IMG_000008.png,0,0.038054988,0,1
150
+ IMG_000521.png,0,0.0010172316,0,1
151
+ IMG_000012.png,1,0.9812463,1,1
152
+ IMG_000780.png,1,0.6871019,1,1
153
+ IMG_000428.png,1,0.49723867,0,0
154
+ IMG_000463.png,0,0.016233142,0,1
155
+ IMG_001423.png,1,0.8972441,1,1
156
+ IMG_000287.png,1,0.61366016,1,1
157
+ IMG_001173.png,1,0.17359304,0,0
158
+ IMG_001357.png,1,0.6145142,1,1
159
+ IMG_001329.png,1,0.9680075,1,1
160
+ IMG_001119.png,0,0.013847093,0,1
161
+ IMG_001675.png,0,0.01776683,0,1
162
+ IMG_000877.png,0,0.9286344,1,0
163
+ IMG_000402.png,0,0.25376308,0,1
164
+ IMG_001176.png,1,0.9541638,1,1
165
+ IMG_000068.png,0,0.0027813588,0,1
166
+ IMG_000091.png,0,0.06692752,0,1
167
+ IMG_000600.png,0,0.68098706,1,0
168
+ IMG_000765.png,1,0.9324583,1,1
169
+ IMG_000965.png,0,0.0155419195,0,1
170
+ IMG_000361.png,1,0.4985782,0,0
171
+ IMG_000414.png,1,0.95493746,1,1
172
+ IMG_000969.png,0,0.06441665,0,1
173
+ IMG_001087.png,0,0.0050179944,0,1
174
+ IMG_000124.png,1,0.6126341,1,1
175
+ IMG_000013.png,0,0.14956698,0,1
176
+ IMG_001140.png,1,0.98292345,1,1
177
+ IMG_001088.png,1,0.9412705,1,1
178
+ IMG_001655.png,0,0.10415509,0,1
179
+ IMG_000720.png,1,0.5776679,1,1
180
+ IMG_001687.png,0,0.054884706,0,1
181
+ IMG_000654.png,1,0.80313796,1,1
182
+ IMG_000200.png,0,0.016887484,0,1
183
+ IMG_001131.png,0,0.59535867,1,0
184
+ IMG_000236.png,1,0.9608375,1,1
185
+ IMG_001004.png,1,0.91360044,1,1
186
+ IMG_000869.png,1,0.9747172,1,1
187
+ IMG_001630.png,0,0.0015267566,0,1
188
+ IMG_000374.png,1,0.32337222,0,0
189
+ IMG_001203.png,0,0.00011447017,0,1
190
+ IMG_000853.png,0,0.06780861,0,1
191
+ IMG_000638.png,0,0.02331572,0,1
192
+ IMG_000620.png,1,0.9309563,1,1
193
+ IMG_000167.png,1,0.9887361,1,1
194
+ IMG_000811.png,0,0.26473084,0,1
195
+ IMG_001061.png,1,0.905426,1,1
196
+ IMG_000456.png,0,0.00062740117,0,1
197
+ IMG_000438.png,1,0.7869287,1,1
198
+ IMG_000292.png,1,0.9418937,1,1
199
+ IMG_000644.png,0,0.5747625,1,0
200
+ IMG_000491.png,0,0.014190746,0,1
201
+ IMG_000564.png,0,0.0036469798,0,1
202
+ IMG_000738.png,0,0.8187067,1,0
203
+ IMG_000886.png,1,0.7505375,1,1
204
+ IMG_000387.png,0,0.24593168,0,1
205
+ IMG_000492.png,0,0.33065897,0,1
206
+ IMG_001057.png,1,0.9879611,1,1
207
+ IMG_001195.png,0,0.0053326027,0,1
208
+ IMG_000088.png,1,0.91581327,1,1
209
+ IMG_000014.png,0,0.08517947,0,1
210
+ IMG_000830.png,0,0.02413497,0,1
211
+ IMG_000779.png,0,0.021322682,0,1
212
+ IMG_001510.png,0,0.06587305,0,1
213
+ IMG_000556.png,0,0.019615667,0,1
214
+ IMG_000744.png,1,0.8163302,1,1
215
+ IMG_001059.png,0,0.012555552,0,1
216
+ IMG_001013.png,1,0.12594076,0,0
217
+ IMG_000901.png,1,0.9905481,1,1
218
+ IMG_000983.png,0,0.13618241,0,1
219
+ IMG_001359.png,1,0.90301836,1,1
220
+ IMG_001604.png,0,0.0092756525,0,1
221
+ IMG_001050.png,0,0.41582817,0,1
222
+ IMG_000662.png,0,0.0009655981,0,1
223
+ IMG_000225.png,0,0.059564617,0,1
224
+ IMG_001261.png,0,0.00012916393,0,1
225
+ IMG_001401.png,1,0.8134201,1,1
226
+ IMG_001136.png,0,0.03423228,0,1
227
+ IMG_000031.png,1,0.9828577,1,1
228
+ IMG_001547.png,0,0.028171018,0,1
229
+ IMG_000023.png,1,0.96932787,1,1
230
+ IMG_000261.png,1,0.89618427,1,1
231
+ IMG_001125.png,1,0.9938845,1,1
232
+ IMG_000210.png,1,0.9844753,1,1
233
+ IMG_000305.png,0,0.96544015,1,0
234
+ IMG_001426.png,0,0.23876508,0,1
235
+ IMG_001204.png,1,0.9547577,1,1
236
+ IMG_000081.png,1,0.88376915,1,1
237
+ IMG_001683.png,1,0.7803761,1,1
238
+ IMG_000069.png,1,0.9368211,1,1
model_training/src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src package
model_training/src/config.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass, field
5
+ from typing import List
6
+
7
+ import torch
8
+
9
+
10
+ @dataclass
11
+ class Config:
12
+ # ── Data paths ────────────────────────────────────────────────────────
13
+ csv_path: str = "../../data/train_val.csv"
14
+ image_dir: str = "../../data/images"
15
+ submission_test_dir: str = "../../data/test_images"
16
+ output_dir: str = "results"
17
+ results_log_path: str = "results_log.csv" # global run log (one row per training run)
18
+
19
+ # ── Reproducibility ──────────────────────────────────────────────────
20
+ seed: int = 42
21
+
22
+ # ── Image / DataLoader ───────────────────────────────────────────────
23
+ img_size: int = 224 # torchxrayvision DenseNet-121 native resolution
24
+ batch_size: int = 32
25
+ num_workers: int = 4
26
+
27
+ # ── Train / val / test split ─────────────────────────────────────────
28
+ val_size: float = 0.15
29
+ test_size: float = 0.15
30
+
31
+ # ── Training schedule (two-stage) ────────────────────────────────────
32
+ frozen_epochs: int = 3 # stage 1: head-only warmup
33
+ finetune_epochs: int = 22 # stage 2: full unfreeze with cosine LR
34
+ early_stop_patience: int = 6 # early stop on val AUC during stage 2
35
+
36
+ # ── Optimiser ────────────────────────────────────────────────────────
37
+ head_lr: float = 3e-4 # classifier LR (both stages)
38
+ backbone_lr: float = 1e-4 # features LR (stage 2 only)
39
+ weight_decay: float = 1e-4
40
+ grad_clip: float = 1.0
41
+
42
+ # ── Data augmentation ────────────────────────────────────────────────
43
+ # Mixup: interpolates two samples and their labels in every training batch.
44
+ # mixup_alpha > 0 enables it; Ξ» ~ Beta(Ξ±, Ξ±). 0 = disabled.
45
+ # Typical range: 0.2 – 0.4.
46
+ mixup_alpha: float = 0.0
47
+ # Label smoothing: prevents overconfidence by softening hard {0,1} targets.
48
+ # y_smooth = y*(1-Ξ΅) + 0.5*Ξ΅. 0 = disabled. Typical range: 0.05 – 0.15.
49
+ label_smoothing: float = 0.0
50
+
51
+ # ── Architecture ─────────────────────────────────────────────────────
52
+ # Options: "densenet121" | "rad-dino" | "mobilenet_v3_large" | "efficientnet_b0" | "efficientnet_b3"
53
+ # densenet121 β€” torchxrayvision DenseNet-121, pretrained on ~1M chest X-rays (recommended)
54
+ # rad-dino β€” microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M chest X-rays;
55
+ # use img_size=518 (native: 37Γ—37 patches at 14 px); 12 frozen_blocks max
56
+ # mobilenet_v3_large β€” torchvision MobileNetV3-Large, pretrained on ImageNet (faster, lighter)
57
+ # efficientnet_b0 β€” torchvision EfficientNet-B0, pretrained on ImageNet (good accuracy/size trade-off)
58
+ # efficientnet_b3 β€” torchvision EfficientNet-B3, pretrained on ImageNet (higher accuracy, more params)
59
+ backbone: str = "efficientnet_b0"
60
+
61
+ # ── Ensemble ─────────────────────────────────────────────────────────
62
+ # True: train one model per entry in `seeds` and average predictions
63
+ # False: train a single model using only `seed` (faster experimentation)
64
+ use_ensemble: bool = True
65
+
66
+ # ── Multi-seed ensemble ──────────────────────────────────────────────
67
+ seeds: List[int] = field(default_factory=lambda: [42, 7, 2024])
68
+
69
+ # ── Loss function ─────────────────────────────────────────────────────
70
+ # False: standard BCE | True: 0.5*BCE + 0.5*(1 - soft_composite)
71
+ use_composite_loss: bool = False
72
+ # Blend weight Ξ±: Ξ±Β·BCE + (1-Ξ±)Β·(1-soft_composite). 0 = pure composite, 1 = pure BCE.
73
+ composite_loss_alpha: float = 0.5
74
+ # Temperature for the pairwise-sigmoid soft-AUC term (higher β†’ sharper ranking signal)
75
+ composite_loss_gamma: float = 1.0
76
+
77
+ # ── Inference ────────────────────────────────────────────────────────
78
+ tta_passes: int = 6 # number of deterministic TTA transforms (max 6)
79
+ n_bootstrap: int = 1000 # bootstrap iterations for threshold stabilisation
80
+
81
+ # ── Device (auto-detected) ───────────────────────────────────────────
82
+ device: str = (
83
+ "cuda" if torch.cuda.is_available()
84
+ else "mps" if torch.backends.mps.is_available()
85
+ else "cpu"
86
+ )
87
+
88
+ def setup(self) -> "Config":
89
+ """Create output directory and return self (for chaining)."""
90
+ os.makedirs(self.output_dir, exist_ok=True)
91
+ return self
92
+
93
+
94
+ # Global singleton β€” import and use directly, or override fields before training
95
+ CFG = Config().setup()
model_training/src/data.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ from src.config import CFG
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Column auto-detection
15
+ # ---------------------------------------------------------------------------
16
+ FILENAME_CANDIDATES = [
17
+ "image_name", "filename", "file", "image", "image_id", "img", "name",
18
+ "image index", "image_index", # NIH ChestX-ray14
19
+ ]
20
+ LABEL_CANDIDATES = [
21
+ "label", "cardiomegaly", "class", "target", "y",
22
+ "finding_labels", "finding labels", "finding", # NIH ChestX-ray14
23
+ "labels",
24
+ ]
25
+ POSITIVE_KEYWORD = "cardiomegaly"
26
+
27
+
28
+ def _autodetect(df: pd.DataFrame, candidates: list[str]) -> str:
29
+ """Return the first column in *df* whose lowercase name is in *candidates*."""
30
+ lower = {c.lower(): c for c in df.columns}
31
+ for cand in candidates:
32
+ if cand in lower:
33
+ return lower[cand]
34
+ raise ValueError(f"None of {candidates} found in columns: {list(df.columns)}")
35
+
36
+
37
+ def _coerce_to_binary(series: pd.Series) -> pd.Series:
38
+ """Map mixed label encodings (0/1, 'cardiomegaly', 'no finding', bool, ...) to 0/1."""
39
+ def to_int(v):
40
+ if pd.isna(v):
41
+ return 0
42
+ if isinstance(v, (int, np.integer)):
43
+ return int(v != 0)
44
+ if isinstance(v, (float, np.floating)):
45
+ return int(v != 0)
46
+ if isinstance(v, bool):
47
+ return int(v)
48
+ s = str(v).strip().lower()
49
+ if s in {"1", "true", "yes", "y", "positive", "pos"}:
50
+ return 1
51
+ if s in {"0", "false", "no", "n", "negative", "neg", "no finding", ""}:
52
+ return 0
53
+ return int(POSITIVE_KEYWORD in s)
54
+ return series.apply(to_int).astype(int)
55
+
56
+
57
+ def _resolve_filenames(df: pd.DataFrame, filename_col: str, image_dir: str) -> pd.DataFrame:
58
+ """Add an `image_path` column. Drops rows whose file cannot be found.
59
+
60
+ Tolerates different case, trailing spaces, and missing/wrong extensions.
61
+ """
62
+ disk: dict[str, str] = {}
63
+ for entry in os.scandir(image_dir):
64
+ if not entry.is_file():
65
+ continue
66
+ name = entry.name
67
+ disk[name.lower()] = name
68
+ stem = os.path.splitext(name)[0].lower()
69
+ disk.setdefault(stem, name)
70
+
71
+ resolved, missing = [], []
72
+ for fn in df[filename_col].astype(str):
73
+ raw = fn.strip()
74
+ raw_l = raw.lower()
75
+ hit = disk.get(raw_l) or disk.get(os.path.splitext(raw_l)[0])
76
+ if hit is None:
77
+ for ext in (".png", ".jpg", ".jpeg"):
78
+ if raw_l + ext in disk:
79
+ hit = disk[raw_l + ext]
80
+ break
81
+ if hit is None:
82
+ missing.append(raw)
83
+ resolved.append(None)
84
+ else:
85
+ resolved.append(os.path.join(image_dir, hit))
86
+
87
+ df = df.copy()
88
+ df["image_path"] = resolved
89
+ keep = df["image_path"].notna()
90
+ if (~keep).any():
91
+ print(f"Warning: {(~keep).sum()} rows dropped (file not found). Examples: {missing[:5]}")
92
+ return df[keep].reset_index(drop=True)
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Public API
97
+ # ---------------------------------------------------------------------------
98
+ def load_labels(csv_path: str, image_dir: str) -> pd.DataFrame:
99
+ """Read CSV, auto-detect filename + label columns, coerce labels, resolve paths.
100
+
101
+ Returned DataFrame columns: filename, label, image_path
102
+ """
103
+ df = pd.read_csv(csv_path)
104
+ fn_col = _autodetect(df, FILENAME_CANDIDATES)
105
+ lb_col = _autodetect(df, LABEL_CANDIDATES)
106
+ print(f"Detected filename column: {fn_col!r} label column: {lb_col!r}")
107
+
108
+ df = df[[fn_col, lb_col]].rename(columns={fn_col: "filename", lb_col: "label"})
109
+ df["label"] = _coerce_to_binary(df["label"])
110
+ df = _resolve_filenames(df, "filename", image_dir)
111
+ df = df.drop_duplicates(subset=["filename"]).reset_index(drop=True)
112
+
113
+ if len(df) == 0:
114
+ raise ValueError("No valid labelled images found.")
115
+
116
+ n_pos = int(df["label"].sum())
117
+ n_neg = int((df["label"] == 0).sum())
118
+ print(f"Loaded {len(df)} labelled images pos={n_pos} neg={n_neg}")
119
+ return df
120
+
121
+
122
+ def split_dataframe(
123
+ df: pd.DataFrame,
124
+ val_size: float | None = None,
125
+ test_size: float | None = None,
126
+ seed: int | None = None,
127
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
128
+ """Stratified train / val / test split.
129
+
130
+ Falls back to CFG values when parameters are not supplied.
131
+ """
132
+ val_size = val_size if val_size is not None else CFG.val_size
133
+ test_size = test_size if test_size is not None else CFG.test_size
134
+ seed = seed if seed is not None else CFG.seed
135
+
136
+ train_tmp_df, test_df = train_test_split(
137
+ df, test_size=test_size, stratify=df["label"], random_state=seed,
138
+ )
139
+ rel_val = val_size / (1.0 - test_size)
140
+ train_df, val_df = train_test_split(
141
+ train_tmp_df, test_size=rel_val,
142
+ stratify=train_tmp_df["label"], random_state=seed,
143
+ )
144
+ return (
145
+ train_df.reset_index(drop=True),
146
+ val_df.reset_index(drop=True),
147
+ test_df.reset_index(drop=True),
148
+ )
model_training/src/dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+ import torchvision.transforms as T
12
+
13
+ import torchxrayvision as xrv
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Normalisation functions (one per backbone family)
18
+ # ---------------------------------------------------------------------------
19
+ def xrv_normalize_np(pil_img: Image.Image) -> torch.Tensor:
20
+ """PIL grayscale β†’ (1, H, W) float tensor in [-1024, 1024] (torchxrayvision)."""
21
+ arr = np.array(pil_img, dtype=np.float32) # (H, W) in [0, 255]
22
+ arr = xrv.datasets.normalize(arr, 255) # β†’ [-1024, 1024]
23
+ arr = arr[None, ...] # (1, H, W)
24
+ return torch.from_numpy(arr).float()
25
+
26
+
27
+ _IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
28
+ _IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
29
+
30
+
31
+ def imagenet_normalize_np(pil_img: Image.Image) -> torch.Tensor:
32
+ """PIL grayscale β†’ (3, H, W) float tensor normalized with ImageNet stats.
33
+
34
+ The single grayscale channel is replicated to 3 channels so that ImageNet-
35
+ pretrained backbones (MobileNet, EfficientNet) receive the expected input shape.
36
+ """
37
+ arr = np.array(pil_img, dtype=np.float32) / 255.0 # [0, 1]
38
+ arr = np.stack([arr, arr, arr], axis=0) # (3, H, W)
39
+ arr = (arr - _IMAGENET_MEAN) / _IMAGENET_STD
40
+ return torch.from_numpy(arr).float()
41
+
42
+
43
+ def get_normalize_fn(backbone: str):
44
+ """Return the correct normalization callable for the given backbone name.
45
+
46
+ "densenet121" / "densenet121-res224-all"
47
+ β†’ xrv_normalize_np (grayscale, [-1024, 1024])
48
+ "rad-dino"
49
+ β†’ imagenet_normalize_np (3-ch RGB replicated, ImageNet stats)
50
+ RAD-DINO is a ViT-B/14; feed at 518Γ—518 for best accuracy.
51
+ all other torchvision backbones
52
+ β†’ imagenet_normalize_np
53
+ """
54
+ if backbone in ("densenet121", "densenet121-res224-all"):
55
+ return xrv_normalize_np
56
+ return imagenet_normalize_np
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Labelled dataset (train / val / test)
61
+ # ---------------------------------------------------------------------------
62
+ class ChestXrayDataset(Dataset):
63
+ """Returns (image_tensor, label, filename) triples.
64
+
65
+ backbone controls the normalization applied after PIL transforms:
66
+ "densenet121" β†’ single-channel tensor in [-1024, 1024] (xrv)
67
+ any torchvision model β†’ 3-channel tensor with ImageNet normalization
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ df: pd.DataFrame,
73
+ pil_transform=None,
74
+ use_erasing: bool = False,
75
+ backbone: str | None = None,
76
+ ) -> None:
77
+ from src.config import CFG
78
+ self.df = df.reset_index(drop=True)
79
+ self.pil_transform = pil_transform
80
+ self.use_erasing = use_erasing
81
+ self._normalize = get_normalize_fn(backbone or CFG.backbone)
82
+ self._erasing = T.RandomErasing(
83
+ p=0.5, scale=(0.02, 0.08), ratio=(0.3, 3.3), value=0
84
+ )
85
+
86
+ def __len__(self) -> int:
87
+ return len(self.df)
88
+
89
+ def __getitem__(self, idx: int):
90
+ row = self.df.iloc[idx]
91
+ img = Image.open(row["image_path"]).convert("L")
92
+ if self.pil_transform is not None:
93
+ img = self.pil_transform(img)
94
+ normalize = getattr(self, "_normalize", xrv_normalize_np)
95
+ tensor = normalize(img)
96
+ if self.use_erasing:
97
+ tensor = self._erasing(tensor)
98
+ label = torch.tensor(float(row["label"]), dtype=torch.float32)
99
+ return tensor, label, row["filename"]
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # TTA dataset β€” flexible source (DataFrame with image_path, or external dir)
104
+ # ---------------------------------------------------------------------------
105
+ class TTADataset(Dataset):
106
+ """Used by inference passes (one TTA transform per pass)."""
107
+
108
+ def __init__(
109
+ self,
110
+ df: pd.DataFrame,
111
+ pil_transform,
112
+ image_dir: Optional[str] = None,
113
+ backbone: str | None = None,
114
+ ) -> None:
115
+ from src.config import CFG
116
+ self.df = df.reset_index(drop=True)
117
+ self.pil_transform = pil_transform
118
+ self.image_dir = image_dir
119
+ self._normalize = get_normalize_fn(backbone or CFG.backbone)
120
+
121
+ def __len__(self) -> int:
122
+ return len(self.df)
123
+
124
+ def __getitem__(self, idx: int):
125
+ row = self.df.iloc[idx]
126
+ if "image_path" in row and pd.notna(row.get("image_path")):
127
+ path = row["image_path"]
128
+ else:
129
+ path = os.path.join(self.image_dir, row["filename"])
130
+ img = Image.open(path).convert("L")
131
+ img = self.pil_transform(img)
132
+ normalize = getattr(self, "_normalize", xrv_normalize_np)
133
+ tensor = normalize(img)
134
+ label = float(row["label"]) if "label" in row and not pd.isna(row.get("label", np.nan)) else 0.0
135
+ name = row["filename"] if "filename" in row else os.path.basename(path)
136
+ return tensor, torch.tensor(label, dtype=torch.float32), name
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Submission dataset (unlabelled images in a flat directory)
141
+ # ---------------------------------------------------------------------------
142
+ class SubmissionDataset(Dataset):
143
+ """Unlabelled test images for final inference.
144
+
145
+ Returns (image_tensor, filename).
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ image_dir: str,
151
+ pil_transform=None,
152
+ backbone: str | None = None,
153
+ ) -> None:
154
+ from src.config import CFG
155
+ self.image_dir = image_dir
156
+ self.pil_transform = pil_transform
157
+ self._normalize = get_normalize_fn(backbone or CFG.backbone)
158
+ self.image_files = sorted(
159
+ f for f in os.listdir(image_dir)
160
+ if f.lower().endswith((".png", ".jpg", ".jpeg"))
161
+ )
162
+
163
+ def __len__(self) -> int:
164
+ return len(self.image_files)
165
+
166
+ def __getitem__(self, idx: int):
167
+ fname = self.image_files[idx]
168
+ img = Image.open(os.path.join(self.image_dir, fname)).convert("L")
169
+ if self.pil_transform is not None:
170
+ img = self.pil_transform(img)
171
+ normalize = getattr(self, "_normalize", xrv_normalize_np)
172
+ tensor = normalize(img)
173
+ return tensor, fname
model_training/src/model.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchxrayvision as xrv
8
+
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # RAD-DINO wrapper
12
+ # ---------------------------------------------------------------------------
13
+ class RadDinoWrapper(nn.Module):
14
+ """microsoft/rad-dino β€” DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays.
15
+
16
+ Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier``
17
+ contract used by every other backbone, so freeze helpers and the two-stage
18
+ optimiser work without modification.
19
+
20
+ Architecture
21
+ ────────────
22
+ .features β€” the full Dinov2Model (embeddings + 12 transformer blocks + layernorm)
23
+ .classifier β€” nn.Linear(hidden_size=768, out_features=1)
24
+
25
+ Forward pass
26
+ ────────────
27
+ x : (B, 3, H, W) float tensor β€” ImageNet-normalised, any multiple of 14 px.
28
+ Recommended resolution: 518 Γ— 518 (native: 37 Γ— 37 patches at 14 px).
29
+ Returns (B,) logit tensor.
30
+
31
+ Freeze / unfreeze
32
+ ─────────────────
33
+ freeze_backbone() β†’ freezes .features (all 12 blocks + embeddings)
34
+ partial_unfreeze(N) β†’ unfreeze last (12 βˆ’ N) blocks + layernorm;
35
+ keep embeddings + first N blocks frozen.
36
+ """
37
+
38
+ def __init__(self) -> None:
39
+ super().__init__()
40
+ from transformers import AutoModel # lazy β€” only loaded when this backbone is used
41
+ dinov2 = AutoModel.from_pretrained("microsoft/rad-dino")
42
+ self.features = dinov2
43
+ self.classifier = nn.Linear(dinov2.config.hidden_size, 1)
44
+ nn.init.trunc_normal_(self.classifier.weight, std=0.02)
45
+ nn.init.zeros_(self.classifier.bias)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ out = self.features(pixel_values=x) # Dinov2ModelOutput
49
+ cls = out.last_hidden_state[:, 0] # CLS token (B, 768)
50
+ return self.classifier(cls) # (B, 1)
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Backbone factory
55
+ # ---------------------------------------------------------------------------
56
+ def build_model(backbone: str | None = None) -> nn.Module:
57
+ """Build a backbone model for Cardiomegaly classification.
58
+
59
+ backbone options (also set via CFG.backbone):
60
+ "densenet121" β€” torchxrayvision DenseNet-121, pretrained on ~1M chest
61
+ X-rays; outputs raw Cardiomegaly logit via pathology index.
62
+ "rad-dino" β€” microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M
63
+ chest X-rays (HuggingFace); 518Γ—518 recommended input.
64
+ "mobilenet_v3_large" β€” torchvision MobileNetV3-Large (ImageNet); final linear
65
+ replaced with a single-output head.
66
+ "efficientnet_b0" β€” torchvision EfficientNet-B0 (ImageNet); same replacement.
67
+ "efficientnet_b3" β€” torchvision EfficientNet-B3 (ImageNet); same replacement.
68
+
69
+ All returned models expose .features and .classifier so that freeze_backbone()
70
+ and the two-stage optimizer in train_one_seed() work unchanged.
71
+ Input tensor format differs by backbone β€” use dataset.get_normalize_fn(backbone).
72
+ """
73
+ from src.config import CFG # lazy to avoid circular import at module load
74
+ backbone = backbone or CFG.backbone
75
+
76
+ if backbone in ("densenet121", "densenet121-res224-all"):
77
+ model = xrv.models.DenseNet(weights="densenet121-res224-all")
78
+ model.op_threshs = None # raw logits at every output
79
+ model.apply_sigmoid = False # belt + suspenders
80
+ return model
81
+
82
+ if backbone == "rad-dino":
83
+ return RadDinoWrapper()
84
+
85
+ import torchvision.models as tvm
86
+
87
+ if backbone == "mobilenet_v3_large":
88
+ model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2)
89
+ in_features = model.classifier[-1].in_features
90
+ model.classifier[-1] = nn.Linear(in_features, 1)
91
+ return model
92
+
93
+ if backbone in ("efficientnet_b0", "efficientnet_b3"):
94
+ if backbone == "efficientnet_b0":
95
+ model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1)
96
+ else:
97
+ model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1)
98
+ in_features = model.classifier[-1].in_features
99
+ model.classifier[-1] = nn.Linear(in_features, 1)
100
+ return model
101
+
102
+ raise ValueError(
103
+ f"Unknown backbone: {backbone!r}. "
104
+ "Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3"
105
+ )
106
+
107
+
108
+ def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
109
+ """Forward pass returning a (B,) tensor of raw logits for Cardiomegaly.
110
+
111
+ For torchxrayvision DenseNet the logit is extracted from the pathology head.
112
+ For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model
113
+ outputs (B, 1) which is squeezed to (B,).
114
+ """
115
+ if isinstance(model, xrv.models.DenseNet):
116
+ out = model(x) # (B, num_pathologies)
117
+ idx = model.pathologies.index("Cardiomegaly")
118
+ return out[:, idx]
119
+ return model(x).squeeze(1) # (B, 1) β†’ (B,)
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Backbone management helpers
124
+ # ---------------------------------------------------------------------------
125
+ def freeze_backbone(model: nn.Module) -> nn.Module:
126
+ """Freeze all params in .features; keep .classifier trainable."""
127
+ for p in model.features.parameters():
128
+ p.requires_grad = False
129
+ for p in model.classifier.parameters():
130
+ p.requires_grad = True
131
+ return model
132
+
133
+
134
+ def unfreeze_all(model: nn.Module) -> nn.Module:
135
+ """Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze."""
136
+ for p in model.parameters():
137
+ p.requires_grad = True
138
+ return model
139
+
140
+
141
+ # DenseNet-121 block groups: (block_name, transition_name) for blocks 1–4
142
+ _DENSENET_BLOCK_GROUPS = [
143
+ ("denseblock1", "transition1"),
144
+ ("denseblock2", "transition2"),
145
+ ("denseblock3", "transition3"),
146
+ ("denseblock4", "norm5"),
147
+ ]
148
+
149
+
150
+ def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module:
151
+ """Selectively unfreeze the model for stage-2 fine-tuning.
152
+
153
+ frozen_blocks β€” how many feature blocks to keep frozen:
154
+ 0 β†’ unfreeze everything (same as unfreeze_all)
155
+
156
+ DenseNet-121 (4 dense block groups):
157
+ 1 β†’ keep denseblock1 (+transition1) frozen
158
+ 2 β†’ keep denseblock1–2 frozen
159
+ 3 β†’ keep denseblock1–3 frozen
160
+ 4 β†’ keep all dense blocks frozen (only classifier trains)
161
+
162
+ RAD-DINO / ViT-B (12 transformer blocks):
163
+ 1–12 β†’ keep embeddings + first N transformer blocks frozen
164
+ (last 12βˆ’N blocks + layernorm are unfrozen)
165
+ β‰₯12 β†’ keep all transformer blocks frozen (only classifier trains)
166
+
167
+ torchvision models (MobileNet, EfficientNet):
168
+ N β†’ freeze first N indexed children of model.features.
169
+ """
170
+ for p in model.parameters():
171
+ p.requires_grad = True
172
+
173
+ if frozen_blocks <= 0:
174
+ return model
175
+
176
+ if isinstance(model, xrv.models.DenseNet):
177
+ frozen_names: set[str] = set()
178
+ for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))):
179
+ frozen_names.update(_DENSENET_BLOCK_GROUPS[i])
180
+ for name, module in model.features.named_children():
181
+ if name in frozen_names:
182
+ for p in module.parameters():
183
+ p.requires_grad = False
184
+
185
+ elif isinstance(model, RadDinoWrapper):
186
+ # Always freeze the patch/position embeddings
187
+ for p in model.features.embeddings.parameters():
188
+ p.requires_grad = False
189
+ # Freeze the first `frozen_blocks` transformer blocks
190
+ encoder_layers = model.features.encoder.layer
191
+ for block in encoder_layers[:frozen_blocks]:
192
+ for p in block.parameters():
193
+ p.requires_grad = False
194
+
195
+ else:
196
+ for module in list(model.features.children())[:frozen_blocks]:
197
+ for p in module.parameters():
198
+ p.requires_grad = False
199
+
200
+ return model
201
+
202
+
203
+ def trainable_params(model: nn.Module) -> List[nn.Parameter]:
204
+ """List of parameters with `requires_grad=True` (for optimiser construction)."""
205
+ return [p for p in model.parameters() if p.requires_grad]
model_training/src/train.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+ import os
6
+ from typing import List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from sklearn.metrics import confusion_matrix, roc_auc_score
14
+ from torch.utils.data import DataLoader
15
+
16
+ from src.config import CFG
17
+ from src.dataset import ChestXrayDataset, SubmissionDataset, TTADataset
18
+ from src.model import (
19
+ build_model,
20
+ cardio_logit,
21
+ freeze_backbone,
22
+ partial_unfreeze,
23
+ trainable_params,
24
+ unfreeze_all,
25
+ )
26
+ from src.transforms import make_tta_transforms
27
+ from src.utils import free_device_cache, log_run, set_seed
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Mixup helper
32
+ # ---------------------------------------------------------------------------
33
+ def mixup_data(
34
+ x: torch.Tensor,
35
+ y: torch.Tensor,
36
+ alpha: float = 0.4,
37
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
38
+ """Return a randomly mixed batch and the corresponding soft labels.
39
+
40
+ Ξ» ~ Beta(Ξ±, Ξ±). When Ξ± ≀ 0 the original batch is returned unchanged.
41
+
42
+ Args:
43
+ x: Image tensor (B, C, H, W) on the training device.
44
+ y: Label tensor (B,) – may already be soft (e.g. after smoothing).
45
+ alpha: Beta distribution parameter. Typical: 0.2 – 0.4.
46
+ """
47
+ if alpha <= 0:
48
+ return x, y
49
+ lam = float(np.random.beta(alpha, alpha))
50
+ idx = torch.randperm(x.size(0), device=x.device)
51
+ mixed_x = lam * x + (1.0 - lam) * x[idx]
52
+ mixed_y = lam * y + (1.0 - lam) * y[idx]
53
+ return mixed_x, mixed_y
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Differentiable composite loss
58
+ # ---------------------------------------------------------------------------
59
+ class SoftCompositeLoss(nn.Module):
60
+ """Differentiable approximation of composite = 0.5Β·AUC + 0.25Β·sens + 0.25Β·spec.
61
+
62
+ Minimises ``1 - soft_composite``, blended with standard BCE for stability.
63
+
64
+ Soft-AUC
65
+ Pairwise sigmoid over all (positive, negative) logit pairs in the batch:
66
+ ``soft_auc = mean( Οƒ(Ξ³ Β· (logit⁺ βˆ’ logit⁻)) )``
67
+ where Ξ³ (``auc_gamma``) is a sharpness temperature.
68
+
69
+ Soft-sens / soft-spec
70
+ ``soft_sens = mean( Οƒ(logit) | y=1 )``
71
+ ``soft_spec = mean( 1 βˆ’ Οƒ(logit) | y=0 )``
72
+
73
+ Total loss
74
+ ``Ξ± Β· BCE + (1 βˆ’ Ξ±) Β· (1 βˆ’ soft_composite)``
75
+
76
+ Args:
77
+ alpha: Weight of BCE in the blend (0 = pure composite, 1 = pure BCE).
78
+ auc_gamma: Temperature for the pairwise sigmoid (higher β†’ sharper AUC signal).
79
+ eps: Numerical stability floor.
80
+ """
81
+
82
+ def __init__(self, alpha: float = 0.5, auc_gamma: float = 1.0, eps: float = 1e-7):
83
+ super().__init__()
84
+ self.alpha = alpha
85
+ self.auc_gamma = auc_gamma
86
+ self.eps = eps
87
+ self._bce = nn.BCEWithLogitsLoss()
88
+
89
+ def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
90
+ bce_loss = self._bce(logit, target)
91
+
92
+ prob = torch.sigmoid(logit)
93
+ # Use > 0.5 so the masks work correctly for both hard labels {0,1}
94
+ # and soft targets produced by mixup or label smoothing.
95
+ pos_mask = (target > 0.5)
96
+ neg_mask = ~pos_mask
97
+ n_pos = pos_mask.sum()
98
+ n_neg = neg_mask.sum()
99
+
100
+ # ── Soft AUC (pairwise) ──────────────────────────────────────────────
101
+ if n_pos > 0 and n_neg > 0:
102
+ pos_logits = logit[pos_mask] # (n_pos,)
103
+ neg_logits = logit[neg_mask] # (n_neg,)
104
+ diff = pos_logits.unsqueeze(1) - neg_logits.unsqueeze(0) # (n_pos, n_neg)
105
+ soft_auc = torch.sigmoid(self.auc_gamma * diff).mean()
106
+ else:
107
+ soft_auc = torch.tensor(0.5, device=logit.device, dtype=logit.dtype)
108
+
109
+ # ── Soft sensitivity / specificity ──────────────────────────────────
110
+ soft_sens = prob[pos_mask].mean() if n_pos > 0 else torch.tensor(
111
+ 0.0, device=logit.device, dtype=logit.dtype)
112
+ soft_spec = (1.0 - prob[neg_mask]).mean() if n_neg > 0 else torch.tensor(
113
+ 0.0, device=logit.device, dtype=logit.dtype)
114
+
115
+ soft_composite = 0.5 * soft_auc + 0.25 * soft_sens + 0.25 * soft_spec
116
+ composite_loss = 1.0 - soft_composite
117
+
118
+ return self.alpha * bce_loss + (1.0 - self.alpha) * composite_loss
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Epoch runner
123
+ # ---------------------------------------------------------------------------
124
+ def run_one_epoch(
125
+ model: nn.Module,
126
+ loader: DataLoader,
127
+ criterion: Optional[nn.Module] = None,
128
+ optimizer: Optional[optim.Optimizer] = None,
129
+ scaler: Optional[torch.cuda.amp.GradScaler] = None,
130
+ grad_clip: Optional[float] = None,
131
+ mixup_alpha: float = 0.0,
132
+ label_smoothing: float = 0.0,
133
+ ) -> dict:
134
+ """Single forward pass over *loader*.
135
+
136
+ Pass ``optimizer=None`` for evaluation (mixup and smoothing are skipped).
137
+ Expects (image, label, name) batches. Uses CUDA AMP when available.
138
+
139
+ Hard original labels are always accumulated for metric computation;
140
+ the (potentially mixed + smoothed) soft labels are only used for the loss.
141
+ """
142
+ is_train = optimizer is not None
143
+ model.train(is_train)
144
+
145
+ losses, logits_all, labels_all, names_all = [], [], [], []
146
+ pin = (CFG.device == "cuda")
147
+ grad_clip = grad_clip if grad_clip is not None else CFG.grad_clip
148
+
149
+ amp_ctx = torch.cuda.amp.autocast(enabled=(CFG.device == "cuda"))
150
+ for x, y, names in loader:
151
+ x = x.to(CFG.device, non_blocking=pin)
152
+ y = y.to(CFG.device, non_blocking=pin)
153
+
154
+ # Keep hard labels for metric accumulation (before any augmentation)
155
+ y_hard = y.detach().clone()
156
+
157
+ if is_train:
158
+ # Mixup: interpolate two samples + their labels in-place
159
+ if mixup_alpha > 0.0:
160
+ x, y = mixup_data(x, y, alpha=mixup_alpha)
161
+ # Label smoothing: y_smooth = y*(1-Ξ΅) + 0.5*Ξ΅
162
+ if label_smoothing > 0.0:
163
+ y = y * (1.0 - label_smoothing) + 0.5 * label_smoothing
164
+
165
+ with torch.set_grad_enabled(is_train):
166
+ with amp_ctx:
167
+ logit = cardio_logit(model, x)
168
+ loss = criterion(logit, y) if criterion is not None else None
169
+
170
+ if is_train:
171
+ optimizer.zero_grad(set_to_none=True)
172
+ if scaler is not None and scaler.is_enabled():
173
+ scaler.scale(loss).backward()
174
+ scaler.unscale_(optimizer)
175
+ torch.nn.utils.clip_grad_norm_(trainable_params(model), grad_clip)
176
+ scaler.step(optimizer)
177
+ scaler.update()
178
+ else:
179
+ loss.backward()
180
+ torch.nn.utils.clip_grad_norm_(trainable_params(model), grad_clip)
181
+ optimizer.step()
182
+
183
+ if loss is not None:
184
+ losses.append(loss.item())
185
+ logits_all.append(logit.detach().float().cpu().numpy())
186
+ labels_all.append(y_hard.float().cpu().numpy()) # always hard labels
187
+ names_all.extend(list(names))
188
+
189
+ y_true = np.concatenate(labels_all)
190
+ y_logit = np.concatenate(logits_all)
191
+ y_prob = 1.0 / (1.0 + np.exp(-y_logit))
192
+ auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else float("nan")
193
+
194
+ # Per-epoch metrics at threshold=0.5 (used for progress logging)
195
+ y_pred = (y_prob >= 0.5).astype(int)
196
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
197
+ acc = float((tp + tn) / (tp + tn + fp + fn)) if (tp + tn + fp + fn) > 0 else float("nan")
198
+ sens = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
199
+ spec = float(tn / (tn + fp)) if (tn + fp) > 0 else 0.0
200
+ composite = 0.5 * (auc if not np.isnan(auc) else 0.0) + 0.25 * sens + 0.25 * spec
201
+
202
+ return {
203
+ "loss": float(np.mean(losses)) if losses else float("nan"),
204
+ "auc": float(auc),
205
+ "acc": acc,
206
+ "sens": sens,
207
+ "spec": spec,
208
+ "composite": float(composite),
209
+ "y_true": y_true,
210
+ "y_prob": y_prob,
211
+ "names": names_all,
212
+ }
213
+
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # Single-seed two-stage training
217
+ # ---------------------------------------------------------------------------
218
+ def train_one_seed(
219
+ seed: int,
220
+ train_loader: DataLoader,
221
+ val_loader: DataLoader,
222
+ output_dir: Optional[str] = None,
223
+ config=None,
224
+ ) -> Tuple[nn.Module, float, str, list[dict]]:
225
+ """Train ONE model end-to-end (frozen warmup β†’ full fine-tune).
226
+
227
+ Returns (best_model, best_val_auc, checkpoint_path, history).
228
+ """
229
+ cfg = config or CFG
230
+ output_dir = output_dir or cfg.output_dir
231
+ os.makedirs(output_dir, exist_ok=True)
232
+
233
+ set_seed(seed)
234
+
235
+ model = build_model(cfg.backbone).to(cfg.device)
236
+ total_params = sum(p.numel() for p in model.parameters())
237
+
238
+ print("\n" + "=" * 80)
239
+ print(f" Seed : {seed}")
240
+ print(f" Model : {model.__class__.__name__} ({total_params:,} total params)")
241
+ print(f" Backbone: {cfg.backbone}")
242
+ print(f" Device : {cfg.device}")
243
+ print("=" * 80)
244
+
245
+ if cfg.use_composite_loss:
246
+ criterion = SoftCompositeLoss(
247
+ alpha=cfg.composite_loss_alpha,
248
+ auc_gamma=cfg.composite_loss_gamma,
249
+ )
250
+ print(
251
+ f" Loss : SoftCompositeLoss "
252
+ f"(Ξ±={cfg.composite_loss_alpha}, Ξ³={cfg.composite_loss_gamma})"
253
+ )
254
+ else:
255
+ criterion = nn.BCEWithLogitsLoss()
256
+ print(" Loss : BCEWithLogitsLoss")
257
+
258
+ mixup_alpha = getattr(cfg, "mixup_alpha", 0.0)
259
+ label_smoothing = getattr(cfg, "label_smoothing", 0.0)
260
+ if mixup_alpha > 0:
261
+ print(f" Mixup : Ξ±={mixup_alpha}")
262
+ if label_smoothing > 0:
263
+ print(f" Smoothing: Ξ΅={label_smoothing}")
264
+
265
+ scaler = torch.cuda.amp.GradScaler(enabled=(cfg.device == "cuda"))
266
+ history: list[dict] = []
267
+
268
+ _aug_kw = dict(mixup_alpha=mixup_alpha, label_smoothing=label_smoothing)
269
+
270
+ # ── Stage 1: frozen backbone, head-only warmup ─────────────────────────
271
+ freeze_backbone(model)
272
+ n_trainable = sum(p.numel() for p in trainable_params(model))
273
+ print(f"\n Stage 1 β€” all backbone blocks FROZEN | {n_trainable:,} trainable params")
274
+ opt_frozen = optim.AdamW(
275
+ trainable_params(model), lr=cfg.head_lr, weight_decay=cfg.weight_decay,
276
+ )
277
+ for ep in range(1, cfg.frozen_epochs + 1):
278
+ t = run_one_epoch(model, train_loader, criterion, opt_frozen, scaler, **_aug_kw)
279
+ v = run_one_epoch(model, val_loader, criterion)
280
+ history.append({
281
+ "seed": seed, "stage": "frozen", "epoch": ep,
282
+ "train_loss": t["loss"], "train_auc": t["auc"],
283
+ "train_acc": t["acc"], "train_composite": t["composite"],
284
+ "val_loss": v["loss"], "val_auc": v["auc"],
285
+ "val_acc": v["acc"], "val_sens": v["sens"],
286
+ "val_spec": v["spec"], "val_composite": v["composite"],
287
+ "lr": opt_frozen.param_groups[0]["lr"],
288
+ })
289
+ print(
290
+ f" [frozen] {ep}/{cfg.frozen_epochs} "
291
+ f"loss={t['loss']:.4f} train_acc={t['acc']*100:.1f}% | "
292
+ f"val_auc={v['auc']:.4f} val_acc={v['acc']*100:.1f}% "
293
+ f"sens={v['sens']:.3f} spec={v['spec']:.3f} comp={v['composite']:.4f}"
294
+ )
295
+
296
+ # ── Stage 2: partial or full fine-tune ───────────────────────────────
297
+ frozen_blocks = getattr(cfg, "frozen_blocks", 0)
298
+ partial_unfreeze(model, frozen_blocks)
299
+ n_trainable = sum(p.numel() for p in trainable_params(model))
300
+ if frozen_blocks == 0:
301
+ stage2_label = "all blocks UNFROZEN"
302
+ else:
303
+ stage2_label = f"{frozen_blocks} block(s) still FROZEN"
304
+ print(f"\n Stage 2 β€” {stage2_label} | {n_trainable:,} trainable params")
305
+
306
+ opt_ft = optim.AdamW(
307
+ [
308
+ {"params": model.features.parameters(), "lr": cfg.backbone_lr},
309
+ {"params": model.classifier.parameters(), "lr": cfg.head_lr},
310
+ ],
311
+ weight_decay=cfg.weight_decay,
312
+ )
313
+ sched = optim.lr_scheduler.CosineAnnealingLR(
314
+ opt_ft, T_max=cfg.finetune_epochs, eta_min=cfg.backbone_lr * 0.01,
315
+ )
316
+
317
+ best_auc, best_state, patience_ctr = -1.0, None, 0
318
+ for ep in range(1, cfg.finetune_epochs + 1):
319
+ t = run_one_epoch(model, train_loader, criterion, opt_ft, scaler, **_aug_kw)
320
+ v = run_one_epoch(model, val_loader, criterion)
321
+ sched.step()
322
+ history.append({
323
+ "seed": seed, "stage": "finetune", "epoch": ep,
324
+ "train_loss": t["loss"], "train_auc": t["auc"],
325
+ "train_acc": t["acc"], "train_composite": t["composite"],
326
+ "val_loss": v["loss"], "val_auc": v["auc"],
327
+ "val_acc": v["acc"], "val_sens": v["sens"],
328
+ "val_spec": v["spec"], "val_composite": v["composite"],
329
+ "lr": opt_ft.param_groups[0]["lr"],
330
+ })
331
+ print(
332
+ f" [ft] {ep}/{cfg.finetune_epochs} "
333
+ f"loss={t['loss']:.4f} train_acc={t['acc']*100:.1f}% | "
334
+ f"val_auc={v['auc']:.4f} val_acc={v['acc']*100:.1f}% "
335
+ f"sens={v['sens']:.3f} spec={v['spec']:.3f} comp={v['composite']:.4f} "
336
+ f"lr={opt_ft.param_groups[0]['lr']:.2e}"
337
+ )
338
+
339
+ if v["auc"] > best_auc:
340
+ best_auc, best_state, patience_ctr = (
341
+ v["auc"], copy.deepcopy(model.state_dict()), 0
342
+ )
343
+ else:
344
+ patience_ctr += 1
345
+ if patience_ctr >= cfg.early_stop_patience:
346
+ print(f" [ft] early stop at epoch {ep} (best val AUC = {best_auc:.4f})")
347
+ break
348
+
349
+ if best_state is not None:
350
+ model.load_state_dict(best_state)
351
+
352
+ ckpt_path = os.path.join(output_dir, f"model_seed{seed}.pth")
353
+ torch.save(best_state if best_state is not None else model.state_dict(), ckpt_path)
354
+ print(f"[seed={seed}] Best val AUC = {best_auc:.4f} checkpoint β†’ {ckpt_path}")
355
+
356
+ return model, best_auc, ckpt_path, history
357
+
358
+
359
+ # ---------------------------------------------------------------------------
360
+ # Multi-seed ensemble training
361
+ # ---------------------------------------------------------------------------
362
+ def train_ensemble(
363
+ train_loader: DataLoader,
364
+ val_loader: DataLoader,
365
+ seeds: Optional[List[int]] = None,
366
+ output_dir: Optional[str] = None,
367
+ config=None,
368
+ ) -> Tuple[List[Tuple[int, nn.Module, float, str]], pd.DataFrame]:
369
+ """Train one model per seed and return (models_list, full_history_df).
370
+
371
+ `models_list` items: (seed, trained_model, best_val_auc, checkpoint_path).
372
+ """
373
+ cfg = config or CFG
374
+ seeds = seeds if seeds is not None else cfg.seeds
375
+ output_dir = output_dir or cfg.output_dir
376
+
377
+ print(f" ENSEMBLE TRAINING STARTED")
378
+
379
+ models, all_history = [], []
380
+ for seed in seeds:
381
+ m, auc, ckpt, hist = train_one_seed(
382
+ seed, train_loader, val_loader,
383
+ output_dir=output_dir, config=cfg,
384
+ )
385
+ models.append((seed, m, auc, ckpt))
386
+ all_history.extend(hist)
387
+ free_device_cache(cfg.device)
388
+
389
+ history_df = pd.DataFrame(all_history)
390
+ history_df.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
391
+
392
+ frozen_blocks = getattr(cfg, "frozen_blocks", 0)
393
+ stage2_label = "all blocks unfrozen" if frozen_blocks == 0 else f"{frozen_blocks} block(s) frozen"
394
+
395
+ print("\n" + "═" * 80)
396
+ print(f" ENSEMBLE COMPLETE")
397
+ print(f" Backbone : {cfg.backbone}")
398
+ print(f" Frozen blocks : {frozen_blocks} ({stage2_label} in Stage 2)")
399
+ print(f" Seeds trained : {len(models)}")
400
+ print(" Per-seed best val AUC:")
401
+ for seed, _, auc, _ in models:
402
+ print(f" seed {seed:>5} : {auc:.4f}")
403
+ best_seed = max(models, key=lambda x: x[2])
404
+ print(f" Best seed : {best_seed[0]} (AUC={best_seed[2]:.4f})")
405
+ print("═" * 80)
406
+
407
+ return models, history_df
408
+
409
+
410
+ # ---------------------------------------------------------------------------
411
+ # Unified entry point (respects CFG.use_ensemble)
412
+ # ---------------------------------------------------------------------------
413
+ def train(
414
+ train_loader: DataLoader,
415
+ val_loader: DataLoader,
416
+ output_dir: Optional[str] = None,
417
+ config=None,
418
+ ) -> Tuple[List[Tuple[int, nn.Module, float, str]], pd.DataFrame]:
419
+ """Train and return (models_list, history_df) β€” same format as train_ensemble.
420
+
421
+ Behaviour is controlled by CFG.use_ensemble:
422
+ True β†’ delegates to train_ensemble (one model per seed in CFG.seeds)
423
+ False β†’ trains a single model with CFG.seed and wraps result in the
424
+ same list format so the rest of the pipeline works unchanged.
425
+ """
426
+ cfg = config or CFG
427
+ if cfg.use_ensemble:
428
+ return train_ensemble(train_loader, val_loader, output_dir=output_dir, config=cfg)
429
+
430
+ m, auc, ckpt, hist = train_one_seed(
431
+ cfg.seed, train_loader, val_loader, output_dir=output_dir, config=cfg,
432
+ )
433
+ history_df = pd.DataFrame(hist)
434
+ history_df.to_csv(
435
+ os.path.join(output_dir or cfg.output_dir, "training_history.csv"), index=False,
436
+ )
437
+ return [(cfg.seed, m, auc, ckpt)], history_df
438
+
439
+
440
+ # ---------------------------------------------------------------------------
441
+ # TTA inference
442
+ # ---------------------------------------------------------------------------
443
+ def tta_predict(
444
+ model: nn.Module,
445
+ df: pd.DataFrame,
446
+ image_dir: Optional[str] = None,
447
+ has_labels: bool = True,
448
+ tta_transforms: Optional[List] = None,
449
+ config=None,
450
+ ) -> dict:
451
+ """Run TTA inference for ONE model on a DataFrame.
452
+
453
+ Predictions are averaged in **logit space** across all TTA passes.
454
+ """
455
+ cfg = config or CFG
456
+ tta_transforms = tta_transforms or make_tta_transforms(cfg.img_size)
457
+ tta_transforms = tta_transforms[:cfg.tta_passes]
458
+
459
+ all_logits: list[np.ndarray] = []
460
+ names_ref, labels_ref = None, None
461
+
462
+ pin = (cfg.device == "cuda")
463
+ amp_ctx = torch.cuda.amp.autocast(enabled=(cfg.device == "cuda"))
464
+
465
+ for tf in tta_transforms:
466
+ ds = TTADataset(df, tf, image_dir)
467
+ loader = DataLoader(
468
+ ds, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
469
+ pin_memory=pin, shuffle=False,
470
+ )
471
+ pass_logits, pass_names, pass_labels = [], [], []
472
+ model.eval()
473
+ with torch.no_grad(), amp_ctx:
474
+ for x, y, names in loader:
475
+ x = x.to(cfg.device, non_blocking=pin)
476
+ logit = cardio_logit(model, x).float().cpu().numpy()
477
+ pass_logits.append(logit)
478
+ pass_names.extend(list(names))
479
+ if has_labels:
480
+ pass_labels.append(y.numpy())
481
+ all_logits.append(np.concatenate(pass_logits))
482
+ if names_ref is None:
483
+ names_ref = pass_names
484
+ labels_ref = np.concatenate(pass_labels) if has_labels else None
485
+
486
+ mean_logit = np.stack(all_logits, axis=0).mean(axis=0)
487
+ mean_prob = (1.0 / (1.0 + np.exp(-mean_logit))).astype(np.float32)
488
+ return {
489
+ "names": names_ref,
490
+ "y_prob": mean_prob,
491
+ "y_true": labels_ref,
492
+ "mean_logit": mean_logit,
493
+ }
494
+
495
+
496
+ def tta_predict_ensemble(
497
+ models_list: List[Tuple[int, nn.Module, float, str]],
498
+ df: pd.DataFrame,
499
+ image_dir: Optional[str] = None,
500
+ has_labels: bool = True,
501
+ tta_transforms: Optional[List] = None,
502
+ config=None,
503
+ ) -> dict:
504
+ """Run TTA for every model in `models_list` and average in logit space."""
505
+ cfg = config or CFG
506
+ all_logits: list[np.ndarray] = []
507
+ names_ref, labels_ref = None, None
508
+
509
+ for (seed, model, _, _) in models_list:
510
+ print(f" TTA with seed={seed}...")
511
+ pred = tta_predict(
512
+ model, df, image_dir=image_dir, has_labels=has_labels,
513
+ tta_transforms=tta_transforms, config=cfg,
514
+ )
515
+ all_logits.append(pred["mean_logit"])
516
+ if names_ref is None:
517
+ names_ref = pred["names"]
518
+ labels_ref = pred["y_true"]
519
+
520
+ mean_logit = np.stack(all_logits, axis=0).mean(axis=0)
521
+ mean_prob = (1.0 / (1.0 + np.exp(-mean_logit))).astype(np.float32)
522
+ return {"names": names_ref, "y_prob": mean_prob, "y_true": labels_ref}
523
+
524
+
525
+ # ---------------------------------------------------------------------------
526
+ # Submission inference
527
+ # ---------------------------------------------------------------------------
528
+ def predict_submission(
529
+ models_list: List[Tuple[int, nn.Module, float, str]],
530
+ submission_dir: str,
531
+ tta_transforms: Optional[List] = None,
532
+ config=None,
533
+ ) -> dict:
534
+ """TTA + ensemble inference on an unlabelled submission directory.
535
+
536
+ Wraps the directory in a DataFrame so we can reuse `tta_predict_ensemble`.
537
+ """
538
+ cfg = config or CFG
539
+ files = sorted(
540
+ f for f in os.listdir(submission_dir)
541
+ if os.path.isfile(os.path.join(submission_dir, f))
542
+ and f.lower().endswith((".png", ".jpg", ".jpeg"))
543
+ )
544
+ sub_df = pd.DataFrame({"filename": files})
545
+ return tta_predict_ensemble(
546
+ models_list, sub_df,
547
+ image_dir=submission_dir, has_labels=False,
548
+ tta_transforms=tta_transforms, config=cfg,
549
+ )
550
+
551
+
552
+ # ---------------------------------------------------------------------------
553
+ # Metrics
554
+ # ---------------------------------------------------------------------------
555
+ def metrics_at_threshold(y_true, y_prob, threshold: float) -> dict:
556
+ """Composite-grading-aware metric set at a given threshold.
557
+
558
+ composite = 0.5Β·AUC + 0.25Β·sensitivity + 0.25Β·specificity
559
+ """
560
+ y_pred = (y_prob >= threshold).astype(int)
561
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
562
+
563
+ sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
564
+ spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
565
+ acc = (tp + tn) / (tp + tn + fp + fn)
566
+ auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else float("nan")
567
+ composite = 0.5 * auc + 0.25 * sens + 0.25 * spec
568
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
569
+
570
+ return {
571
+ "threshold": float(threshold),
572
+ "auc": float(auc),
573
+ "sensitivity": float(sens),
574
+ "specificity": float(spec),
575
+ "accuracy": float(acc),
576
+ "youden": float(sens + spec - 1.0),
577
+ "composite": float(composite),
578
+ "precision": float(precision),
579
+ "tp": int(tp), "tn": int(tn), "fp": int(fp), "fn": int(fn),
580
+ }
581
+
582
+
583
+ # Backwards-compatible alias (used by older notebook cells)
584
+ compute_basic_metrics = metrics_at_threshold
585
+
586
+
587
+ def find_best_threshold(y_true, y_prob) -> Tuple[float, dict]:
588
+ """Pick the threshold that maximises sensitivity + specificity (Youden's J)."""
589
+ candidates = np.unique(np.round(np.concatenate([[0.0], y_prob, [1.0]]), 6))
590
+ best_score, best_row = -np.inf, None
591
+ for thr in candidates:
592
+ m = metrics_at_threshold(y_true, y_prob, thr)
593
+ score = m["sensitivity"] + m["specificity"]
594
+ if score > best_score:
595
+ best_score, best_row = score, m
596
+ return float(best_row["threshold"]), best_row
597
+
598
+
599
+ def bootstrap_threshold(
600
+ y_true, y_prob,
601
+ n_boot: Optional[int] = None,
602
+ seed: Optional[int] = None,
603
+ ) -> float:
604
+ """Bootstrap-stabilised threshold (median across resamples).
605
+
606
+ Generalises better than a single-shot pick on the raw val set.
607
+ """
608
+ n_boot = n_boot if n_boot is not None else CFG.n_bootstrap
609
+ seed = seed if seed is not None else CFG.seed
610
+
611
+ rng = np.random.RandomState(seed)
612
+ thrs: list[float] = []
613
+ n = len(y_true)
614
+ for _ in range(n_boot):
615
+ idx = rng.randint(0, n, size=n)
616
+ if len(np.unique(y_true[idx])) < 2:
617
+ continue
618
+ thr, _ = find_best_threshold(y_true[idx], y_prob[idx])
619
+ thrs.append(thr)
620
+ return float(np.median(thrs)) if thrs else 0.5
621
+
622
+
623
+ def select_threshold(y_true, y_prob, config=None) -> Tuple[float, dict, dict]:
624
+ """Pick the better of (single-shot) vs (bootstrap) thresholds on composite.
625
+
626
+ Bootstrap is preferred unless its composite is clearly worse (margin 0.005).
627
+ Returns (chosen_threshold, single_metrics, bootstrap_metrics).
628
+ """
629
+ cfg = config or CFG
630
+ thr_single, _ = find_best_threshold(y_true, y_prob)
631
+ thr_boot = bootstrap_threshold(y_true, y_prob, n_boot=cfg.n_bootstrap, seed=cfg.seed)
632
+ m_single = metrics_at_threshold(y_true, y_prob, thr_single)
633
+ m_boot = metrics_at_threshold(y_true, y_prob, thr_boot)
634
+ chosen = thr_boot if m_boot["composite"] >= m_single["composite"] - 0.005 else thr_single
635
+ return float(chosen), m_single, m_boot
636
+
637
+
638
+ # ---------------------------------------------------------------------------
639
+ # Saving results
640
+ # ---------------------------------------------------------------------------
641
+ def save_results(
642
+ models_list: List[Tuple[int, nn.Module, float, str]],
643
+ history: pd.DataFrame,
644
+ val_out: dict,
645
+ test_out: dict,
646
+ best_threshold: float,
647
+ output_dir: str,
648
+ model_name: str = "model",
649
+ config=None,
650
+ ) -> None:
651
+ """Persist per-seed checkpoints, history, metrics, predictions, and global log.
652
+
653
+ Per-seed `.pth` files are already written by `train_one_seed`; here we
654
+ only re-save them under the conventional name and write the metrics +
655
+ per-image prediction CSVs.
656
+ """
657
+ cfg = config or CFG
658
+ os.makedirs(output_dir, exist_ok=True)
659
+
660
+ # ── Metric files + per-image predictions ─────────────────────────────
661
+ val_metrics = metrics_at_threshold(val_out["y_true"], val_out["y_prob"], best_threshold)
662
+ test_metrics = metrics_at_threshold(test_out["y_true"], test_out["y_prob"], best_threshold)
663
+
664
+ for split_name, metrics in [("val", val_metrics), ("test", test_metrics)]:
665
+ with open(os.path.join(output_dir, f"{split_name}_metrics_final.json"), "w") as f:
666
+ json.dump(metrics, f, indent=2)
667
+
668
+ history.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
669
+
670
+ for split_name, out in [("val", val_out), ("test", test_out)]:
671
+ y_true = out["y_true"].astype(int)
672
+ y_pred = (out["y_prob"] >= best_threshold).astype(int)
673
+ pd.DataFrame({
674
+ "filename": out["names"],
675
+ "y_true": y_true,
676
+ "prob": out["y_prob"],
677
+ "pred": y_pred,
678
+ "correct": (y_pred == y_true).astype(int),
679
+ }).to_csv(os.path.join(output_dir, f"{split_name}_predictions.csv"), index=False)
680
+
681
+ # ── Ensemble manifest (which seeds + which checkpoints) ──────────────
682
+ pd.DataFrame([
683
+ {"seed": s, "best_val_auc": auc, "checkpoint": ckpt}
684
+ for (s, _, auc, ckpt) in models_list
685
+ ]).to_csv(os.path.join(output_dir, "ensemble_manifest.csv"), index=False)
686
+
687
+ print(f"Results saved β†’ {output_dir}")
688
+
689
+ # ── Append to global results log ─────────────────────────────────────
690
+ log_run(
691
+ model_name=model_name,
692
+ val_metrics=val_metrics,
693
+ test_metrics=test_metrics,
694
+ config=cfg,
695
+ n_seeds=len(models_list),
696
+ log_path=cfg.results_log_path,
697
+ )
model_training/src/transforms.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+
8
+ from src.config import CFG
9
+
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # PIL helpers (TTA expects PIL β†’ PIL transforms; xrv normalisation is applied
13
+ # downstream inside the Dataset).
14
+ # ---------------------------------------------------------------------------
15
+ def _pil_hflip(img: Image.Image) -> Image.Image:
16
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Training and evaluation transforms
21
+ # ---------------------------------------------------------------------------
22
+ def make_transforms(img_size: int | None = None) -> Tuple[T.Compose, T.Compose]:
23
+ """Return (train_transform, eval_transform) PIL-space pipelines.
24
+
25
+ All transforms produce a PIL grayscale image of size (img_size, img_size).
26
+ The downstream Dataset converts it to a single-channel xrv-normalised
27
+ tensor in [-1024, 1024].
28
+
29
+ Train pipeline: small affine, mild jitter, light hflip; random erasing
30
+ happens after xrv normalisation inside the Dataset.
31
+ Eval pipeline: deterministic resize.
32
+ """
33
+ img_size = img_size if img_size is not None else CFG.img_size
34
+
35
+ train_tf = T.Compose([
36
+ T.Resize((img_size + 16, img_size + 16)),
37
+ T.RandomCrop((img_size, img_size)),
38
+ T.RandomHorizontalFlip(p=0.5),
39
+ T.RandomAffine(
40
+ degrees=8,
41
+ translate=(0.04, 0.04),
42
+ scale=(0.95, 1.05),
43
+ fill=0,
44
+ ),
45
+ T.ColorJitter(brightness=0.15, contrast=0.15),
46
+ ])
47
+
48
+ eval_tf = T.Compose([
49
+ T.Resize((img_size, img_size)),
50
+ ])
51
+
52
+ return train_tf, eval_tf
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Test-time augmentation (TTA) transforms
57
+ # ---------------------------------------------------------------------------
58
+ def make_tta_transforms(img_size: int | None = None) -> List[T.Compose]:
59
+ """Six deterministic PIL-space transforms.
60
+
61
+ All end with a resized PIL image ready for xrv_normalize_np().
62
+ Predictions are averaged across all passes (in logit space) inside
63
+ `tta_predict` / `tta_predict_ensemble`.
64
+ """
65
+ img_size = img_size if img_size is not None else CFG.img_size
66
+ size = (img_size, img_size)
67
+
68
+ return [
69
+ T.Compose([T.Resize(size)]),
70
+ T.Compose([T.Resize(size), T.Lambda(_pil_hflip)]),
71
+ T.Compose([T.Resize((img_size + 20, img_size + 20)), T.CenterCrop(size)]),
72
+ T.Compose([T.Resize((img_size - 20, img_size - 20)),
73
+ T.Pad(10, fill=0), T.CenterCrop(size)]),
74
+ T.Compose([T.Resize(size),
75
+ T.RandomAffine(degrees=(6, 6), fill=0)]),
76
+ T.Compose([T.Resize(size),
77
+ T.RandomAffine(degrees=(-6, -6), fill=0)]),
78
+ ]
model_training/src/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import os
5
+ import random
6
+ from datetime import datetime
7
+ from typing import TYPE_CHECKING
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+
13
+ if TYPE_CHECKING:
14
+ from src.config import Config
15
+
16
+
17
+ def set_seed(seed: int) -> None:
18
+ """Set all relevant random seeds for reproducibility."""
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ if torch.cuda.is_available():
23
+ torch.cuda.manual_seed_all(seed)
24
+ elif torch.backends.mps.is_available():
25
+ torch.mps.manual_seed(seed)
26
+
27
+
28
+ def free_device_cache(device: str) -> None:
29
+ """Release unused memory on GPU / MPS (useful between seeds / Optuna trials)."""
30
+ if device == "mps":
31
+ torch.mps.empty_cache()
32
+ elif device == "cuda":
33
+ torch.cuda.empty_cache()
34
+
35
+
36
+ def log_run(
37
+ model_name: str,
38
+ val_metrics: dict,
39
+ test_metrics: dict,
40
+ config: "Config",
41
+ n_seeds: int,
42
+ log_path: str = "results_log.csv",
43
+ ) -> pd.DataFrame:
44
+ """Append one training run to the global results log CSV.
45
+
46
+ Creates the file with a header if it does not exist yet, otherwise appends.
47
+
48
+ Columns
49
+ ───────
50
+ run_id, model_name, created_at,
51
+ <all Config fields except device/csv_path/image_dir/submission_test_dir/output_dir>,
52
+ n_seeds,
53
+ val_* (auc, sensitivity, specificity, accuracy, youden, composite,
54
+ precision, tp, tn, fp, fn, threshold)
55
+ test_* (same set)
56
+
57
+ Parameters
58
+ ──────────
59
+ model_name : human-readable name for this run (e.g. "xrv_densenet_ensemble")
60
+ val_metrics : dict returned by metrics_at_threshold on the validation split
61
+ test_metrics : dict returned by metrics_at_threshold on the test split
62
+ config : the Config instance used for this run
63
+ n_seeds : number of seeds in the ensemble
64
+ log_path : path to the CSV results log (created if missing)
65
+ """
66
+ cfg_dict = dataclasses.asdict(config)
67
+
68
+ # exclude path/device fields β€” not meaningful for comparison
69
+ skip = {"csv_path", "image_dir", "submission_test_dir", "output_dir", "device"}
70
+ hyperparams = {k: v for k, v in cfg_dict.items() if k not in skip}
71
+
72
+ # `seeds` is a list β€” flatten to a string so the CSV stays readable
73
+ if "seeds" in hyperparams and isinstance(hyperparams["seeds"], list):
74
+ hyperparams["seeds"] = ",".join(str(s) for s in hyperparams["seeds"])
75
+
76
+ row: dict = {
77
+ "run_id": datetime.now().strftime("%Y%m%d_%H%M%S"),
78
+ "model_name": model_name,
79
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
80
+ "n_seeds": n_seeds,
81
+ **hyperparams,
82
+ }
83
+
84
+ for prefix, metrics in [("val", val_metrics), ("test", test_metrics)]:
85
+ for key, value in metrics.items():
86
+ if key != "threshold":
87
+ row[f"{prefix}_{key}"] = value
88
+ row[f"{prefix}_threshold"] = metrics.get("threshold", float("nan"))
89
+
90
+ new_row_df = pd.DataFrame([row])
91
+
92
+ if os.path.exists(log_path):
93
+ log_df = pd.read_csv(log_path)
94
+ log_df = pd.concat([log_df, new_row_df], ignore_index=True)
95
+ else:
96
+ log_df = new_row_df
97
+
98
+ log_df.to_csv(log_path, index=False)
99
+ print(f"Run logged β†’ {log_path} ({len(log_df)} total runs)")
100
+ return log_df