Spaces:
Sleeping
Sleeping
cardio-deploy commited on
Commit Β·
08a909f
1
Parent(s): 1acfe01
Deploy CardioScan inference 2026-04-23T12:22:25Z
Browse files- .gitattributes +2 -33
- Dockerfile +68 -0
- README.md +38 -8
- inference_server/README.md +218 -0
- inference_server/requirements.txt +17 -0
- inference_server/run.sh +10 -0
- inference_server/server.py +493 -0
- model_training/notebooks/results/best_model.pth +3 -0
- model_training/notebooks/results/daily_submission-1.csv +177 -0
- model_training/notebooks/results/ensemble_manifest.csv +4 -0
- model_training/notebooks/results/model_seed2024.pth +3 -0
- model_training/notebooks/results/model_seed42.pth +3 -0
- model_training/notebooks/results/model_seed7.pth +3 -0
- model_training/notebooks/results/model_seed8.pth +3 -0
- model_training/notebooks/results/test_metrics_final.json +14 -0
- model_training/notebooks/results/test_predictions.csv +238 -0
- model_training/notebooks/results/training_history.csv +51 -0
- model_training/notebooks/results/val_metrics_final.json +14 -0
- model_training/notebooks/results/val_predictions.csv +238 -0
- model_training/src/__init__.py +1 -0
- model_training/src/config.py +95 -0
- model_training/src/data.py +148 -0
- model_training/src/dataset.py +173 -0
- model_training/src/model.py +205 -0
- model_training/src/train.py +697 -0
- model_training/src/transforms.py +78 -0
- model_training/src/utils.py +100 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,4 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 4 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# syntax=docker/dockerfile:1
|
| 2 |
+
#
|
| 3 |
+
# Hugging Face Spaces (Docker SDK) image for the CardioScan inference server.
|
| 4 |
+
#
|
| 5 |
+
# Spaces conventions:
|
| 6 |
+
# * Container must listen on the port declared in `README.md` (`app_port`).
|
| 7 |
+
# * Container is started as user `user` (UID 1000), not root.
|
| 8 |
+
# * /data is the only writable location persisted across restarts (we don't
|
| 9 |
+
# need persistence β checkpoints ship inside the image).
|
| 10 |
+
#
|
| 11 |
+
# Build context expected by `scripts/deploy-space.sh`:
|
| 12 |
+
# /
|
| 13 |
+
# βββ Dockerfile (this file)
|
| 14 |
+
# βββ README.md (HF Spaces metadata)
|
| 15 |
+
# βββ inference_server/ (FastAPI app + requirements.txt)
|
| 16 |
+
# βββ model_training/
|
| 17 |
+
# βββ src/ (only the package; notebooks excluded)
|
| 18 |
+
# βββ notebooks/results/ (.pth checkpoints + manifest + metrics)
|
| 19 |
+
|
| 20 |
+
FROM python:3.11-slim
|
| 21 |
+
|
| 22 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 23 |
+
PYTHONUNBUFFERED=1 \
|
| 24 |
+
PIP_NO_CACHE_DIR=1 \
|
| 25 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
| 26 |
+
HF_HOME=/app/.cache/huggingface \
|
| 27 |
+
TORCH_HOME=/app/.cache/torch \
|
| 28 |
+
XDG_CACHE_HOME=/app/.cache
|
| 29 |
+
|
| 30 |
+
# System deps for Pillow / scientific Python wheels.
|
| 31 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 32 |
+
libgomp1 \
|
| 33 |
+
ca-certificates \
|
| 34 |
+
curl \
|
| 35 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 36 |
+
|
| 37 |
+
# HF Spaces requires the runtime user to be UID 1000.
|
| 38 |
+
RUN useradd --create-home --uid 1000 user
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
# Install Python deps first so layer caches when only code/weights change.
|
| 42 |
+
COPY --chown=user:user inference_server/requirements.txt /app/inference_server/requirements.txt
|
| 43 |
+
RUN pip install --upgrade pip && \
|
| 44 |
+
pip install --extra-index-url https://download.pytorch.org/whl/cpu \
|
| 45 |
+
torch==2.4.1 torchvision==0.19.1 && \
|
| 46 |
+
pip install -r /app/inference_server/requirements.txt
|
| 47 |
+
|
| 48 |
+
# Copy the actual code and weights.
|
| 49 |
+
COPY --chown=user:user inference_server/ /app/inference_server/
|
| 50 |
+
COPY --chown=user:user model_training/src/ /app/model_training/src/
|
| 51 |
+
COPY --chown=user:user model_training/notebooks/results/ /app/model_training/notebooks/results/
|
| 52 |
+
|
| 53 |
+
# Pre-create writable cache dirs owned by `user`.
|
| 54 |
+
RUN mkdir -p /app/.cache/torch /app/.cache/huggingface && chown -R user:user /app
|
| 55 |
+
|
| 56 |
+
USER user
|
| 57 |
+
|
| 58 |
+
# Spaces routes traffic to whatever app_port we declared in README.md (7860).
|
| 59 |
+
ENV PORT=7860 \
|
| 60 |
+
LOG_LEVEL=INFO \
|
| 61 |
+
MODEL_USE_TTA=true \
|
| 62 |
+
ALLOWED_ORIGIN_REGEX="https://([a-z0-9-]+\\.)*lovable\\.app|https://([a-z0-9-]+\\.)*lovableproject\\.com|https://([a-z0-9-]+\\.)*hf\\.space|http://localhost(:\\d+)?|http://127\\.0\\.0\\.1(:\\d+)?"
|
| 63 |
+
|
| 64 |
+
EXPOSE 7860
|
| 65 |
+
|
| 66 |
+
# server.py expects to be importable from /app/inference_server.
|
| 67 |
+
WORKDIR /app/inference_server
|
| 68 |
+
CMD ["sh", "-c", "uvicorn server:app --host 0.0.0.0 --port ${PORT}"]
|
README.md
CHANGED
|
@@ -1,12 +1,42 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: CardioScan Inference
|
| 3 |
+
emoji: π«
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
|
|
|
| 8 |
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
short_description: Cardiomegaly screening API for the CardioScan frontend.
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# CardioScan Inference
|
| 14 |
+
|
| 15 |
+
FastAPI service that serves the chest X-ray cardiomegaly ensemble trained in
|
| 16 |
+
[heart-scan-helper](https://github.com/) β a torchxrayvision DenseNet-121
|
| 17 |
+
3-seed ensemble with optional 6-pass test-time augmentation.
|
| 18 |
+
|
| 19 |
+
## Endpoints
|
| 20 |
+
|
| 21 |
+
- `GET /health` β readiness + model metadata
|
| 22 |
+
- `POST /predict` β `multipart/form-data`, field name `image`
|
| 23 |
+
- `POST /debug/predict` β per-model, per-TTA logits (debugging)
|
| 24 |
+
|
| 25 |
+
Response shape:
|
| 26 |
+
|
| 27 |
+
```json
|
| 28 |
+
{
|
| 29 |
+
"prediction": "Cardiomegaly",
|
| 30 |
+
"confidence": 0.873,
|
| 31 |
+
"heatmap_url": null,
|
| 32 |
+
"source": "model",
|
| 33 |
+
"threshold": 0.504486,
|
| 34 |
+
"ensemble_size": 3,
|
| 35 |
+
"use_tta": true
|
| 36 |
+
}
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Environment
|
| 40 |
+
|
| 41 |
+
CORS already allows `*.lovable.app`, `*.lovableproject.com`, `*.hf.space` and
|
| 42 |
+
`localhost`. Override with `ALLOWED_ORIGIN_REGEX` if you need to lock it down.
|
inference_server/README.md
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference_server
|
| 2 |
+
|
| 3 |
+
FastAPI service that wraps the **trained ensemble** in
|
| 4 |
+
`model_training/notebooks/results/` and exposes a single `POST /predict`
|
| 5 |
+
endpoint for the React frontend.
|
| 6 |
+
|
| 7 |
+
This folder is strictly an inference layer β **nothing inside
|
| 8 |
+
`model_training/` is modified**. We only import `src.model.build_model` and
|
| 9 |
+
`src.model.cardio_logit` to recreate the architecture before loading the
|
| 10 |
+
saved state dicts.
|
| 11 |
+
|
| 12 |
+
## 1. Install
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
cd inference_server
|
| 16 |
+
python -m venv .venv
|
| 17 |
+
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
| 18 |
+
pip install -r requirements.txt
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
> The backbone is **auto-detected from the first checkpoint** referenced by
|
| 22 |
+
> `ensemble_manifest.csv`, so there is never a mismatch between
|
| 23 |
+
> architecture and weights. The shipped checkpoints in
|
| 24 |
+
> `model_training/notebooks/results/` are `torchxrayvision densenet121`
|
| 25 |
+
> (trained before `CFG.backbone` in `model_training/src/config.py` was
|
| 26 |
+
> changed to `efficientnet_b0`); the server correctly identifies and
|
| 27 |
+
> uses them. You can still force a specific backbone via `MODEL_BACKBONE`
|
| 28 |
+
> if you train a new model β see _Configuration_ below.
|
| 29 |
+
|
| 30 |
+
## 2. Run
|
| 31 |
+
|
| 32 |
+
**You must use the venv** where `pip install -r requirements.txt` was run. If
|
| 33 |
+
you see `ModuleNotFoundError: No module named 'torch'`, you started `uvicorn`
|
| 34 |
+
with the system Python instead of `inference_server/.venv`.
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
cd inference_server
|
| 38 |
+
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
| 39 |
+
uvicorn server:app --host 0.0.0.0 --port 8000
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Or, without activating (always uses the project interpreter):
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
cd inference_server
|
| 46 |
+
./.venv/bin/uvicorn server:app --host 0.0.0.0 --port 8000
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
If port 8000 is still taken, use **8001** (or any free port) and point the
|
| 50 |
+
frontend at the same port:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
PORT=8001 ./run.sh
|
| 54 |
+
# in project root .env:
|
| 55 |
+
# VITE_PREDICT_API_URL=http://127.0.0.1:8001
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
If your shell prompt already shows `inference_server` in the path, you are
|
| 59 |
+
**inside** that folderβdo not run `cd inference_server` again (you will get
|
| 60 |
+
`no such file`).
|
| 61 |
+
|
| 62 |
+
**`[Errno 48] address already in use` on port 8000** means something else is
|
| 63 |
+
already bound there (usually an older uvicorn you forgot to stop). On macOS:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
lsof -i :8000
|
| 67 |
+
# note the PID in the second column, then:
|
| 68 |
+
kill <PID>
|
| 69 |
+
# if it does not exit:
|
| 70 |
+
kill -9 <PID>
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Then start uvicorn again. To use another port without killing the other process
|
| 74 |
+
(e.g. 8001), add `--port 8001` and set `VITE_PREDICT_API_URL` in the frontend
|
| 75 |
+
`.env` to match.
|
| 76 |
+
|
| 77 |
+
> **IPv4/IPv6 gotcha.** Bind with `--host 0.0.0.0` (not `127.0.0.1`). On
|
| 78 |
+
> macOS, browsers often resolve `localhost` to IPv6 `::1` first, and
|
| 79 |
+
> `uvicorn --host 127.0.0.1` only listens on IPv4, which surfaces in the
|
| 80 |
+
> browser as a bare `Network Error` even though `curl 127.0.0.1:8000` works.
|
| 81 |
+
> The frontend's `.env` uses `http://127.0.0.1:8000` (not `localhost`) for
|
| 82 |
+
> the same reason.
|
| 83 |
+
|
| 84 |
+
On startup the server will:
|
| 85 |
+
|
| 86 |
+
1. Read `model_training/notebooks/results/ensemble_manifest.csv`.
|
| 87 |
+
2. For each row, rebuild the EfficientNet-B0 architecture and load the
|
| 88 |
+
corresponding `model_seed*.pth` checkpoint.
|
| 89 |
+
3. Move every model to CUDA / MPS / CPU (auto-detected).
|
| 90 |
+
|
| 91 |
+
If the manifest is missing it falls back to
|
| 92 |
+
`model_training/notebooks/results/best_model.pth`.
|
| 93 |
+
|
| 94 |
+
## 3. Frontend wiring
|
| 95 |
+
|
| 96 |
+
The app's `.env` already points at this server:
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
VITE_PREDICT_API_URL=http://localhost:8000
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
`src/services/predict.ts` posts the uploaded file to `/predict` as
|
| 103 |
+
`multipart/form-data` (field name: `image`). The response shape is exactly
|
| 104 |
+
what the frontend expects:
|
| 105 |
+
|
| 106 |
+
```json
|
| 107 |
+
{
|
| 108 |
+
"prediction": "Cardiomegaly",
|
| 109 |
+
"confidence": 0.873,
|
| 110 |
+
"heatmap_url": null,
|
| 111 |
+
"source": "model",
|
| 112 |
+
"threshold": 0.504486,
|
| 113 |
+
"ensemble_size": 3,
|
| 114 |
+
"use_tta": true
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
The frontend uses `source: "model"` to render a green "Real model" badge on
|
| 119 |
+
each result card, so there is no ambiguity about whether a prediction came
|
| 120 |
+
from the real trained ensemble.
|
| 121 |
+
|
| 122 |
+
## 4. Configuration (env vars)
|
| 123 |
+
|
| 124 |
+
| Variable | Default | Purpose |
|
| 125 |
+
| ------------------ | ---------------------- | -------------------------------------------------------- |
|
| 126 |
+
| `MODEL_BACKBONE` | `CFG.backbone` | Must match the architecture used for training |
|
| 127 |
+
| `MODEL_IMG_SIZE` | `CFG.img_size` | Must match training (224 for EfficientNet-B0, 518 for RAD-DINO) |
|
| 128 |
+
| `MODEL_THRESHOLD` | `val_metrics_final.json::threshold` (fallback `0.5`) | Cut-off used when choosing the label string |
|
| 129 |
+
| `MODEL_USE_TTA` | `true` | `true` β run the 6-pass TTA used at training evaluation time |
|
| 130 |
+
| `ALLOWED_ORIGINS` | localhost dev origins | Comma-separated CORS origins (exact match) |
|
| 131 |
+
| `ALLOWED_ORIGIN_REGEX` | _(unset)_ | Regex for origins, e.g. `https://.*\.lovable\.app` for Lovable preview URLs |
|
| 132 |
+
| `LOG_LEVEL` | `INFO` | Standard Python logging level |
|
| 133 |
+
|
| 134 |
+
Example:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
MODEL_USE_TTA=true MODEL_THRESHOLD=0.504 uvicorn server:app --port 8000
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## 5. Smoke-test
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
curl -s http://localhost:8000/health | jq .
|
| 144 |
+
curl -s -X POST -F "image=@/path/to/xray.png" http://localhost:8000/predict | jq .
|
| 145 |
+
|
| 146 |
+
# Full transparency: per-model + per-TTA raw logits so you can compare
|
| 147 |
+
# against val_predictions.csv / test_predictions.csv in the notebook:
|
| 148 |
+
curl -s -X POST -F "image=@/path/to/xray.png" http://localhost:8000/debug/predict | jq .
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
Every `/predict` call is also logged in the server terminal with the
|
| 152 |
+
filename, per-model mean logits, and the final probability β useful to
|
| 153 |
+
confirm the frontend is actually hitting the server.
|
| 154 |
+
|
| 155 |
+
## 6. Deploying to production (Lovable + separate inference host)
|
| 156 |
+
|
| 157 |
+
Lovable hosts the React frontend, but it cannot run this Python server.
|
| 158 |
+
You need a separate Python host for the inference server. In all cases the
|
| 159 |
+
steps are the same:
|
| 160 |
+
|
| 161 |
+
1. **Push this repo** (including `model_training/notebooks/results/*.pth`)
|
| 162 |
+
to the chosen host. The whole monorepo is self-contained.
|
| 163 |
+
2. **Start command** (the host's "start" or "web" command):
|
| 164 |
+
```bash
|
| 165 |
+
uvicorn server:app --host 0.0.0.0 --port $PORT --app-dir inference_server
|
| 166 |
+
```
|
| 167 |
+
Most PaaS hosts inject `PORT`; if yours doesn't, use `8000`.
|
| 168 |
+
3. **Requirements**: point the host at `inference_server/requirements.txt`.
|
| 169 |
+
4. **Set env vars on the inference host**:
|
| 170 |
+
```
|
| 171 |
+
ALLOWED_ORIGINS=https://your-app.lovable.app,https://your-custom-domain.com
|
| 172 |
+
# Or, if Lovable assigns preview URLs with a hash prefix:
|
| 173 |
+
ALLOWED_ORIGIN_REGEX=https://.*\.lovable\.app
|
| 174 |
+
MODEL_USE_TTA=true
|
| 175 |
+
```
|
| 176 |
+
5. **Set env var on Lovable (frontend)**:
|
| 177 |
+
```
|
| 178 |
+
VITE_PREDICT_API_URL=https://your-inference-server-url
|
| 179 |
+
```
|
| 180 |
+
Then redeploy the Lovable frontend so Vite bakes the new URL into the
|
| 181 |
+
bundle.
|
| 182 |
+
|
| 183 |
+
### Suggested hosts
|
| 184 |
+
|
| 185 |
+
| Host | Free tier | Notes |
|
| 186 |
+
| ---- | --------- | ----- |
|
| 187 |
+
| **Hugging Face Spaces** | 2 vCPU / 16 GB, always-on free | Ideal for this model. Create a Space with the "FastAPI" SDK, push the repo, and point it at `inference_server/server.py`. |
|
| 188 |
+
| **Render.com** | Web service, spins down after 15 min idle | Simple Git-push deploy. Cold start ~30 s while weights load. |
|
| 189 |
+
| **Fly.io** | Shared-CPU 256 MB free | Docker-based. Dockerfile is trivial (Python base β pip install β CMD uvicorn). |
|
| 190 |
+
| **Railway / Modal** | Paid / pay-per-second | Always-on, fastest DX. |
|
| 191 |
+
|
| 192 |
+
### Keep model checkpoints in Git LFS (optional)
|
| 193 |
+
|
| 194 |
+
`model_training/notebooks/results/model_seed*.pth` are ~30 MB each. On
|
| 195 |
+
GitHub free plans this is fine, but you may want to move them to Git LFS
|
| 196 |
+
to keep repo clones small.
|
| 197 |
+
|
| 198 |
+
## 7. Notes
|
| 199 |
+
|
| 200 |
+
- The ensemble is loaded once at startup (one-time cost of a few seconds).
|
| 201 |
+
- **Auto-detection**: the server inspects the first checkpoint on startup and
|
| 202 |
+
picks the matching backbone (`densenet121`, `efficientnet_b0`,
|
| 203 |
+
`mobilenet_v3_large`, or `rad-dino`). No `CFG` / env-var bookkeeping required.
|
| 204 |
+
- **Correct preprocessing per backbone**: the server delegates to
|
| 205 |
+
`model_training/src/dataset.py::get_normalize_fn` so the normalization
|
| 206 |
+
matches training exactly β `xrv_normalize_np` (grayscale, [-1024, 1024]) for
|
| 207 |
+
torchxrayvision DenseNet-121, `imagenet_normalize_np` (3-channel) for every
|
| 208 |
+
other backbone.
|
| 209 |
+
- **No pretrained-weight downloads**: torchvision and torchxrayvision
|
| 210 |
+
constructors are monkey-patched so they skip their pretrained-weight
|
| 211 |
+
download entirely β our trained checkpoint fully overwrites those weights
|
| 212 |
+
anyway. This means the server works offline and in sandboxed environments.
|
| 213 |
+
- **Fail-fast checkpoint loading**: if `state_dict` keys don't match the
|
| 214 |
+
architecture, startup aborts with a clear error listing the mismatch.
|
| 215 |
+
- Each request is ~50β150 ms on CPU without TTA (3 Γ DenseNet-121 forward
|
| 216 |
+
passes). With `MODEL_USE_TTA=true` that becomes ~0.5β1.5 s per image.
|
| 217 |
+
- **Verified**: the server reproduces `notebooks/results/val_predictions.csv`
|
| 218 |
+
probabilities to 6 decimal places (zero delta) on the validation set.
|
inference_server/requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI + ASGI server
|
| 2 |
+
fastapi>=0.110
|
| 3 |
+
uvicorn[standard]>=0.27
|
| 4 |
+
python-multipart>=0.0.9
|
| 5 |
+
|
| 6 |
+
# ML stack (must match `model_training/requirements.txt` so state-dict loads cleanly)
|
| 7 |
+
torch>=2.1.0
|
| 8 |
+
torchvision>=0.16.0
|
| 9 |
+
numpy>=1.24
|
| 10 |
+
pandas>=2.0
|
| 11 |
+
pillow>=9.0
|
| 12 |
+
|
| 13 |
+
# `model_training/src/model.py` imports this at module level
|
| 14 |
+
torchxrayvision>=1.2.0
|
| 15 |
+
|
| 16 |
+
# Needed only if MODEL_BACKBONE=rad-dino; harmless otherwise
|
| 17 |
+
transformers>=4.40.0
|
inference_server/run.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run the API with the project venv (avoids: ModuleNotFoundError: No module named 'torch')
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
cd "$(dirname "$0")"
|
| 5 |
+
if [[ ! -d .venv ]]; then
|
| 6 |
+
echo "No .venv here. First run: python3 -m venv .venv && .venv/bin/pip install -r requirements.txt" >&2
|
| 7 |
+
exit 1
|
| 8 |
+
fi
|
| 9 |
+
PORT="${PORT:-8000}"
|
| 10 |
+
exec .venv/bin/uvicorn server:app --host 0.0.0.0 --port "$PORT" "$@"
|
inference_server/server.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI inference server for the Cardiomegaly classifier.
|
| 2 |
+
|
| 3 |
+
Loads the multi-seed ensemble trained in ``model_training/`` and exposes a
|
| 4 |
+
single ``POST /predict`` endpoint that the frontend (`src/services/predict.ts`)
|
| 5 |
+
already knows how to consume.
|
| 6 |
+
|
| 7 |
+
Nothing inside ``model_training/`` is modified β we only *import* the model
|
| 8 |
+
factory (``src.model.build_model``) to rebuild the exact architecture that was
|
| 9 |
+
saved to disk, then load the weights on top.
|
| 10 |
+
|
| 11 |
+
Run locally
|
| 12 |
+
-----------
|
| 13 |
+
cd inference_server
|
| 14 |
+
pip install -r requirements.txt
|
| 15 |
+
uvicorn server:app --host 0.0.0.0 --port 8000
|
| 16 |
+
|
| 17 |
+
Environment overrides (optional)
|
| 18 |
+
--------------------------------
|
| 19 |
+
MODEL_BACKBONE default: CFG.backbone (e.g. "efficientnet_b0")
|
| 20 |
+
MODEL_IMG_SIZE default: CFG.img_size (e.g. 224)
|
| 21 |
+
MODEL_THRESHOLD default: 0.5 (binary cut-off for the label)
|
| 22 |
+
MODEL_USE_TTA default: "false" ("true" β 6-pass TTA per image)
|
| 23 |
+
ALLOWED_ORIGINS comma-separated CORS origins (exact match)
|
| 24 |
+
ALLOWED_ORIGIN_REGEX regex origin whitelist (e.g. Lovable preview URLs:
|
| 25 |
+
"https://.*\\.lovable\\.app")
|
| 26 |
+
LOG_LEVEL default: INFO
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import io
|
| 32 |
+
import logging
|
| 33 |
+
import os
|
| 34 |
+
import sys
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import List
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import pandas as pd
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
import torchvision.transforms as T
|
| 43 |
+
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 44 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 45 |
+
from PIL import Image
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Paths β make `from src.model import ...` resolvable without touching
|
| 49 |
+
# `model_training/`. We prepend the training directory to sys.path so its
|
| 50 |
+
# internal `from src.config import CFG` style imports keep working.
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 53 |
+
TRAINING_DIR = REPO_ROOT / "model_training"
|
| 54 |
+
NOTEBOOKS_DIR = TRAINING_DIR / "notebooks"
|
| 55 |
+
RESULTS_DIR = NOTEBOOKS_DIR / "results"
|
| 56 |
+
|
| 57 |
+
if str(TRAINING_DIR) not in sys.path:
|
| 58 |
+
sys.path.insert(0, str(TRAINING_DIR))
|
| 59 |
+
|
| 60 |
+
# Point torch's hub cache to a writable in-project location so the server
|
| 61 |
+
# works in sandboxed environments where ``~/.cache`` is read-only. Setting
|
| 62 |
+
# this BEFORE importing torchvision is critical.
|
| 63 |
+
os.environ.setdefault("TORCH_HOME", str(REPO_ROOT / ".torch-cache"))
|
| 64 |
+
|
| 65 |
+
# `build_model` in ``model_training/src/model.py`` constructs torchvision or
|
| 66 |
+
# torchxrayvision backbones WITH their pretrained weights. Those weights are
|
| 67 |
+
# irrelevant at inference time because we immediately overwrite them with the
|
| 68 |
+
# trained checkpoint from ``model_training/notebooks/results/``. We monkey-
|
| 69 |
+
# patch the constructors so the server skips every pretrained-weight
|
| 70 |
+
# download. This avoids needless bandwidth AND cache-dir permission errors
|
| 71 |
+
# when running in sandboxed environments.
|
| 72 |
+
import torchvision.models as _tvm # noqa: E402 pylint: disable=wrong-import-position
|
| 73 |
+
import torchxrayvision as _xrv # noqa: E402 pylint: disable=wrong-import-position
|
| 74 |
+
|
| 75 |
+
for _fn_name in ("efficientnet_b0", "efficientnet_b3", "mobilenet_v3_large"):
|
| 76 |
+
_orig = getattr(_tvm, _fn_name, None)
|
| 77 |
+
if _orig is None:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
def _no_download_builder(*args, __orig=_orig, **kwargs):
|
| 81 |
+
kwargs["weights"] = None
|
| 82 |
+
return __orig(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
setattr(_tvm, _fn_name, _no_download_builder)
|
| 85 |
+
|
| 86 |
+
# torchxrayvision DenseNet also attempts a download when weights="..." is set.
|
| 87 |
+
# We wrap its __init__ so the caller's weights argument is remembered, but
|
| 88 |
+
# the actual download is skipped. We still restore the canonical label list
|
| 89 |
+
# (``self.pathologies`` / ``self.targets``) that downstream code in
|
| 90 |
+
# ``model_training/src/model.py::cardio_logit`` relies on to locate the
|
| 91 |
+
# Cardiomegaly output index.
|
| 92 |
+
_orig_xrv_densenet_init = _xrv.models.DenseNet.__init__
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _xrv_densenet_init_no_download(self, *args, **kwargs):
|
| 96 |
+
requested_weights = kwargs.get("weights")
|
| 97 |
+
kwargs["weights"] = None
|
| 98 |
+
_orig_xrv_densenet_init(self, *args, **kwargs)
|
| 99 |
+
if requested_weights and requested_weights in _xrv.models.model_urls:
|
| 100 |
+
labels = _xrv.models.model_urls[requested_weights]["labels"]
|
| 101 |
+
self.targets = labels
|
| 102 |
+
self.pathologies = labels
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
_xrv.models.DenseNet.__init__ = _xrv_densenet_init_no_download
|
| 106 |
+
|
| 107 |
+
from src.config import CFG # noqa: E402 pylint: disable=wrong-import-position
|
| 108 |
+
from src.model import build_model, cardio_logit # noqa: E402 pylint: disable=wrong-import-position
|
| 109 |
+
from src.dataset import get_normalize_fn # noqa: E402 pylint: disable=wrong-import-position
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _detect_backbone_from_checkpoint(ckpt_path: Path) -> str:
|
| 113 |
+
"""Inspect a saved state_dict and guess which backbone produced it.
|
| 114 |
+
|
| 115 |
+
Rules:
|
| 116 |
+
* torchxrayvision DenseNet-121 β has ``features.denseblockN.*`` keys
|
| 117 |
+
* torchvision EfficientNet β top-level ``features.0.0.weight`` (stem conv)
|
| 118 |
+
and depth β₯ 9 feature groups
|
| 119 |
+
* torchvision MobileNetV3-Large β ``features.0.0.weight`` with depth ~17
|
| 120 |
+
* microsoft/rad-dino β keys under ``features.embeddings`` /
|
| 121 |
+
``features.encoder.layer.``
|
| 122 |
+
Defaults to ``CFG.backbone`` if no signature matches.
|
| 123 |
+
"""
|
| 124 |
+
state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 125 |
+
if isinstance(state, dict) and "state_dict" in state:
|
| 126 |
+
state = state["state_dict"]
|
| 127 |
+
keys = list(state.keys())
|
| 128 |
+
|
| 129 |
+
if any("denseblock" in k for k in keys):
|
| 130 |
+
return "densenet121"
|
| 131 |
+
if any(k.startswith("features.embeddings.") for k in keys) or any(
|
| 132 |
+
k.startswith("features.encoder.layer.") for k in keys
|
| 133 |
+
):
|
| 134 |
+
return "rad-dino"
|
| 135 |
+
# torchvision feature indices
|
| 136 |
+
feature_indices = {
|
| 137 |
+
int(k.split(".")[1])
|
| 138 |
+
for k in keys
|
| 139 |
+
if k.startswith("features.") and k.split(".")[1].isdigit()
|
| 140 |
+
}
|
| 141 |
+
if feature_indices:
|
| 142 |
+
# EfficientNet-B0 has 9 groups (features.0 β¦ features.8)
|
| 143 |
+
# MobileNetV3-Large has 17 groups (features.0 β¦ features.16)
|
| 144 |
+
if max(feature_indices) >= 12:
|
| 145 |
+
return "mobilenet_v3_large"
|
| 146 |
+
if max(feature_indices) >= 7:
|
| 147 |
+
return "efficientnet_b0"
|
| 148 |
+
return CFG.backbone
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Backbone + image size: auto-detected from the checkpoint so the server never
|
| 153 |
+
# runs with a mismatched architecture. Can still be forced via env vars.
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
def _first_checkpoint_path() -> Path:
|
| 156 |
+
manifest = RESULTS_DIR / "ensemble_manifest.csv"
|
| 157 |
+
if manifest.exists():
|
| 158 |
+
df = pd.read_csv(manifest)
|
| 159 |
+
first = df["checkpoint"].iloc[0]
|
| 160 |
+
p = Path(first)
|
| 161 |
+
if p.is_absolute() and p.exists():
|
| 162 |
+
return p
|
| 163 |
+
for candidate in (NOTEBOOKS_DIR / first, RESULTS_DIR / Path(first).name):
|
| 164 |
+
if candidate.exists():
|
| 165 |
+
return candidate
|
| 166 |
+
fallback = RESULTS_DIR / "best_model.pth"
|
| 167 |
+
if fallback.exists():
|
| 168 |
+
return fallback
|
| 169 |
+
raise FileNotFoundError("No checkpoints found under model_training/notebooks/results/")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
_DETECTED_BACKBONE = _detect_backbone_from_checkpoint(_first_checkpoint_path())
|
| 173 |
+
# DenseNet-121 (torchxrayvision) is trained on 224x224; ViT-B/14 needs 518.
|
| 174 |
+
_DEFAULT_IMG_SIZE = 518 if _DETECTED_BACKBONE == "rad-dino" else 224
|
| 175 |
+
|
| 176 |
+
BACKBONE: str = os.environ.get("MODEL_BACKBONE", _DETECTED_BACKBONE)
|
| 177 |
+
IMG_SIZE: int = int(os.environ.get("MODEL_IMG_SIZE", str(_DEFAULT_IMG_SIZE)))
|
| 178 |
+
USE_TTA: bool = os.environ.get("MODEL_USE_TTA", "true").lower() in {"1", "true", "yes"}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _default_threshold() -> float:
|
| 182 |
+
"""Use the training-selected threshold when available."""
|
| 183 |
+
metrics_path = RESULTS_DIR / "val_metrics_final.json"
|
| 184 |
+
if metrics_path.exists():
|
| 185 |
+
try:
|
| 186 |
+
import json
|
| 187 |
+
|
| 188 |
+
with open(metrics_path, "r", encoding="utf-8") as f:
|
| 189 |
+
data = json.load(f)
|
| 190 |
+
thr = float(data.get("threshold", 0.5))
|
| 191 |
+
if 0.0 <= thr <= 1.0:
|
| 192 |
+
return thr
|
| 193 |
+
except Exception: # noqa: BLE001
|
| 194 |
+
pass
|
| 195 |
+
return 0.5
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
DECISION_THRESHOLD: float = float(os.environ.get("MODEL_THRESHOLD", str(_default_threshold())))
|
| 199 |
+
|
| 200 |
+
_DEFAULT_ORIGINS = (
|
| 201 |
+
"http://localhost:3000,"
|
| 202 |
+
"http://localhost:5173,"
|
| 203 |
+
"http://localhost:8080,"
|
| 204 |
+
"http://127.0.0.1:3000,"
|
| 205 |
+
"http://127.0.0.1:5173,"
|
| 206 |
+
"http://127.0.0.1:8080"
|
| 207 |
+
)
|
| 208 |
+
ALLOWED_ORIGINS: list[str] = [
|
| 209 |
+
o.strip()
|
| 210 |
+
for o in os.environ.get("ALLOWED_ORIGINS", _DEFAULT_ORIGINS).split(",")
|
| 211 |
+
if o.strip()
|
| 212 |
+
]
|
| 213 |
+
# Optional regex list β useful when the production frontend is served from a
|
| 214 |
+
# hash-based preview URL (e.g. Lovable / Vercel preview deployments).
|
| 215 |
+
# By default we allow:
|
| 216 |
+
# * any *.lovable.app and *.lovableproject.com subdomain (deployed Lovable apps)
|
| 217 |
+
# * any *.ngrok-free.app / *.ngrok.app / *.ngrok.io subdomain (when the user
|
| 218 |
+
# forwards the dev server through ngrok and previews the app from anywhere)
|
| 219 |
+
# Override with `ALLOWED_ORIGIN_REGEX` to lock things down in production.
|
| 220 |
+
# Include common private LAN dev URLs (Vite "Network" URL is often
|
| 221 |
+
# `http://192.168.x.x:8080` β the Origin header is not localhost, so
|
| 222 |
+
# it must be accepted here or the browser will block with "Network Error").
|
| 223 |
+
_DEFAULT_ORIGIN_REGEX = (
|
| 224 |
+
r"https://([a-z0-9-]+\.)*lovable\.app"
|
| 225 |
+
r"|https://([a-z0-9-]+\.)*lovableproject\.com"
|
| 226 |
+
r"|https://([a-z0-9-]+\.)*ngrok-free\.app"
|
| 227 |
+
r"|https://([a-z0-9-]+\.)*ngrok\.app"
|
| 228 |
+
r"|https://([a-z0-9-]+\.)*ngrok\.io"
|
| 229 |
+
r"|http://(192\.168\.\d{1,3}\.\d{1,3}|10\.\d{1,3}\.\d{1,3}\.\d{1,3}):\d+"
|
| 230 |
+
)
|
| 231 |
+
_ORIGIN_REGEX: str | None = os.environ.get("ALLOWED_ORIGIN_REGEX", _DEFAULT_ORIGIN_REGEX) or None
|
| 232 |
+
|
| 233 |
+
DEVICE: torch.device = torch.device(
|
| 234 |
+
"cuda" if torch.cuda.is_available()
|
| 235 |
+
else "mps" if torch.backends.mps.is_available()
|
| 236 |
+
else "cpu"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
POSITIVE_LABEL = "Cardiomegaly"
|
| 240 |
+
NEGATIVE_LABEL = "No Cardiomegaly indication"
|
| 241 |
+
|
| 242 |
+
# ---------------------------------------------------------------------------
|
| 243 |
+
# Logging
|
| 244 |
+
# ---------------------------------------------------------------------------
|
| 245 |
+
logging.basicConfig(
|
| 246 |
+
level=os.environ.get("LOG_LEVEL", "INFO"),
|
| 247 |
+
format="%(asctime)s %(levelname)-5s %(message)s",
|
| 248 |
+
)
|
| 249 |
+
log = logging.getLogger("inference")
|
| 250 |
+
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
# Preprocessing β delegate to the SAME normalization functions the training
|
| 253 |
+
# dataset uses (`xrv_normalize_np` for densenet121, `imagenet_normalize_np`
|
| 254 |
+
# for every other backbone). This guarantees byte-for-byte identical
|
| 255 |
+
# preprocessing between training and inference.
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
_normalize_fn = get_normalize_fn(BACKBONE)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _pil_hflip(img: Image.Image) -> Image.Image:
|
| 261 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _tta_pipelines(size: int) -> List[T.Compose]:
|
| 265 |
+
"""Match `src.transforms.make_tta_transforms` (6 deterministic passes)."""
|
| 266 |
+
s = (size, size)
|
| 267 |
+
return [
|
| 268 |
+
T.Compose([T.Resize(s)]),
|
| 269 |
+
T.Compose([T.Resize(s), T.Lambda(_pil_hflip)]),
|
| 270 |
+
T.Compose([T.Resize((size + 20, size + 20)), T.CenterCrop(s)]),
|
| 271 |
+
T.Compose([T.Resize((size - 20, size - 20)), T.Pad(10, fill=0), T.CenterCrop(s)]),
|
| 272 |
+
T.Compose([T.Resize(s), T.RandomAffine(degrees=(6, 6), fill=0)]),
|
| 273 |
+
T.Compose([T.Resize(s), T.RandomAffine(degrees=(-6, -6), fill=0)]),
|
| 274 |
+
]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _single_eval_pipeline(size: int) -> T.Compose:
|
| 278 |
+
return T.Compose([T.Resize((size, size))])
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ---------------------------------------------------------------------------
|
| 282 |
+
# Ensemble loading
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
def _resolve_checkpoint(p: str) -> Path:
|
| 285 |
+
"""Manifest paths are stored relative to ``model_training/notebooks/``."""
|
| 286 |
+
path = Path(p)
|
| 287 |
+
if path.is_absolute() and path.exists():
|
| 288 |
+
return path
|
| 289 |
+
for candidate in (NOTEBOOKS_DIR / p, RESULTS_DIR / Path(p).name):
|
| 290 |
+
if candidate.exists():
|
| 291 |
+
return candidate
|
| 292 |
+
raise FileNotFoundError(f"Checkpoint not found: {p!r}")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _load_ensemble() -> List[nn.Module]:
|
| 296 |
+
# Align CFG so build_model() reads the right backbone/size internally.
|
| 297 |
+
CFG.backbone = BACKBONE
|
| 298 |
+
CFG.img_size = IMG_SIZE
|
| 299 |
+
|
| 300 |
+
manifest = RESULTS_DIR / "ensemble_manifest.csv"
|
| 301 |
+
if manifest.exists():
|
| 302 |
+
df = pd.read_csv(manifest)
|
| 303 |
+
checkpoint_paths = [_resolve_checkpoint(p) for p in df["checkpoint"].tolist()]
|
| 304 |
+
log.info("Loading ensemble of %d models from %s", len(checkpoint_paths), manifest.name)
|
| 305 |
+
else:
|
| 306 |
+
best = RESULTS_DIR / "best_model.pth"
|
| 307 |
+
if not best.exists():
|
| 308 |
+
raise FileNotFoundError(
|
| 309 |
+
f"Neither {manifest} nor {best} exist. Train a model before starting the server."
|
| 310 |
+
)
|
| 311 |
+
checkpoint_paths = [best]
|
| 312 |
+
log.info("No manifest found, falling back to single checkpoint: %s", best.name)
|
| 313 |
+
|
| 314 |
+
models: list[nn.Module] = []
|
| 315 |
+
for ckpt_path in checkpoint_paths:
|
| 316 |
+
log.info(" β loading %s", ckpt_path.name)
|
| 317 |
+
model = build_model(BACKBONE)
|
| 318 |
+
state = torch.load(ckpt_path, map_location=DEVICE)
|
| 319 |
+
if isinstance(state, dict) and "state_dict" in state:
|
| 320 |
+
state = state["state_dict"]
|
| 321 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 322 |
+
if missing or unexpected:
|
| 323 |
+
raise RuntimeError(
|
| 324 |
+
"Checkpoint architecture mismatch. "
|
| 325 |
+
f"backbone={BACKBONE!r}, checkpoint={ckpt_path.name!r}, "
|
| 326 |
+
f"missing_keys={len(missing)}, unexpected_keys={len(unexpected)}. "
|
| 327 |
+
"Use the correct MODEL_BACKBONE / MODEL_IMG_SIZE and ensure "
|
| 328 |
+
"ensemble_manifest.csv points to checkpoints from that training run."
|
| 329 |
+
)
|
| 330 |
+
model.to(DEVICE).eval()
|
| 331 |
+
models.append(model)
|
| 332 |
+
|
| 333 |
+
log.info(
|
| 334 |
+
"Ensemble ready β %d model(s) Β· device=%s Β· backbone=%s (detected=%s) Β· "
|
| 335 |
+
"normalize=%s Β· img_size=%d Β· tta=%s Β· threshold=%.4f",
|
| 336 |
+
len(models), DEVICE, BACKBONE, _DETECTED_BACKBONE,
|
| 337 |
+
_normalize_fn.__name__, IMG_SIZE, USE_TTA, DECISION_THRESHOLD,
|
| 338 |
+
)
|
| 339 |
+
return models
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ---------------------------------------------------------------------------
|
| 343 |
+
# FastAPI app
|
| 344 |
+
# ---------------------------------------------------------------------------
|
| 345 |
+
app = FastAPI(title="CardioScan inference", version="1.0")
|
| 346 |
+
|
| 347 |
+
app.add_middleware(
|
| 348 |
+
CORSMiddleware,
|
| 349 |
+
allow_origins=ALLOWED_ORIGINS,
|
| 350 |
+
allow_origin_regex=_ORIGIN_REGEX,
|
| 351 |
+
allow_credentials=True,
|
| 352 |
+
allow_methods=["*"],
|
| 353 |
+
allow_headers=["*"],
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
_ensemble: list[nn.Module] = []
|
| 357 |
+
_loaded_checkpoints: list[str] = []
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@app.on_event("startup")
|
| 361 |
+
def _startup() -> None:
|
| 362 |
+
global _ensemble, _loaded_checkpoints
|
| 363 |
+
manifest = RESULTS_DIR / "ensemble_manifest.csv"
|
| 364 |
+
if manifest.exists():
|
| 365 |
+
df = pd.read_csv(manifest)
|
| 366 |
+
_loaded_checkpoints = [Path(p).name for p in df["checkpoint"].tolist()]
|
| 367 |
+
else:
|
| 368 |
+
_loaded_checkpoints = ["best_model.pth"]
|
| 369 |
+
_ensemble = _load_ensemble()
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@app.get("/health")
|
| 373 |
+
def health() -> dict:
|
| 374 |
+
return {
|
| 375 |
+
"ok": bool(_ensemble),
|
| 376 |
+
"models": len(_ensemble),
|
| 377 |
+
"checkpoints": _loaded_checkpoints,
|
| 378 |
+
"backbone": BACKBONE,
|
| 379 |
+
"detected_backbone": _DETECTED_BACKBONE,
|
| 380 |
+
"normalization": _normalize_fn.__name__,
|
| 381 |
+
"img_size": IMG_SIZE,
|
| 382 |
+
"device": str(DEVICE),
|
| 383 |
+
"use_tta": USE_TTA,
|
| 384 |
+
"threshold": DECISION_THRESHOLD,
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@torch.no_grad()
|
| 389 |
+
def _predict_probability_detailed(pil_gray: Image.Image) -> dict:
|
| 390 |
+
"""Run ensemble (+ optional TTA) on a single PIL image.
|
| 391 |
+
|
| 392 |
+
Returns a dict with per-model / per-TTA logits for transparency.
|
| 393 |
+
Matches `tta_predict` / `tta_predict_ensemble` in ``src.train`` exactly:
|
| 394 |
+
average logits across TTA (per model), then average across models,
|
| 395 |
+
then sigmoid.
|
| 396 |
+
"""
|
| 397 |
+
pipelines = _tta_pipelines(IMG_SIZE) if USE_TTA else [_single_eval_pipeline(IMG_SIZE)]
|
| 398 |
+
|
| 399 |
+
tensors = [_normalize_fn(pipeline(pil_gray)) for pipeline in pipelines]
|
| 400 |
+
batch = torch.stack(tensors, dim=0).to(DEVICE) # (num_tta, 3, H, W)
|
| 401 |
+
|
| 402 |
+
per_model_tta_logits: list[np.ndarray] = []
|
| 403 |
+
per_model_mean_logit: list[float] = []
|
| 404 |
+
for model in _ensemble:
|
| 405 |
+
logit_vec = cardio_logit(model, batch).float().cpu().numpy() # (num_tta,)
|
| 406 |
+
per_model_tta_logits.append(logit_vec)
|
| 407 |
+
per_model_mean_logit.append(float(np.mean(logit_vec)))
|
| 408 |
+
|
| 409 |
+
ensemble_mean_logit = float(np.mean(per_model_mean_logit))
|
| 410 |
+
probability = float(1.0 / (1.0 + np.exp(-ensemble_mean_logit)))
|
| 411 |
+
|
| 412 |
+
return {
|
| 413 |
+
"probability": probability,
|
| 414 |
+
"ensemble_mean_logit": ensemble_mean_logit,
|
| 415 |
+
"per_model_mean_logit": {
|
| 416 |
+
name: lg for name, lg in zip(_loaded_checkpoints, per_model_mean_logit)
|
| 417 |
+
},
|
| 418 |
+
"per_model_tta_logits": {
|
| 419 |
+
name: lg.tolist() for name, lg in zip(_loaded_checkpoints, per_model_tta_logits)
|
| 420 |
+
},
|
| 421 |
+
"num_tta_passes": batch.shape[0],
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@app.post("/predict")
|
| 426 |
+
async def predict(image: UploadFile = File(...)) -> dict:
|
| 427 |
+
if not _ensemble:
|
| 428 |
+
raise HTTPException(status_code=503, detail="Model not ready")
|
| 429 |
+
|
| 430 |
+
raw = await image.read()
|
| 431 |
+
if not raw:
|
| 432 |
+
raise HTTPException(status_code=400, detail="Empty upload")
|
| 433 |
+
|
| 434 |
+
try:
|
| 435 |
+
pil = Image.open(io.BytesIO(raw)).convert("L")
|
| 436 |
+
except Exception as exc: # noqa: BLE001
|
| 437 |
+
raise HTTPException(status_code=400, detail=f"Could not decode image: {exc}") from exc
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
details = _predict_probability_detailed(pil)
|
| 441 |
+
except Exception as exc: # noqa: BLE001
|
| 442 |
+
log.exception("Inference failed")
|
| 443 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
|
| 444 |
+
|
| 445 |
+
probability = details["probability"]
|
| 446 |
+
is_positive = probability >= DECISION_THRESHOLD
|
| 447 |
+
|
| 448 |
+
log.info(
|
| 449 |
+
"/predict file=%s size=%d prob=%.4f thr=%.4f -> %s (per-model=%s, tta=%d)",
|
| 450 |
+
image.filename,
|
| 451 |
+
len(raw),
|
| 452 |
+
probability,
|
| 453 |
+
DECISION_THRESHOLD,
|
| 454 |
+
"Cardiomegaly" if is_positive else "Negative",
|
| 455 |
+
{k: round(v, 4) for k, v in details["per_model_mean_logit"].items()},
|
| 456 |
+
details["num_tta_passes"],
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
return {
|
| 460 |
+
"prediction": POSITIVE_LABEL if is_positive else NEGATIVE_LABEL,
|
| 461 |
+
"confidence": probability,
|
| 462 |
+
"heatmap_url": None,
|
| 463 |
+
"source": "model",
|
| 464 |
+
"threshold": DECISION_THRESHOLD,
|
| 465 |
+
"ensemble_size": len(_ensemble),
|
| 466 |
+
"use_tta": USE_TTA,
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@app.post("/debug/predict")
|
| 471 |
+
async def debug_predict(image: UploadFile = File(...)) -> dict:
|
| 472 |
+
"""Same as /predict but returns per-model and per-TTA raw logits for
|
| 473 |
+
verification against the training notebook's val/test CSVs."""
|
| 474 |
+
if not _ensemble:
|
| 475 |
+
raise HTTPException(status_code=503, detail="Model not ready")
|
| 476 |
+
|
| 477 |
+
raw = await image.read()
|
| 478 |
+
if not raw:
|
| 479 |
+
raise HTTPException(status_code=400, detail="Empty upload")
|
| 480 |
+
|
| 481 |
+
try:
|
| 482 |
+
pil = Image.open(io.BytesIO(raw)).convert("L")
|
| 483 |
+
except Exception as exc: # noqa: BLE001
|
| 484 |
+
raise HTTPException(status_code=400, detail=f"Could not decode image: {exc}") from exc
|
| 485 |
+
|
| 486 |
+
details = _predict_probability_detailed(pil)
|
| 487 |
+
details["prediction"] = (
|
| 488 |
+
POSITIVE_LABEL if details["probability"] >= DECISION_THRESHOLD else NEGATIVE_LABEL
|
| 489 |
+
)
|
| 490 |
+
details["threshold"] = DECISION_THRESHOLD
|
| 491 |
+
details["use_tta"] = USE_TTA
|
| 492 |
+
details["checkpoints"] = _loaded_checkpoints
|
| 493 |
+
return details
|
model_training/notebooks/results/best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bede2c09fc26c3fb5260c036ef0a0517d21de942dc509bf4425c0a0919bab7b
|
| 3 |
+
size 43341137
|
model_training/notebooks/results/daily_submission-1.csv
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_file,prob,pred
|
| 2 |
+
IMG_000018.png,0.816906,1
|
| 3 |
+
IMG_000021.png,0.009301819,0
|
| 4 |
+
IMG_000039.png,0.33299014,0
|
| 5 |
+
IMG_000044.png,0.11926782,0
|
| 6 |
+
IMG_000047.png,0.3057367,0
|
| 7 |
+
IMG_000054.png,0.9174445,1
|
| 8 |
+
IMG_000077.png,0.0068948073,0
|
| 9 |
+
IMG_000094.png,0.010871509,0
|
| 10 |
+
IMG_000096.png,0.97423095,1
|
| 11 |
+
IMG_000108.png,0.03271585,0
|
| 12 |
+
IMG_000111.png,0.37760112,0
|
| 13 |
+
IMG_000128.png,0.8483817,1
|
| 14 |
+
IMG_000129.png,0.98759496,1
|
| 15 |
+
IMG_000135.png,0.8694341,1
|
| 16 |
+
IMG_000149.png,0.20452695,0
|
| 17 |
+
IMG_000160.png,0.8993582,1
|
| 18 |
+
IMG_000174.png,0.9729403,1
|
| 19 |
+
IMG_000180.png,0.3869212,0
|
| 20 |
+
IMG_000232.png,0.10533079,0
|
| 21 |
+
IMG_000242.png,0.064271316,0
|
| 22 |
+
IMG_000248.png,0.055668276,0
|
| 23 |
+
IMG_000249.png,6.6870154e-05,0
|
| 24 |
+
IMG_000255.png,0.011946698,0
|
| 25 |
+
IMG_000273.png,0.0054892018,0
|
| 26 |
+
IMG_000277.png,0.9634049,1
|
| 27 |
+
IMG_000299.png,0.96502954,1
|
| 28 |
+
IMG_000302.png,0.047373675,0
|
| 29 |
+
IMG_000309.png,0.9813534,1
|
| 30 |
+
IMG_000313.png,0.008047514,0
|
| 31 |
+
IMG_000314.png,0.008340702,0
|
| 32 |
+
IMG_000324.png,0.96767646,1
|
| 33 |
+
IMG_000343.png,0.023570525,0
|
| 34 |
+
IMG_000346.png,0.9816032,1
|
| 35 |
+
IMG_000353.png,0.45696247,0
|
| 36 |
+
IMG_000367.png,0.35300097,0
|
| 37 |
+
IMG_000404.png,0.4996281,0
|
| 38 |
+
IMG_000408.png,0.8722059,1
|
| 39 |
+
IMG_000415.png,0.74312115,1
|
| 40 |
+
IMG_000416.png,0.94917625,1
|
| 41 |
+
IMG_000420.png,0.22852376,0
|
| 42 |
+
IMG_000423.png,0.7794522,1
|
| 43 |
+
IMG_000429.png,0.8098528,1
|
| 44 |
+
IMG_000432.png,0.9711745,1
|
| 45 |
+
IMG_000479.png,0.909901,1
|
| 46 |
+
IMG_000483.png,0.832377,1
|
| 47 |
+
IMG_000485.png,0.744525,1
|
| 48 |
+
IMG_000522.png,0.9405937,1
|
| 49 |
+
IMG_000527.png,0.8822646,1
|
| 50 |
+
IMG_000532.png,0.97545457,1
|
| 51 |
+
IMG_000536.png,0.89281857,1
|
| 52 |
+
IMG_000538.png,0.9655915,1
|
| 53 |
+
IMG_000547.png,0.3647764,0
|
| 54 |
+
IMG_000548.png,0.47809702,0
|
| 55 |
+
IMG_000549.png,0.97563666,1
|
| 56 |
+
IMG_000565.png,0.074912064,0
|
| 57 |
+
IMG_000568.png,0.08629721,0
|
| 58 |
+
IMG_000575.png,0.9126054,1
|
| 59 |
+
IMG_000576.png,0.76525265,1
|
| 60 |
+
IMG_000579.png,0.00989682,0
|
| 61 |
+
IMG_000580.png,0.036182728,0
|
| 62 |
+
IMG_000593.png,0.90472263,1
|
| 63 |
+
IMG_000622.png,0.002541845,0
|
| 64 |
+
IMG_000627.png,0.0029941453,0
|
| 65 |
+
IMG_000632.png,0.07461243,0
|
| 66 |
+
IMG_000633.png,0.07481881,0
|
| 67 |
+
IMG_000653.png,0.18598385,0
|
| 68 |
+
IMG_000660.png,0.20837037,0
|
| 69 |
+
IMG_000671.png,0.0016912636,0
|
| 70 |
+
IMG_000675.png,0.1421747,0
|
| 71 |
+
IMG_000683.png,0.025120575,0
|
| 72 |
+
IMG_000688.png,0.57493603,1
|
| 73 |
+
IMG_000694.png,0.94961226,1
|
| 74 |
+
IMG_000696.png,0.3008333,0
|
| 75 |
+
IMG_000708.png,0.012850549,0
|
| 76 |
+
IMG_000710.png,0.9830056,1
|
| 77 |
+
IMG_000773.png,0.0031617412,0
|
| 78 |
+
IMG_000785.png,0.82301766,1
|
| 79 |
+
IMG_000787.png,0.9569242,1
|
| 80 |
+
IMG_000792.png,0.9786517,1
|
| 81 |
+
IMG_000797.png,0.0004634541,0
|
| 82 |
+
IMG_000821.png,0.024412306,0
|
| 83 |
+
IMG_000854.png,0.19797583,0
|
| 84 |
+
IMG_000862.png,0.9750999,1
|
| 85 |
+
IMG_000870.png,0.8686767,1
|
| 86 |
+
IMG_000880.png,0.0010660468,0
|
| 87 |
+
IMG_000893.png,0.96038836,1
|
| 88 |
+
IMG_000910.png,0.07809025,0
|
| 89 |
+
IMG_000914.png,0.26980433,0
|
| 90 |
+
IMG_000917.png,0.98292184,1
|
| 91 |
+
IMG_000925.png,0.0027379196,0
|
| 92 |
+
IMG_000936.png,0.96947944,1
|
| 93 |
+
IMG_000942.png,0.82471883,1
|
| 94 |
+
IMG_000945.png,0.00028870278,0
|
| 95 |
+
IMG_000946.png,0.07984374,0
|
| 96 |
+
IMG_000951.png,0.26128584,0
|
| 97 |
+
IMG_000953.png,0.0010449657,0
|
| 98 |
+
IMG_000981.png,0.090446874,0
|
| 99 |
+
IMG_000985.png,0.018414477,0
|
| 100 |
+
IMG_000988.png,0.9778957,1
|
| 101 |
+
IMG_000991.png,0.036471717,0
|
| 102 |
+
IMG_001026.png,0.59295475,1
|
| 103 |
+
IMG_001027.png,0.9421946,1
|
| 104 |
+
IMG_001056.png,0.5752065,1
|
| 105 |
+
IMG_001063.png,0.028953826,0
|
| 106 |
+
IMG_001071.png,0.024936734,0
|
| 107 |
+
IMG_001083.png,0.95445114,1
|
| 108 |
+
IMG_001084.png,0.90794367,1
|
| 109 |
+
IMG_001085.png,0.7487882,1
|
| 110 |
+
IMG_001100.png,0.92409223,1
|
| 111 |
+
IMG_001112.png,0.1766397,0
|
| 112 |
+
IMG_001113.png,0.93751234,1
|
| 113 |
+
IMG_001121.png,0.80051255,1
|
| 114 |
+
IMG_001148.png,0.1581435,0
|
| 115 |
+
IMG_001154.png,0.001443795,0
|
| 116 |
+
IMG_001155.png,0.99754816,1
|
| 117 |
+
IMG_001158.png,0.21377504,0
|
| 118 |
+
IMG_001160.png,0.8512862,1
|
| 119 |
+
IMG_001162.png,0.55319506,1
|
| 120 |
+
IMG_001166.png,0.013784781,0
|
| 121 |
+
IMG_001168.png,0.6532163,1
|
| 122 |
+
IMG_001179.png,0.11473991,0
|
| 123 |
+
IMG_001180.png,0.057606816,0
|
| 124 |
+
IMG_001184.png,0.3746491,0
|
| 125 |
+
IMG_001205.png,0.0096385535,0
|
| 126 |
+
IMG_001206.png,0.3501062,0
|
| 127 |
+
IMG_001239.png,0.016910465,0
|
| 128 |
+
IMG_001259.png,0.9424403,1
|
| 129 |
+
IMG_001271.png,0.8103772,1
|
| 130 |
+
IMG_001290.png,0.22025856,0
|
| 131 |
+
IMG_001291.png,0.133331,0
|
| 132 |
+
IMG_001304.png,0.9841086,1
|
| 133 |
+
IMG_001317.png,0.75506747,1
|
| 134 |
+
IMG_001321.png,0.014835258,0
|
| 135 |
+
IMG_001334.png,0.03285374,0
|
| 136 |
+
IMG_001336.png,0.36698562,0
|
| 137 |
+
IMG_001349.png,0.03645498,0
|
| 138 |
+
IMG_001352.png,0.72149867,1
|
| 139 |
+
IMG_001363.png,0.00038983554,0
|
| 140 |
+
IMG_001365.png,0.6854751,1
|
| 141 |
+
IMG_001368.png,0.9896269,1
|
| 142 |
+
IMG_001384.png,0.71602064,1
|
| 143 |
+
IMG_001393.png,0.017989887,0
|
| 144 |
+
IMG_001394.png,0.0023077542,0
|
| 145 |
+
IMG_001405.png,0.8804799,1
|
| 146 |
+
IMG_001412.png,0.5575575,1
|
| 147 |
+
IMG_001417.png,0.71472263,1
|
| 148 |
+
IMG_001436.png,0.9894505,1
|
| 149 |
+
IMG_001443.png,0.00088119216,0
|
| 150 |
+
IMG_001495.png,0.005760317,0
|
| 151 |
+
IMG_001504.png,0.017389266,0
|
| 152 |
+
IMG_001506.png,0.012142896,0
|
| 153 |
+
IMG_001508.png,0.81739867,1
|
| 154 |
+
IMG_001511.png,0.9034313,1
|
| 155 |
+
IMG_001528.png,0.025507784,0
|
| 156 |
+
IMG_001536.png,0.5163997,1
|
| 157 |
+
IMG_001570.png,0.014479928,0
|
| 158 |
+
IMG_001572.png,0.004765836,0
|
| 159 |
+
IMG_001577.png,0.03651942,0
|
| 160 |
+
IMG_001581.png,0.0067347363,0
|
| 161 |
+
IMG_001588.png,0.7685501,1
|
| 162 |
+
IMG_001593.png,0.92315125,1
|
| 163 |
+
IMG_001605.png,0.028199496,0
|
| 164 |
+
IMG_001606.png,0.092707306,0
|
| 165 |
+
IMG_001621.png,0.00926323,0
|
| 166 |
+
IMG_001628.png,0.33566543,0
|
| 167 |
+
IMG_001638.png,0.0071623144,0
|
| 168 |
+
IMG_001645.png,0.0010000179,0
|
| 169 |
+
IMG_001671.png,0.0020201257,0
|
| 170 |
+
IMG_001684.png,0.87558657,1
|
| 171 |
+
IMG_001688.png,0.86326087,1
|
| 172 |
+
IMG_001702.png,0.014059455,0
|
| 173 |
+
IMG_001710.png,0.97905207,1
|
| 174 |
+
IMG_001718.png,0.00522173,0
|
| 175 |
+
IMG_001726.png,0.26297852,0
|
| 176 |
+
IMG_001731.png,0.56691897,1
|
| 177 |
+
IMG_001743.png,0.9594168,1
|
model_training/notebooks/results/ensemble_manifest.csv
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed,best_val_auc,checkpoint
|
| 2 |
+
8,0.9496841803043353,results/model_seed8.pth
|
| 3 |
+
7,0.9425782371518806,results/model_seed7.pth
|
| 4 |
+
2024,0.94717197817973,results/model_seed2024.pth
|
model_training/notebooks/results/model_seed2024.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f97f75f36fc98bb7a585e9167b2494fbb0e39b59c57faccf818021bace913189
|
| 3 |
+
size 28464092
|
model_training/notebooks/results/model_seed42.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a10caab6adc4408c708209adebe0293a060943c3156b06ed1c19f794a7d8504c
|
| 3 |
+
size 17015941
|
model_training/notebooks/results/model_seed7.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd24fddc1efd99010271a20b31d316f9459213cd5fbc5142654737ca825b6577
|
| 3 |
+
size 28461893
|
model_training/notebooks/results/model_seed8.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94f02549afb37082890d35231782b2a61117f535c752c8acd6881a69ebe880bb
|
| 3 |
+
size 28461893
|
model_training/notebooks/results/test_metrics_final.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"threshold": 0.504486,
|
| 3 |
+
"auc": 0.9391329313809934,
|
| 4 |
+
"sensitivity": 0.8703703703703703,
|
| 5 |
+
"specificity": 0.8527131782945736,
|
| 6 |
+
"accuracy": 0.8607594936708861,
|
| 7 |
+
"youden": 0.7230835486649441,
|
| 8 |
+
"composite": 0.9003373528567327,
|
| 9 |
+
"precision": 0.831858407079646,
|
| 10 |
+
"tp": 94,
|
| 11 |
+
"tn": 110,
|
| 12 |
+
"fp": 19,
|
| 13 |
+
"fn": 14
|
| 14 |
+
}
|
model_training/notebooks/results/test_predictions.csv
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
filename,y_true,prob,pred,correct
|
| 2 |
+
IMG_000824.png,0,0.021219134,0,1
|
| 3 |
+
IMG_000990.png,1,0.7771948,1,1
|
| 4 |
+
IMG_001690.png,1,0.85049516,1,1
|
| 5 |
+
IMG_000879.png,1,0.8450817,1,1
|
| 6 |
+
IMG_000884.png,0,0.11847936,0,1
|
| 7 |
+
IMG_001538.png,1,0.8167436,1,1
|
| 8 |
+
IMG_001404.png,1,0.799158,1,1
|
| 9 |
+
IMG_001043.png,0,0.030076103,0,1
|
| 10 |
+
IMG_001194.png,0,0.20013258,0,1
|
| 11 |
+
IMG_000345.png,0,0.003733192,0,1
|
| 12 |
+
IMG_001019.png,0,0.4092049,0,1
|
| 13 |
+
IMG_000233.png,1,0.754443,1,1
|
| 14 |
+
IMG_000994.png,1,0.8932385,1,1
|
| 15 |
+
IMG_000219.png,0,0.74411124,1,0
|
| 16 |
+
IMG_000758.png,0,0.88436824,1,0
|
| 17 |
+
IMG_001127.png,0,0.0010003505,0,1
|
| 18 |
+
IMG_000185.png,1,0.43144453,0,0
|
| 19 |
+
IMG_000344.png,1,0.7200385,1,1
|
| 20 |
+
IMG_000462.png,1,0.6866032,1,1
|
| 21 |
+
IMG_001243.png,0,0.005522695,0,1
|
| 22 |
+
IMG_000110.png,0,0.00616389,0,1
|
| 23 |
+
IMG_000178.png,1,0.38014814,0,0
|
| 24 |
+
IMG_001282.png,0,0.047400977,0,1
|
| 25 |
+
IMG_000451.png,1,0.95899737,1,1
|
| 26 |
+
IMG_001312.png,1,0.87483233,1,1
|
| 27 |
+
IMG_001298.png,1,0.89103687,1,1
|
| 28 |
+
IMG_000514.png,1,0.9204423,1,1
|
| 29 |
+
IMG_001297.png,0,0.026098764,0,1
|
| 30 |
+
IMG_000941.png,1,0.84286815,1,1
|
| 31 |
+
IMG_000218.png,1,0.32278225,0,0
|
| 32 |
+
IMG_001117.png,0,0.42146167,0,1
|
| 33 |
+
IMG_000001.png,0,0.0846162,0,1
|
| 34 |
+
IMG_001105.png,0,0.34927747,0,1
|
| 35 |
+
IMG_000079.png,0,0.0015609249,0,1
|
| 36 |
+
IMG_000400.png,0,0.06923158,0,1
|
| 37 |
+
IMG_000260.png,1,0.9263501,1,1
|
| 38 |
+
IMG_000998.png,1,0.6549944,1,1
|
| 39 |
+
IMG_001281.png,1,0.71612954,1,1
|
| 40 |
+
IMG_000690.png,0,0.030174967,0,1
|
| 41 |
+
IMG_001627.png,1,0.95443106,1,1
|
| 42 |
+
IMG_001107.png,1,0.7405456,1,1
|
| 43 |
+
IMG_000763.png,0,0.14266923,0,1
|
| 44 |
+
IMG_000545.png,0,0.022861702,0,1
|
| 45 |
+
IMG_001343.png,1,0.78821915,1,1
|
| 46 |
+
IMG_001175.png,1,0.8847534,1,1
|
| 47 |
+
IMG_000682.png,1,0.18611808,0,0
|
| 48 |
+
IMG_001398.png,0,0.010194484,0,1
|
| 49 |
+
IMG_001556.png,1,0.80163234,1,1
|
| 50 |
+
IMG_001225.png,1,0.95753545,1,1
|
| 51 |
+
IMG_001460.png,0,0.39417872,0,1
|
| 52 |
+
IMG_001746.png,1,0.822455,1,1
|
| 53 |
+
IMG_001406.png,1,0.9840471,1,1
|
| 54 |
+
IMG_001052.png,0,0.24067384,0,1
|
| 55 |
+
IMG_001640.png,0,0.022804145,0,1
|
| 56 |
+
IMG_001728.png,1,0.9836304,1,1
|
| 57 |
+
IMG_001592.png,0,0.36900538,0,1
|
| 58 |
+
IMG_001473.png,1,0.9562594,1,1
|
| 59 |
+
IMG_000376.png,0,0.007365662,0,1
|
| 60 |
+
IMG_001390.png,0,0.019877983,0,1
|
| 61 |
+
IMG_000169.png,0,0.030604592,0,1
|
| 62 |
+
IMG_001310.png,0,0.011923165,0,1
|
| 63 |
+
IMG_000151.png,0,0.6607677,1,0
|
| 64 |
+
IMG_001240.png,0,0.013350104,0,1
|
| 65 |
+
IMG_000864.png,1,0.7523066,1,1
|
| 66 |
+
IMG_001741.png,1,0.9743217,1,1
|
| 67 |
+
IMG_001364.png,1,0.35030937,0,0
|
| 68 |
+
IMG_000121.png,0,0.0022226572,0,1
|
| 69 |
+
IMG_000359.png,0,0.004687663,0,1
|
| 70 |
+
IMG_001075.png,1,0.8264759,1,1
|
| 71 |
+
IMG_000187.png,0,0.17870605,0,1
|
| 72 |
+
IMG_000237.png,1,0.9433076,1,1
|
| 73 |
+
IMG_001562.png,1,0.96882564,1,1
|
| 74 |
+
IMG_001114.png,1,0.98810416,1,1
|
| 75 |
+
IMG_000223.png,0,0.021592166,0,1
|
| 76 |
+
IMG_001745.png,0,0.95945805,1,0
|
| 77 |
+
IMG_000826.png,0,0.029931583,0,1
|
| 78 |
+
IMG_001704.png,0,0.23410255,0,1
|
| 79 |
+
IMG_000406.png,0,0.14752638,0,1
|
| 80 |
+
IMG_000634.png,1,0.6518697,1,1
|
| 81 |
+
IMG_001238.png,0,0.08548438,0,1
|
| 82 |
+
IMG_000993.png,1,0.7866735,1,1
|
| 83 |
+
IMG_001736.png,0,0.40755022,0,1
|
| 84 |
+
IMG_000481.png,1,0.001395701,0,0
|
| 85 |
+
IMG_001395.png,1,0.7820107,1,1
|
| 86 |
+
IMG_000553.png,0,0.02335425,0,1
|
| 87 |
+
IMG_000754.png,1,0.7037294,1,1
|
| 88 |
+
IMG_001018.png,0,0.95035726,1,0
|
| 89 |
+
IMG_001492.png,1,0.88814825,1,1
|
| 90 |
+
IMG_000737.png,0,0.00022573896,0,1
|
| 91 |
+
IMG_001617.png,0,0.017841075,0,1
|
| 92 |
+
IMG_000411.png,1,0.46054697,0,0
|
| 93 |
+
IMG_000764.png,1,0.9863303,1,1
|
| 94 |
+
IMG_001198.png,0,0.53324735,1,0
|
| 95 |
+
IMG_001237.png,1,0.40191436,0,0
|
| 96 |
+
IMG_001639.png,1,0.6333021,1,1
|
| 97 |
+
IMG_000829.png,1,0.9704466,1,1
|
| 98 |
+
IMG_001192.png,1,0.8689759,1,1
|
| 99 |
+
IMG_001143.png,0,0.12318405,0,1
|
| 100 |
+
IMG_000860.png,0,0.22503993,0,1
|
| 101 |
+
IMG_001308.png,1,0.9235162,1,1
|
| 102 |
+
IMG_000067.png,1,0.8142528,1,1
|
| 103 |
+
IMG_000899.png,0,0.0091903545,0,1
|
| 104 |
+
IMG_000602.png,1,0.7867315,1,1
|
| 105 |
+
IMG_000022.png,0,0.12565576,0,1
|
| 106 |
+
IMG_001553.png,0,0.5937602,1,0
|
| 107 |
+
IMG_000689.png,0,0.0035498547,0,1
|
| 108 |
+
IMG_001719.png,1,0.97812426,1,1
|
| 109 |
+
IMG_000074.png,0,0.017582327,0,1
|
| 110 |
+
IMG_001584.png,0,0.29081395,0,1
|
| 111 |
+
IMG_000127.png,0,0.13165514,0,1
|
| 112 |
+
IMG_000204.png,0,0.044497356,0,1
|
| 113 |
+
IMG_001663.png,1,0.7866542,1,1
|
| 114 |
+
IMG_001116.png,0,0.031116765,0,1
|
| 115 |
+
IMG_000964.png,1,0.9310901,1,1
|
| 116 |
+
IMG_001677.png,1,0.927338,1,1
|
| 117 |
+
IMG_001418.png,1,0.8989984,1,1
|
| 118 |
+
IMG_000974.png,1,0.9450868,1,1
|
| 119 |
+
IMG_000477.png,1,0.962653,1,1
|
| 120 |
+
IMG_001126.png,0,0.00042581963,0,1
|
| 121 |
+
IMG_001373.png,1,0.84743226,1,1
|
| 122 |
+
IMG_000589.png,1,0.8598813,1,1
|
| 123 |
+
IMG_001389.png,1,0.7277787,1,1
|
| 124 |
+
IMG_001292.png,0,0.30514425,0,1
|
| 125 |
+
IMG_000240.png,1,0.9254172,1,1
|
| 126 |
+
IMG_000437.png,0,0.07736599,0,1
|
| 127 |
+
IMG_000751.png,0,0.21614793,0,1
|
| 128 |
+
IMG_001278.png,0,0.00081374723,0,1
|
| 129 |
+
IMG_000724.png,1,0.8973737,1,1
|
| 130 |
+
IMG_000944.png,1,0.9696045,1,1
|
| 131 |
+
IMG_000930.png,1,0.06275673,0,0
|
| 132 |
+
IMG_001586.png,1,0.85567933,1,1
|
| 133 |
+
IMG_001415.png,0,0.42133018,0,1
|
| 134 |
+
IMG_000235.png,0,0.00044268242,0,1
|
| 135 |
+
IMG_000789.png,0,0.51844907,1,0
|
| 136 |
+
IMG_001722.png,0,0.02285177,0,1
|
| 137 |
+
IMG_000457.png,1,0.9203535,1,1
|
| 138 |
+
IMG_001385.png,1,0.33224016,0,0
|
| 139 |
+
IMG_001545.png,0,0.33094287,0,1
|
| 140 |
+
IMG_001327.png,0,0.0026560111,0,1
|
| 141 |
+
IMG_000693.png,1,0.7572429,1,1
|
| 142 |
+
IMG_001427.png,0,0.18905798,0,1
|
| 143 |
+
IMG_001660.png,0,0.059407346,0,1
|
| 144 |
+
IMG_001189.png,0,0.047472887,0,1
|
| 145 |
+
IMG_000257.png,0,0.8905588,1,0
|
| 146 |
+
IMG_001369.png,0,0.36061272,0,1
|
| 147 |
+
IMG_001691.png,1,0.964177,1,1
|
| 148 |
+
IMG_001732.png,1,0.9945857,1,1
|
| 149 |
+
IMG_001520.png,0,0.22373733,0,1
|
| 150 |
+
IMG_001376.png,1,0.7934236,1,1
|
| 151 |
+
IMG_000436.png,0,0.71353745,1,0
|
| 152 |
+
IMG_000741.png,1,0.98955476,1,1
|
| 153 |
+
IMG_000446.png,0,0.76797557,1,0
|
| 154 |
+
IMG_000876.png,0,0.43126056,0,1
|
| 155 |
+
IMG_001462.png,0,0.013134937,0,1
|
| 156 |
+
IMG_001066.png,0,0.00046867327,0,1
|
| 157 |
+
IMG_000114.png,0,0.3608086,0,1
|
| 158 |
+
IMG_001015.png,0,0.04901976,0,1
|
| 159 |
+
IMG_001372.png,0,0.018457958,0,1
|
| 160 |
+
IMG_000949.png,1,0.99760324,1,1
|
| 161 |
+
IMG_000919.png,0,0.016946174,0,1
|
| 162 |
+
IMG_001172.png,1,0.8392962,1,1
|
| 163 |
+
IMG_001068.png,0,0.8429959,1,0
|
| 164 |
+
IMG_000087.png,0,0.6321575,1,0
|
| 165 |
+
IMG_000888.png,0,0.84551364,1,0
|
| 166 |
+
IMG_000286.png,1,0.9955705,1,1
|
| 167 |
+
IMG_000214.png,1,0.73011214,1,1
|
| 168 |
+
IMG_001065.png,0,0.047320206,0,1
|
| 169 |
+
IMG_000510.png,0,0.00021625705,0,1
|
| 170 |
+
IMG_000315.png,0,0.040586166,0,1
|
| 171 |
+
IMG_000422.png,1,0.86234355,1,1
|
| 172 |
+
IMG_000468.png,1,0.9084235,1,1
|
| 173 |
+
IMG_000749.png,0,0.048509527,0,1
|
| 174 |
+
IMG_000822.png,0,0.00459977,0,1
|
| 175 |
+
IMG_001025.png,0,0.1423657,0,1
|
| 176 |
+
IMG_000777.png,1,0.94110733,1,1
|
| 177 |
+
IMG_001235.png,1,0.955069,1,1
|
| 178 |
+
IMG_001265.png,1,0.4247596,0,0
|
| 179 |
+
IMG_000476.png,0,0.0074078897,0,1
|
| 180 |
+
IMG_000922.png,1,0.4249936,0,0
|
| 181 |
+
IMG_001294.png,0,0.5523263,1,0
|
| 182 |
+
IMG_001181.png,0,0.021967547,0,1
|
| 183 |
+
IMG_000179.png,1,0.6198387,1,1
|
| 184 |
+
IMG_000844.png,0,0.0009595261,0,1
|
| 185 |
+
IMG_001062.png,0,0.02210043,0,1
|
| 186 |
+
IMG_000647.png,0,0.048218727,0,1
|
| 187 |
+
IMG_000629.png,1,0.7496184,1,1
|
| 188 |
+
IMG_000743.png,1,0.94583535,1,1
|
| 189 |
+
IMG_000639.png,0,0.31978914,0,1
|
| 190 |
+
IMG_000207.png,0,0.0010147754,0,1
|
| 191 |
+
IMG_000116.png,0,0.4486534,0,1
|
| 192 |
+
IMG_000412.png,0,0.00072938023,0,1
|
| 193 |
+
IMG_001076.png,0,0.00037972903,0,1
|
| 194 |
+
IMG_000911.png,0,0.21701539,0,1
|
| 195 |
+
IMG_000226.png,1,0.9124851,1,1
|
| 196 |
+
IMG_001434.png,0,0.1508207,0,1
|
| 197 |
+
IMG_001534.png,0,0.007111399,0,1
|
| 198 |
+
IMG_000239.png,1,0.97189856,1,1
|
| 199 |
+
IMG_000729.png,0,0.024547786,0,1
|
| 200 |
+
IMG_000362.png,0,0.0368409,0,1
|
| 201 |
+
IMG_001563.png,1,0.92587185,1,1
|
| 202 |
+
IMG_000850.png,0,0.031102972,0,1
|
| 203 |
+
IMG_001283.png,0,0.0041644424,0,1
|
| 204 |
+
IMG_001077.png,0,0.0043652514,0,1
|
| 205 |
+
IMG_000109.png,1,0.94027257,1,1
|
| 206 |
+
IMG_001039.png,0,0.31002668,0,1
|
| 207 |
+
IMG_000855.png,1,0.9954782,1,1
|
| 208 |
+
IMG_000581.png,0,0.58953196,1,0
|
| 209 |
+
IMG_001670.png,1,0.41889447,0,0
|
| 210 |
+
IMG_001489.png,0,0.0039128903,0,1
|
| 211 |
+
IMG_001488.png,1,0.977498,1,1
|
| 212 |
+
IMG_001247.png,1,0.9217244,1,1
|
| 213 |
+
IMG_001340.png,1,0.9252595,1,1
|
| 214 |
+
IMG_000562.png,0,0.00028323507,0,1
|
| 215 |
+
IMG_001714.png,1,0.9137611,1,1
|
| 216 |
+
IMG_000113.png,1,0.83810467,1,1
|
| 217 |
+
IMG_000519.png,0,0.24291025,0,1
|
| 218 |
+
IMG_000352.png,1,0.89757895,1,1
|
| 219 |
+
IMG_000060.png,0,0.12018548,0,1
|
| 220 |
+
IMG_001689.png,1,0.88260335,1,1
|
| 221 |
+
IMG_001332.png,0,0.021517713,0,1
|
| 222 |
+
IMG_001227.png,0,0.5806698,1,0
|
| 223 |
+
IMG_000952.png,1,0.91132796,1,1
|
| 224 |
+
IMG_001223.png,0,0.0054034726,0,1
|
| 225 |
+
IMG_001662.png,1,0.9796165,1,1
|
| 226 |
+
IMG_001242.png,1,0.89553756,1,1
|
| 227 |
+
IMG_001542.png,1,0.2224111,0,0
|
| 228 |
+
IMG_000150.png,0,0.038930204,0,1
|
| 229 |
+
IMG_001021.png,0,0.00035879802,0,1
|
| 230 |
+
IMG_001516.png,0,0.62845725,1,0
|
| 231 |
+
IMG_001074.png,0,0.7126184,1,0
|
| 232 |
+
IMG_001648.png,1,0.9681084,1,1
|
| 233 |
+
IMG_000300.png,1,0.66993195,1,1
|
| 234 |
+
IMG_000194.png,0,0.16441628,0,1
|
| 235 |
+
IMG_000728.png,0,0.005496844,0,1
|
| 236 |
+
IMG_000818.png,1,0.92287654,1,1
|
| 237 |
+
IMG_001712.png,0,0.29255372,0,1
|
| 238 |
+
IMG_000719.png,0,0.075776406,0,1
|
model_training/notebooks/results/training_history.csv
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed,stage,epoch,train_loss,train_auc,train_acc,train_composite,val_loss,val_auc,val_acc,val_sens,val_spec,val_composite,lr
|
| 2 |
+
8,frozen,1,0.9582490903990609,0.7691765873015873,0.6177536231884058,0.6771676587301587,0.7091568447649479,0.85099052540913,0.6708860759493671,0.37037037037037035,0.9224806201550387,0.7487080103359173,0.001
|
| 3 |
+
8,frozen,2,0.6246700976576124,0.783531746031746,0.6838768115942029,0.726031746031746,0.48649174720048904,0.8657766293425209,0.7552742616033755,0.6574074074074074,0.8372093023255814,0.8065424921045077,0.001
|
| 4 |
+
8,frozen,3,0.5868643377508436,0.7798181216931217,0.6920289855072463,0.7326074735449736,0.45043402910232544,0.8748923341946597,0.7721518987341772,0.7037037037037037,0.8294573643410853,0.8207364341085271,0.001
|
| 5 |
+
8,frozen,4,0.5721473864146641,0.7899371693121694,0.7038043478260869,0.7456233465608466,0.43230215832591057,0.878409417169107,0.7805907172995781,0.7314814814814815,0.8217054263565892,0.8275014355440712,0.001
|
| 6 |
+
8,frozen,5,0.54250961627279,0.8040046296296297,0.7246376811594203,0.7627959656084657,0.4336203671991825,0.8802038472581108,0.7763713080168776,0.7129629629629629,0.8294573643410853,0.8257070054550675,0.001
|
| 7 |
+
8,frozen,6,0.534449862582343,0.8101421957671958,0.7336956521739131,0.7705076058201058,0.40960250422358513,0.8915446454206145,0.7890295358649789,0.7407407407407407,0.8294573643410853,0.8383218489807638,0.001
|
| 8 |
+
8,frozen,7,0.530328972850527,0.814781746031746,0.7454710144927537,0.7780853174603175,0.40932638198137283,0.8911139821992536,0.7932489451476793,0.7407407407407407,0.8372093023255814,0.8400445018662074,0.001
|
| 9 |
+
8,frozen,8,0.5051169097423553,0.8304497354497355,0.7373188405797102,0.7829629629629631,0.4084223732352257,0.8912575366063739,0.7974683544303798,0.75,0.8372093023255814,0.8424310938845824,0.001
|
| 10 |
+
8,frozen,9,0.5022006383963994,0.8348148148148148,0.7672101449275363,0.7998478835978836,0.4024202488362789,0.895205282802182,0.7890295358649789,0.7314814814814815,0.8372093023255814,0.8397753373528568,0.001
|
| 11 |
+
8,frozen,10,0.49878278289522443,0.8352347883597884,0.7518115942028986,0.7921808862433863,0.39458586275577545,0.8993683606086705,0.7890295358649789,0.7314814814814815,0.8372093023255814,0.8418568762561011,0.001
|
| 12 |
+
8,finetune,1,0.4544305384159088,0.8672321428571428,0.7835144927536232,0.8254613095238095,0.38248663768172264,0.919250645994832,0.8270042194092827,0.75,0.8914728682170543,0.8699935400516796,0.0009891830623632338
|
| 13 |
+
8,finetune,2,0.38789899774960107,0.9041567460317461,0.8342391304347826,0.8695585317460318,0.2955472208559513,0.9496841803043353,0.8860759493670886,0.9444444444444444,0.8372093023255814,0.9202555268446742,0.0009572050015330873
|
| 14 |
+
8,finetune,3,0.3507444177355085,0.9221329365079365,0.8469202898550725,0.8848561507936508,0.3256516754627228,0.937123169681309,0.8818565400843882,0.8703703703703703,0.8914728682170543,0.9090223944875107,0.0009054634122155989
|
| 15 |
+
8,finetune,4,0.33837775332587106,0.9301190476190476,0.8641304347826086,0.8962896825396826,0.33792560175061226,0.9344674131495836,0.8607594936708861,0.8425925925925926,0.875968992248062,0.8968741027849554,0.0008362196501476348
|
| 16 |
+
8,finetune,5,0.30863836577960424,0.9406580687830688,0.8586956521739131,0.90088458994709,0.35501030273735523,0.9326729830605799,0.8565400843881856,0.7962962962962963,0.9069767441860465,0.8921547516508757,0.0007524999999999999
|
| 17 |
+
8,finetune,6,0.32643192623342787,0.9347982804232804,0.8523550724637681,0.892974537037037,0.31595958210527897,0.9378409417169108,0.8860759493670886,0.8796296296296297,0.8914728682170543,0.9116960953201263,0.0006579634122155989
|
| 18 |
+
7,frozen,1,0.9370916630540576,0.7815476190476189,0.6204710144927537,0.6849999999999999,0.6874790899455547,0.8518518518518519,0.6708860759493671,0.3888888888888889,0.9069767441860465,0.7498923341946597,0.001
|
| 19 |
+
7,frozen,2,0.6154342787606376,0.801253306878307,0.6983695652173914,0.7406068121693122,0.5005612336099148,0.863982199253517,0.7510548523206751,0.6481481481481481,0.8372093023255814,0.803330462245191,0.001
|
| 20 |
+
7,frozen,3,0.5728206719670977,0.7850958994708995,0.7028985507246377,0.7411193783068782,0.4355204217135906,0.8814958369221936,0.7848101265822784,0.7407407407407407,0.8217054263565892,0.8313594602354293,0.001
|
| 21 |
+
7,frozen,4,0.5409952683108193,0.803399470899471,0.7119565217391305,0.7561838624338624,0.4371259845793247,0.8789836347975883,0.7805907172995781,0.7222222222222222,0.8294573643410853,0.8274117140396211,0.001
|
| 22 |
+
7,frozen,5,0.5573380930083138,0.7994576719576719,0.7346014492753623,0.7652645502645502,0.41805607080459595,0.8866637955785241,0.7805907172995781,0.7407407407407407,0.813953488372093,0.8320054550674706,0.001
|
| 23 |
+
7,frozen,6,0.5511363804340362,0.8038194444444444,0.7355072463768116,0.7692113095238096,0.4179871119558811,0.886017800746483,0.7974683544303798,0.75,0.8372093023255814,0.8398112259546369,0.001
|
| 24 |
+
7,frozen,7,0.5211189142295293,0.8182903439153438,0.7336956521739131,0.7744229497354497,0.40560774877667427,0.8934108527131782,0.7974683544303798,0.75,0.8372093023255814,0.8435077519379844,0.001
|
| 25 |
+
7,frozen,8,0.5251553058624268,0.8196527777777778,0.7391304347826086,0.7777628968253969,0.40271764248609543,0.8944157335630204,0.8059071729957806,0.7685185185185185,0.8372093023255814,0.8486398219925352,0.001
|
| 26 |
+
7,frozen,9,0.5052971592971257,0.8322817460317459,0.7445652173913043,0.7871329365079364,0.3970758095383644,0.8983634797588287,0.8016877637130801,0.7592592592592593,0.8372093023255814,0.8482988802756245,0.001
|
| 27 |
+
7,frozen,10,0.5330768167972565,0.8135284391534392,0.7373188405797102,0.7737086640211641,0.3852919824421406,0.904392764857881,0.810126582278481,0.7870370370370371,0.8294573643410853,0.8563199827734711,0.001
|
| 28 |
+
7,finetune,1,0.47272565024239677,0.8557539682539683,0.7817028985507246,0.8182539682539682,0.44458993524312973,0.8919753086419753,0.7763713080168776,0.6481481481481481,0.8837209302325582,0.8289549239161642,0.0009891830623632338
|
| 29 |
+
7,finetune,2,0.38450789025851656,0.9061078042328042,0.8170289855072463,0.8615062830687831,0.3831586390733719,0.9171691070915876,0.8354430379746836,0.7685185185185185,0.8914728682170543,0.873582400229687,0.0009572050015330873
|
| 30 |
+
7,finetune,3,0.35653283638613564,0.922037037037037,0.8505434782608695,0.8862367724867725,0.31522940658032894,0.9425782371518806,0.8818565400843882,0.9074074074074074,0.8604651162790697,0.9132572494975595,0.0009054634122155989
|
| 31 |
+
7,finetune,4,0.33924023721899305,0.9283333333333333,0.8505434782608695,0.8892261904761904,0.3659691587090492,0.9249210450760839,0.8312236286919831,0.7592592592592593,0.8914728682170543,0.8751435544071202,0.0008362196501476348
|
| 32 |
+
7,finetune,5,0.33330591917037966,0.9319179894179894,0.8623188405797102,0.8970701058201058,0.32207747735083103,0.939420040195234,0.8649789029535865,0.8981481481481481,0.8372093023255814,0.9035493827160493,0.0007524999999999999
|
| 33 |
+
7,finetune,6,0.31548327335289544,0.9377645502645502,0.8596014492753623,0.8990608465608465,0.35484062135219574,0.9318116566178581,0.8438818565400844,0.7962962962962963,0.8837209302325582,0.8859101349411427,0.0006579634122155989
|
| 34 |
+
7,finetune,7,0.287742834006037,0.9492956349206351,0.8731884057971014,0.9116319444444445,0.3438141830265522,0.930304335343095,0.8481012658227848,0.7962962962962963,0.8914728682170543,0.887094458799885,0.0005567415893174885
|
| 35 |
+
2024,frozen,1,0.9445781009537834,0.7710317460317461,0.6304347826086957,0.6849603174603175,0.7147812955081463,0.8519236290554119,0.6666666666666666,0.37037037037037035,0.9147286821705426,0.7472365776629342,0.001
|
| 36 |
+
2024,frozen,2,0.6475456161158425,0.7795568783068784,0.6902173913043478,0.7264847883597884,0.4829690493643284,0.8683606086706861,0.7510548523206751,0.6481481481481481,0.8372093023255814,0.8055196669537755,0.001
|
| 37 |
+
2024,frozen,3,0.557839333159583,0.7974801587301588,0.7119565217391305,0.7520337301587301,0.45181746035814285,0.8722365776629342,0.7721518987341772,0.7037037037037037,0.8294573643410853,0.8194085558426644,0.001
|
| 38 |
+
2024,frozen,4,0.553470607314791,0.7997123015873016,0.717391304347826,0.7562053571428571,0.43246644735336304,0.8798449612403101,0.7805907172995781,0.7314814814814815,0.8217054263565892,0.8282192075796727,0.001
|
| 39 |
+
2024,frozen,5,0.5317457369395665,0.8138822751322751,0.7318840579710145,0.772417328042328,0.41545065492391586,0.8885300028710881,0.7890295358649789,0.75,0.8217054263565892,0.8371913580246914,0.001
|
| 40 |
+
2024,frozen,6,0.5310433038643428,0.8142757936507937,0.7273550724637681,0.7686259920634921,0.41416966915130615,0.8903962101636521,0.7848101265822784,0.7222222222222222,0.8372093023255814,0.835055986218777,0.001
|
| 41 |
+
2024,frozen,7,0.535319973741259,0.810681216931217,0.7336956521739131,0.7706183862433862,0.41076211631298065,0.88946310651737,0.810126582278481,0.7870370370370371,0.8294573643410853,0.8488551536032156,0.001
|
| 42 |
+
2024,frozen,8,0.5011444449424743,0.8314996693121693,0.7472826086956522,0.7886268187830687,0.4013783372938633,0.8952770600057421,0.7932489451476793,0.7407407407407407,0.8372093023255814,0.8421260407694516,0.001
|
| 43 |
+
2024,frozen,9,0.5317847013473511,0.8112169312169313,0.7336956521739131,0.7711243386243387,0.38976578786969185,0.9020241171403963,0.8227848101265823,0.8148148148148148,0.8294573643410853,0.8620801033591732,0.001
|
| 44 |
+
2024,frozen,10,0.5140501362936837,0.8266732804232804,0.7336956521739131,0.7782175925925926,0.3864467516541481,0.9039621016365202,0.8143459915611815,0.7962962962962963,0.8294573643410853,0.8584194659776055,0.001
|
| 45 |
+
2024,finetune,1,0.4775453839983259,0.853396164021164,0.7608695652173914,0.8070949074074074,0.3899293653666973,0.9338931955211025,0.8354430379746836,0.9444444444444444,0.7441860465116279,0.8891042204995693,0.0009891830623632338
|
| 46 |
+
2024,finetune,2,0.38173683881759646,0.9084457671957672,0.8251811594202898,0.8663458994708995,0.3108560163527727,0.9433677863910422,0.890295358649789,0.8888888888888888,0.8914728682170543,0.9167743324720068,0.0009572050015330873
|
| 47 |
+
2024,finetune,3,0.36059877617018565,0.918994708994709,0.8460144927536232,0.8829497354497355,0.30220645666122437,0.94717197817973,0.890295358649789,0.8981481481481481,0.8837209302325582,0.9190532586850415,0.0009054634122155989
|
| 48 |
+
2024,finetune,4,0.3435957955462592,0.9252678571428571,0.8505434782608695,0.8876934523809523,0.32907535694539547,0.9417886879127189,0.8734177215189873,0.9351851851851852,0.8217054263565892,0.910116996841803,0.0008362196501476348
|
| 49 |
+
2024,finetune,5,0.3113635071686336,0.9391468253968255,0.8731884057971014,0.9062400793650794,0.35001079365611076,0.9295147861039335,0.8523206751054853,0.8240740740740741,0.875968992248062,0.8897681596325007,0.0007524999999999999
|
| 50 |
+
2024,finetune,6,0.3115046965224402,0.9414550264550263,0.8605072463768116,0.9017989417989417,0.3371146433055401,0.9364771748492678,0.8312236286919831,0.7870370370370371,0.8682170542635659,0.8820521102497847,0.0006579634122155989
|
| 51 |
+
2024,finetune,7,0.267705562710762,0.9570701058201058,0.8786231884057971,0.9179398148148148,0.3176867663860321,0.9399942578237152,0.8649789029535865,0.8425925925925926,0.8837209302325582,0.9015755096181453,0.0005567415893174885
|
model_training/notebooks/results/val_metrics_final.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"threshold": 0.504486,
|
| 3 |
+
"auc": 0.9499712891185758,
|
| 4 |
+
"sensitivity": 0.9166666666666666,
|
| 5 |
+
"specificity": 0.8837209302325582,
|
| 6 |
+
"accuracy": 0.8987341772151899,
|
| 7 |
+
"youden": 0.8003875968992249,
|
| 8 |
+
"composite": 0.9250825437840942,
|
| 9 |
+
"precision": 0.868421052631579,
|
| 10 |
+
"tp": 99,
|
| 11 |
+
"tn": 114,
|
| 12 |
+
"fp": 15,
|
| 13 |
+
"fn": 9
|
| 14 |
+
}
|
model_training/notebooks/results/val_predictions.csv
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
filename,y_true,prob,pred,correct
|
| 2 |
+
IMG_000833.png,0,0.08514188,0,1
|
| 3 |
+
IMG_000868.png,0,0.07656998,0,1
|
| 4 |
+
IMG_000471.png,1,0.926362,1,1
|
| 5 |
+
IMG_001514.png,0,0.042604674,0,1
|
| 6 |
+
IMG_001123.png,0,0.010878541,0,1
|
| 7 |
+
IMG_000311.png,0,0.0019978306,0,1
|
| 8 |
+
IMG_000768.png,0,0.1100046,0,1
|
| 9 |
+
IMG_001354.png,0,0.38976616,0,1
|
| 10 |
+
IMG_000677.png,1,0.9424959,1,1
|
| 11 |
+
IMG_001694.png,1,0.82901657,1,1
|
| 12 |
+
IMG_001249.png,1,0.9801762,1,1
|
| 13 |
+
IMG_001543.png,1,0.5918354,1,1
|
| 14 |
+
IMG_001138.png,1,0.9144192,1,1
|
| 15 |
+
IMG_001279.png,0,0.0023716255,0,1
|
| 16 |
+
IMG_001425.png,1,0.9752167,1,1
|
| 17 |
+
IMG_001664.png,0,0.0030602103,0,1
|
| 18 |
+
IMG_000517.png,0,0.26428953,0,1
|
| 19 |
+
IMG_000238.png,0,0.07020499,0,1
|
| 20 |
+
IMG_001419.png,1,0.9230154,1,1
|
| 21 |
+
IMG_000357.png,0,0.023781383,0,1
|
| 22 |
+
IMG_001413.png,0,0.53677505,1,0
|
| 23 |
+
IMG_001438.png,0,0.003928518,0,1
|
| 24 |
+
IMG_000730.png,1,0.9570253,1,1
|
| 25 |
+
IMG_000043.png,1,0.9750734,1,1
|
| 26 |
+
IMG_000034.png,0,0.006953893,0,1
|
| 27 |
+
IMG_000657.png,0,0.07155775,0,1
|
| 28 |
+
IMG_001544.png,0,0.90743506,1,0
|
| 29 |
+
IMG_000216.png,0,0.029633842,0,1
|
| 30 |
+
IMG_001193.png,0,0.015726507,0,1
|
| 31 |
+
IMG_001046.png,0,0.039458692,0,1
|
| 32 |
+
IMG_001555.png,0,0.3316432,0,1
|
| 33 |
+
IMG_001622.png,1,0.9557772,1,1
|
| 34 |
+
IMG_000170.png,1,0.8454006,1,1
|
| 35 |
+
IMG_000584.png,1,0.7091627,1,1
|
| 36 |
+
IMG_000750.png,0,0.6704473,1,0
|
| 37 |
+
IMG_000168.png,1,0.90156275,1,1
|
| 38 |
+
IMG_000230.png,0,0.66135454,1,0
|
| 39 |
+
IMG_000534.png,1,0.9762868,1,1
|
| 40 |
+
IMG_001523.png,0,0.5044856,0,1
|
| 41 |
+
IMG_000926.png,1,0.96041846,1,1
|
| 42 |
+
IMG_000776.png,0,0.22449876,0,1
|
| 43 |
+
IMG_000842.png,1,0.9752196,1,1
|
| 44 |
+
IMG_001338.png,1,0.29997367,0,0
|
| 45 |
+
IMG_000126.png,1,0.8138403,1,1
|
| 46 |
+
IMG_000377.png,1,0.8908096,1,1
|
| 47 |
+
IMG_001713.png,1,0.92572457,1,1
|
| 48 |
+
IMG_000996.png,0,0.3828967,0,1
|
| 49 |
+
IMG_000278.png,0,0.009509609,0,1
|
| 50 |
+
IMG_001392.png,1,0.96734214,1,1
|
| 51 |
+
IMG_000666.png,0,0.005913499,0,1
|
| 52 |
+
IMG_001012.png,0,0.0072901817,0,1
|
| 53 |
+
IMG_000515.png,0,0.18308273,0,1
|
| 54 |
+
IMG_000177.png,1,0.0014123927,0,0
|
| 55 |
+
IMG_000099.png,1,0.71664655,1,1
|
| 56 |
+
IMG_001485.png,0,0.07025461,0,1
|
| 57 |
+
IMG_000599.png,1,0.7633101,1,1
|
| 58 |
+
IMG_001377.png,0,0.00014443095,0,1
|
| 59 |
+
IMG_001490.png,1,0.9397476,1,1
|
| 60 |
+
IMG_000612.png,0,0.04453209,0,1
|
| 61 |
+
IMG_001293.png,1,0.92839843,1,1
|
| 62 |
+
IMG_001679.png,1,0.75982463,1,1
|
| 63 |
+
IMG_000783.png,0,0.102033496,0,1
|
| 64 |
+
IMG_001396.png,0,0.06896914,0,1
|
| 65 |
+
IMG_001047.png,0,0.0027958185,0,1
|
| 66 |
+
IMG_000090.png,0,0.3977522,0,1
|
| 67 |
+
IMG_000601.png,1,0.37769222,0,0
|
| 68 |
+
IMG_001475.png,1,0.89420277,1,1
|
| 69 |
+
IMG_000494.png,1,0.63474864,1,1
|
| 70 |
+
IMG_001216.png,0,0.025922079,0,1
|
| 71 |
+
IMG_001564.png,1,0.97837263,1,1
|
| 72 |
+
IMG_000928.png,0,0.057339218,0,1
|
| 73 |
+
IMG_001161.png,0,0.016286755,0,1
|
| 74 |
+
IMG_000024.png,0,0.8985922,1,0
|
| 75 |
+
IMG_000319.png,0,0.648694,1,0
|
| 76 |
+
IMG_001416.png,0,0.04064005,0,1
|
| 77 |
+
IMG_000478.png,1,0.9229935,1,1
|
| 78 |
+
IMG_001122.png,0,0.049890943,0,1
|
| 79 |
+
IMG_001740.png,1,0.93487585,1,1
|
| 80 |
+
IMG_001500.png,0,0.09520903,0,1
|
| 81 |
+
IMG_000140.png,1,0.98620355,1,1
|
| 82 |
+
IMG_001289.png,1,0.9263547,1,1
|
| 83 |
+
IMG_000937.png,1,0.60724884,1,1
|
| 84 |
+
IMG_000184.png,1,0.56130826,1,1
|
| 85 |
+
IMG_001060.png,0,0.0008340298,0,1
|
| 86 |
+
IMG_001379.png,1,0.72570825,1,1
|
| 87 |
+
IMG_000051.png,1,0.9593886,1,1
|
| 88 |
+
IMG_000684.png,0,0.031647827,0,1
|
| 89 |
+
IMG_000173.png,1,0.54984164,1,1
|
| 90 |
+
IMG_001529.png,0,0.37173158,0,1
|
| 91 |
+
IMG_001540.png,0,0.0022377113,0,1
|
| 92 |
+
IMG_000469.png,0,0.011063429,0,1
|
| 93 |
+
IMG_000452.png,0,0.5156396,1,0
|
| 94 |
+
IMG_000732.png,0,0.40668148,0,1
|
| 95 |
+
IMG_000546.png,0,0.8375038,1,0
|
| 96 |
+
IMG_000506.png,1,0.9919884,1,1
|
| 97 |
+
IMG_001633.png,0,0.33360916,0,1
|
| 98 |
+
IMG_000082.png,0,0.371273,0,1
|
| 99 |
+
IMG_000206.png,0,0.16295624,0,1
|
| 100 |
+
IMG_000975.png,0,0.00403203,0,1
|
| 101 |
+
IMG_000426.png,1,0.98051214,1,1
|
| 102 |
+
IMG_000441.png,0,0.07469518,0,1
|
| 103 |
+
IMG_000117.png,0,0.031310383,0,1
|
| 104 |
+
IMG_000973.png,0,0.00051067735,0,1
|
| 105 |
+
IMG_000881.png,1,0.8859954,1,1
|
| 106 |
+
IMG_000247.png,1,0.94318277,1,1
|
| 107 |
+
IMG_000574.png,0,0.014506694,0,1
|
| 108 |
+
IMG_000381.png,1,0.9848424,1,1
|
| 109 |
+
IMG_001618.png,0,0.012155103,0,1
|
| 110 |
+
IMG_001053.png,1,0.24243784,0,0
|
| 111 |
+
IMG_001400.png,0,0.25488326,0,1
|
| 112 |
+
IMG_000590.png,1,0.95325905,1,1
|
| 113 |
+
IMG_000342.png,1,0.8786827,1,1
|
| 114 |
+
IMG_001213.png,0,0.0063149747,0,1
|
| 115 |
+
IMG_000075.png,0,0.18751769,0,1
|
| 116 |
+
IMG_000608.png,0,0.09344582,0,1
|
| 117 |
+
IMG_000493.png,1,0.74505854,1,1
|
| 118 |
+
IMG_000630.png,0,0.37733248,0,1
|
| 119 |
+
IMG_001676.png,1,0.50721115,1,1
|
| 120 |
+
IMG_001102.png,1,0.90345204,1,1
|
| 121 |
+
IMG_000095.png,0,0.0039746515,0,1
|
| 122 |
+
IMG_001303.png,0,0.02584564,0,1
|
| 123 |
+
IMG_001266.png,0,0.0027153727,0,1
|
| 124 |
+
IMG_000755.png,1,0.5398157,1,1
|
| 125 |
+
IMG_001295.png,1,0.89617956,1,1
|
| 126 |
+
IMG_000832.png,1,0.985641,1,1
|
| 127 |
+
IMG_001705.png,0,0.017828463,0,1
|
| 128 |
+
IMG_001306.png,0,0.0074857906,0,1
|
| 129 |
+
IMG_000246.png,0,0.33929592,0,1
|
| 130 |
+
IMG_001530.png,1,0.72100955,1,1
|
| 131 |
+
IMG_000970.png,0,0.32502642,0,1
|
| 132 |
+
IMG_000059.png,1,0.9853226,1,1
|
| 133 |
+
IMG_001149.png,1,0.92740005,1,1
|
| 134 |
+
IMG_001200.png,1,0.95928955,1,1
|
| 135 |
+
IMG_001055.png,1,0.76670074,1,1
|
| 136 |
+
IMG_000995.png,0,0.8288552,1,0
|
| 137 |
+
IMG_000038.png,0,0.20308095,0,1
|
| 138 |
+
IMG_000524.png,0,0.005141989,0,1
|
| 139 |
+
IMG_001568.png,0,0.14497586,0,1
|
| 140 |
+
IMG_000122.png,0,0.051262774,0,1
|
| 141 |
+
IMG_000333.png,1,0.8129171,1,1
|
| 142 |
+
IMG_001109.png,0,0.18362695,0,1
|
| 143 |
+
IMG_000698.png,0,0.01703381,0,1
|
| 144 |
+
IMG_001573.png,1,0.9943408,1,1
|
| 145 |
+
IMG_001625.png,1,0.91411126,1,1
|
| 146 |
+
IMG_000509.png,0,0.0028896236,0,1
|
| 147 |
+
IMG_001448.png,1,0.9428977,1,1
|
| 148 |
+
IMG_000331.png,1,0.9803683,1,1
|
| 149 |
+
IMG_000008.png,0,0.038054988,0,1
|
| 150 |
+
IMG_000521.png,0,0.0010172316,0,1
|
| 151 |
+
IMG_000012.png,1,0.9812463,1,1
|
| 152 |
+
IMG_000780.png,1,0.6871019,1,1
|
| 153 |
+
IMG_000428.png,1,0.49723867,0,0
|
| 154 |
+
IMG_000463.png,0,0.016233142,0,1
|
| 155 |
+
IMG_001423.png,1,0.8972441,1,1
|
| 156 |
+
IMG_000287.png,1,0.61366016,1,1
|
| 157 |
+
IMG_001173.png,1,0.17359304,0,0
|
| 158 |
+
IMG_001357.png,1,0.6145142,1,1
|
| 159 |
+
IMG_001329.png,1,0.9680075,1,1
|
| 160 |
+
IMG_001119.png,0,0.013847093,0,1
|
| 161 |
+
IMG_001675.png,0,0.01776683,0,1
|
| 162 |
+
IMG_000877.png,0,0.9286344,1,0
|
| 163 |
+
IMG_000402.png,0,0.25376308,0,1
|
| 164 |
+
IMG_001176.png,1,0.9541638,1,1
|
| 165 |
+
IMG_000068.png,0,0.0027813588,0,1
|
| 166 |
+
IMG_000091.png,0,0.06692752,0,1
|
| 167 |
+
IMG_000600.png,0,0.68098706,1,0
|
| 168 |
+
IMG_000765.png,1,0.9324583,1,1
|
| 169 |
+
IMG_000965.png,0,0.0155419195,0,1
|
| 170 |
+
IMG_000361.png,1,0.4985782,0,0
|
| 171 |
+
IMG_000414.png,1,0.95493746,1,1
|
| 172 |
+
IMG_000969.png,0,0.06441665,0,1
|
| 173 |
+
IMG_001087.png,0,0.0050179944,0,1
|
| 174 |
+
IMG_000124.png,1,0.6126341,1,1
|
| 175 |
+
IMG_000013.png,0,0.14956698,0,1
|
| 176 |
+
IMG_001140.png,1,0.98292345,1,1
|
| 177 |
+
IMG_001088.png,1,0.9412705,1,1
|
| 178 |
+
IMG_001655.png,0,0.10415509,0,1
|
| 179 |
+
IMG_000720.png,1,0.5776679,1,1
|
| 180 |
+
IMG_001687.png,0,0.054884706,0,1
|
| 181 |
+
IMG_000654.png,1,0.80313796,1,1
|
| 182 |
+
IMG_000200.png,0,0.016887484,0,1
|
| 183 |
+
IMG_001131.png,0,0.59535867,1,0
|
| 184 |
+
IMG_000236.png,1,0.9608375,1,1
|
| 185 |
+
IMG_001004.png,1,0.91360044,1,1
|
| 186 |
+
IMG_000869.png,1,0.9747172,1,1
|
| 187 |
+
IMG_001630.png,0,0.0015267566,0,1
|
| 188 |
+
IMG_000374.png,1,0.32337222,0,0
|
| 189 |
+
IMG_001203.png,0,0.00011447017,0,1
|
| 190 |
+
IMG_000853.png,0,0.06780861,0,1
|
| 191 |
+
IMG_000638.png,0,0.02331572,0,1
|
| 192 |
+
IMG_000620.png,1,0.9309563,1,1
|
| 193 |
+
IMG_000167.png,1,0.9887361,1,1
|
| 194 |
+
IMG_000811.png,0,0.26473084,0,1
|
| 195 |
+
IMG_001061.png,1,0.905426,1,1
|
| 196 |
+
IMG_000456.png,0,0.00062740117,0,1
|
| 197 |
+
IMG_000438.png,1,0.7869287,1,1
|
| 198 |
+
IMG_000292.png,1,0.9418937,1,1
|
| 199 |
+
IMG_000644.png,0,0.5747625,1,0
|
| 200 |
+
IMG_000491.png,0,0.014190746,0,1
|
| 201 |
+
IMG_000564.png,0,0.0036469798,0,1
|
| 202 |
+
IMG_000738.png,0,0.8187067,1,0
|
| 203 |
+
IMG_000886.png,1,0.7505375,1,1
|
| 204 |
+
IMG_000387.png,0,0.24593168,0,1
|
| 205 |
+
IMG_000492.png,0,0.33065897,0,1
|
| 206 |
+
IMG_001057.png,1,0.9879611,1,1
|
| 207 |
+
IMG_001195.png,0,0.0053326027,0,1
|
| 208 |
+
IMG_000088.png,1,0.91581327,1,1
|
| 209 |
+
IMG_000014.png,0,0.08517947,0,1
|
| 210 |
+
IMG_000830.png,0,0.02413497,0,1
|
| 211 |
+
IMG_000779.png,0,0.021322682,0,1
|
| 212 |
+
IMG_001510.png,0,0.06587305,0,1
|
| 213 |
+
IMG_000556.png,0,0.019615667,0,1
|
| 214 |
+
IMG_000744.png,1,0.8163302,1,1
|
| 215 |
+
IMG_001059.png,0,0.012555552,0,1
|
| 216 |
+
IMG_001013.png,1,0.12594076,0,0
|
| 217 |
+
IMG_000901.png,1,0.9905481,1,1
|
| 218 |
+
IMG_000983.png,0,0.13618241,0,1
|
| 219 |
+
IMG_001359.png,1,0.90301836,1,1
|
| 220 |
+
IMG_001604.png,0,0.0092756525,0,1
|
| 221 |
+
IMG_001050.png,0,0.41582817,0,1
|
| 222 |
+
IMG_000662.png,0,0.0009655981,0,1
|
| 223 |
+
IMG_000225.png,0,0.059564617,0,1
|
| 224 |
+
IMG_001261.png,0,0.00012916393,0,1
|
| 225 |
+
IMG_001401.png,1,0.8134201,1,1
|
| 226 |
+
IMG_001136.png,0,0.03423228,0,1
|
| 227 |
+
IMG_000031.png,1,0.9828577,1,1
|
| 228 |
+
IMG_001547.png,0,0.028171018,0,1
|
| 229 |
+
IMG_000023.png,1,0.96932787,1,1
|
| 230 |
+
IMG_000261.png,1,0.89618427,1,1
|
| 231 |
+
IMG_001125.png,1,0.9938845,1,1
|
| 232 |
+
IMG_000210.png,1,0.9844753,1,1
|
| 233 |
+
IMG_000305.png,0,0.96544015,1,0
|
| 234 |
+
IMG_001426.png,0,0.23876508,0,1
|
| 235 |
+
IMG_001204.png,1,0.9547577,1,1
|
| 236 |
+
IMG_000081.png,1,0.88376915,1,1
|
| 237 |
+
IMG_001683.png,1,0.7803761,1,1
|
| 238 |
+
IMG_000069.png,1,0.9368211,1,1
|
model_training/src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src package
|
model_training/src/config.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Config:
|
| 12 |
+
# ββ Data paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
+
csv_path: str = "../../data/train_val.csv"
|
| 14 |
+
image_dir: str = "../../data/images"
|
| 15 |
+
submission_test_dir: str = "../../data/test_images"
|
| 16 |
+
output_dir: str = "results"
|
| 17 |
+
results_log_path: str = "results_log.csv" # global run log (one row per training run)
|
| 18 |
+
|
| 19 |
+
# ββ Reproducibility ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
seed: int = 42
|
| 21 |
+
|
| 22 |
+
# ββ Image / DataLoader βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
img_size: int = 224 # torchxrayvision DenseNet-121 native resolution
|
| 24 |
+
batch_size: int = 32
|
| 25 |
+
num_workers: int = 4
|
| 26 |
+
|
| 27 |
+
# ββ Train / val / test split βββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
val_size: float = 0.15
|
| 29 |
+
test_size: float = 0.15
|
| 30 |
+
|
| 31 |
+
# ββ Training schedule (two-stage) ββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
frozen_epochs: int = 3 # stage 1: head-only warmup
|
| 33 |
+
finetune_epochs: int = 22 # stage 2: full unfreeze with cosine LR
|
| 34 |
+
early_stop_patience: int = 6 # early stop on val AUC during stage 2
|
| 35 |
+
|
| 36 |
+
# ββ Optimiser ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
head_lr: float = 3e-4 # classifier LR (both stages)
|
| 38 |
+
backbone_lr: float = 1e-4 # features LR (stage 2 only)
|
| 39 |
+
weight_decay: float = 1e-4
|
| 40 |
+
grad_clip: float = 1.0
|
| 41 |
+
|
| 42 |
+
# ββ Data augmentation ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
# Mixup: interpolates two samples and their labels in every training batch.
|
| 44 |
+
# mixup_alpha > 0 enables it; Ξ» ~ Beta(Ξ±, Ξ±). 0 = disabled.
|
| 45 |
+
# Typical range: 0.2 β 0.4.
|
| 46 |
+
mixup_alpha: float = 0.0
|
| 47 |
+
# Label smoothing: prevents overconfidence by softening hard {0,1} targets.
|
| 48 |
+
# y_smooth = y*(1-Ξ΅) + 0.5*Ξ΅. 0 = disabled. Typical range: 0.05 β 0.15.
|
| 49 |
+
label_smoothing: float = 0.0
|
| 50 |
+
|
| 51 |
+
# ββ Architecture βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
# Options: "densenet121" | "rad-dino" | "mobilenet_v3_large" | "efficientnet_b0" | "efficientnet_b3"
|
| 53 |
+
# densenet121 β torchxrayvision DenseNet-121, pretrained on ~1M chest X-rays (recommended)
|
| 54 |
+
# rad-dino β microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M chest X-rays;
|
| 55 |
+
# use img_size=518 (native: 37Γ37 patches at 14 px); 12 frozen_blocks max
|
| 56 |
+
# mobilenet_v3_large β torchvision MobileNetV3-Large, pretrained on ImageNet (faster, lighter)
|
| 57 |
+
# efficientnet_b0 β torchvision EfficientNet-B0, pretrained on ImageNet (good accuracy/size trade-off)
|
| 58 |
+
# efficientnet_b3 β torchvision EfficientNet-B3, pretrained on ImageNet (higher accuracy, more params)
|
| 59 |
+
backbone: str = "efficientnet_b0"
|
| 60 |
+
|
| 61 |
+
# ββ Ensemble βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
# True: train one model per entry in `seeds` and average predictions
|
| 63 |
+
# False: train a single model using only `seed` (faster experimentation)
|
| 64 |
+
use_ensemble: bool = True
|
| 65 |
+
|
| 66 |
+
# ββ Multi-seed ensemble ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
seeds: List[int] = field(default_factory=lambda: [42, 7, 2024])
|
| 68 |
+
|
| 69 |
+
# ββ Loss function βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
# False: standard BCE | True: 0.5*BCE + 0.5*(1 - soft_composite)
|
| 71 |
+
use_composite_loss: bool = False
|
| 72 |
+
# Blend weight Ξ±: Ξ±Β·BCE + (1-Ξ±)Β·(1-soft_composite). 0 = pure composite, 1 = pure BCE.
|
| 73 |
+
composite_loss_alpha: float = 0.5
|
| 74 |
+
# Temperature for the pairwise-sigmoid soft-AUC term (higher β sharper ranking signal)
|
| 75 |
+
composite_loss_gamma: float = 1.0
|
| 76 |
+
|
| 77 |
+
# ββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
tta_passes: int = 6 # number of deterministic TTA transforms (max 6)
|
| 79 |
+
n_bootstrap: int = 1000 # bootstrap iterations for threshold stabilisation
|
| 80 |
+
|
| 81 |
+
# ββ Device (auto-detected) βββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
device: str = (
|
| 83 |
+
"cuda" if torch.cuda.is_available()
|
| 84 |
+
else "mps" if torch.backends.mps.is_available()
|
| 85 |
+
else "cpu"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def setup(self) -> "Config":
|
| 89 |
+
"""Create output directory and return self (for chaining)."""
|
| 90 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 91 |
+
return self
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Global singleton β import and use directly, or override fields before training
|
| 95 |
+
CFG = Config().setup()
|
model_training/src/data.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
|
| 10 |
+
from src.config import CFG
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Column auto-detection
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
FILENAME_CANDIDATES = [
|
| 17 |
+
"image_name", "filename", "file", "image", "image_id", "img", "name",
|
| 18 |
+
"image index", "image_index", # NIH ChestX-ray14
|
| 19 |
+
]
|
| 20 |
+
LABEL_CANDIDATES = [
|
| 21 |
+
"label", "cardiomegaly", "class", "target", "y",
|
| 22 |
+
"finding_labels", "finding labels", "finding", # NIH ChestX-ray14
|
| 23 |
+
"labels",
|
| 24 |
+
]
|
| 25 |
+
POSITIVE_KEYWORD = "cardiomegaly"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _autodetect(df: pd.DataFrame, candidates: list[str]) -> str:
|
| 29 |
+
"""Return the first column in *df* whose lowercase name is in *candidates*."""
|
| 30 |
+
lower = {c.lower(): c for c in df.columns}
|
| 31 |
+
for cand in candidates:
|
| 32 |
+
if cand in lower:
|
| 33 |
+
return lower[cand]
|
| 34 |
+
raise ValueError(f"None of {candidates} found in columns: {list(df.columns)}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _coerce_to_binary(series: pd.Series) -> pd.Series:
|
| 38 |
+
"""Map mixed label encodings (0/1, 'cardiomegaly', 'no finding', bool, ...) to 0/1."""
|
| 39 |
+
def to_int(v):
|
| 40 |
+
if pd.isna(v):
|
| 41 |
+
return 0
|
| 42 |
+
if isinstance(v, (int, np.integer)):
|
| 43 |
+
return int(v != 0)
|
| 44 |
+
if isinstance(v, (float, np.floating)):
|
| 45 |
+
return int(v != 0)
|
| 46 |
+
if isinstance(v, bool):
|
| 47 |
+
return int(v)
|
| 48 |
+
s = str(v).strip().lower()
|
| 49 |
+
if s in {"1", "true", "yes", "y", "positive", "pos"}:
|
| 50 |
+
return 1
|
| 51 |
+
if s in {"0", "false", "no", "n", "negative", "neg", "no finding", ""}:
|
| 52 |
+
return 0
|
| 53 |
+
return int(POSITIVE_KEYWORD in s)
|
| 54 |
+
return series.apply(to_int).astype(int)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _resolve_filenames(df: pd.DataFrame, filename_col: str, image_dir: str) -> pd.DataFrame:
|
| 58 |
+
"""Add an `image_path` column. Drops rows whose file cannot be found.
|
| 59 |
+
|
| 60 |
+
Tolerates different case, trailing spaces, and missing/wrong extensions.
|
| 61 |
+
"""
|
| 62 |
+
disk: dict[str, str] = {}
|
| 63 |
+
for entry in os.scandir(image_dir):
|
| 64 |
+
if not entry.is_file():
|
| 65 |
+
continue
|
| 66 |
+
name = entry.name
|
| 67 |
+
disk[name.lower()] = name
|
| 68 |
+
stem = os.path.splitext(name)[0].lower()
|
| 69 |
+
disk.setdefault(stem, name)
|
| 70 |
+
|
| 71 |
+
resolved, missing = [], []
|
| 72 |
+
for fn in df[filename_col].astype(str):
|
| 73 |
+
raw = fn.strip()
|
| 74 |
+
raw_l = raw.lower()
|
| 75 |
+
hit = disk.get(raw_l) or disk.get(os.path.splitext(raw_l)[0])
|
| 76 |
+
if hit is None:
|
| 77 |
+
for ext in (".png", ".jpg", ".jpeg"):
|
| 78 |
+
if raw_l + ext in disk:
|
| 79 |
+
hit = disk[raw_l + ext]
|
| 80 |
+
break
|
| 81 |
+
if hit is None:
|
| 82 |
+
missing.append(raw)
|
| 83 |
+
resolved.append(None)
|
| 84 |
+
else:
|
| 85 |
+
resolved.append(os.path.join(image_dir, hit))
|
| 86 |
+
|
| 87 |
+
df = df.copy()
|
| 88 |
+
df["image_path"] = resolved
|
| 89 |
+
keep = df["image_path"].notna()
|
| 90 |
+
if (~keep).any():
|
| 91 |
+
print(f"Warning: {(~keep).sum()} rows dropped (file not found). Examples: {missing[:5]}")
|
| 92 |
+
return df[keep].reset_index(drop=True)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Public API
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
def load_labels(csv_path: str, image_dir: str) -> pd.DataFrame:
|
| 99 |
+
"""Read CSV, auto-detect filename + label columns, coerce labels, resolve paths.
|
| 100 |
+
|
| 101 |
+
Returned DataFrame columns: filename, label, image_path
|
| 102 |
+
"""
|
| 103 |
+
df = pd.read_csv(csv_path)
|
| 104 |
+
fn_col = _autodetect(df, FILENAME_CANDIDATES)
|
| 105 |
+
lb_col = _autodetect(df, LABEL_CANDIDATES)
|
| 106 |
+
print(f"Detected filename column: {fn_col!r} label column: {lb_col!r}")
|
| 107 |
+
|
| 108 |
+
df = df[[fn_col, lb_col]].rename(columns={fn_col: "filename", lb_col: "label"})
|
| 109 |
+
df["label"] = _coerce_to_binary(df["label"])
|
| 110 |
+
df = _resolve_filenames(df, "filename", image_dir)
|
| 111 |
+
df = df.drop_duplicates(subset=["filename"]).reset_index(drop=True)
|
| 112 |
+
|
| 113 |
+
if len(df) == 0:
|
| 114 |
+
raise ValueError("No valid labelled images found.")
|
| 115 |
+
|
| 116 |
+
n_pos = int(df["label"].sum())
|
| 117 |
+
n_neg = int((df["label"] == 0).sum())
|
| 118 |
+
print(f"Loaded {len(df)} labelled images pos={n_pos} neg={n_neg}")
|
| 119 |
+
return df
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def split_dataframe(
|
| 123 |
+
df: pd.DataFrame,
|
| 124 |
+
val_size: float | None = None,
|
| 125 |
+
test_size: float | None = None,
|
| 126 |
+
seed: int | None = None,
|
| 127 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 128 |
+
"""Stratified train / val / test split.
|
| 129 |
+
|
| 130 |
+
Falls back to CFG values when parameters are not supplied.
|
| 131 |
+
"""
|
| 132 |
+
val_size = val_size if val_size is not None else CFG.val_size
|
| 133 |
+
test_size = test_size if test_size is not None else CFG.test_size
|
| 134 |
+
seed = seed if seed is not None else CFG.seed
|
| 135 |
+
|
| 136 |
+
train_tmp_df, test_df = train_test_split(
|
| 137 |
+
df, test_size=test_size, stratify=df["label"], random_state=seed,
|
| 138 |
+
)
|
| 139 |
+
rel_val = val_size / (1.0 - test_size)
|
| 140 |
+
train_df, val_df = train_test_split(
|
| 141 |
+
train_tmp_df, test_size=rel_val,
|
| 142 |
+
stratify=train_tmp_df["label"], random_state=seed,
|
| 143 |
+
)
|
| 144 |
+
return (
|
| 145 |
+
train_df.reset_index(drop=True),
|
| 146 |
+
val_df.reset_index(drop=True),
|
| 147 |
+
test_df.reset_index(drop=True),
|
| 148 |
+
)
|
model_training/src/dataset.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
|
| 13 |
+
import torchxrayvision as xrv
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Normalisation functions (one per backbone family)
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
def xrv_normalize_np(pil_img: Image.Image) -> torch.Tensor:
|
| 20 |
+
"""PIL grayscale β (1, H, W) float tensor in [-1024, 1024] (torchxrayvision)."""
|
| 21 |
+
arr = np.array(pil_img, dtype=np.float32) # (H, W) in [0, 255]
|
| 22 |
+
arr = xrv.datasets.normalize(arr, 255) # β [-1024, 1024]
|
| 23 |
+
arr = arr[None, ...] # (1, H, W)
|
| 24 |
+
return torch.from_numpy(arr).float()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
|
| 28 |
+
_IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def imagenet_normalize_np(pil_img: Image.Image) -> torch.Tensor:
|
| 32 |
+
"""PIL grayscale β (3, H, W) float tensor normalized with ImageNet stats.
|
| 33 |
+
|
| 34 |
+
The single grayscale channel is replicated to 3 channels so that ImageNet-
|
| 35 |
+
pretrained backbones (MobileNet, EfficientNet) receive the expected input shape.
|
| 36 |
+
"""
|
| 37 |
+
arr = np.array(pil_img, dtype=np.float32) / 255.0 # [0, 1]
|
| 38 |
+
arr = np.stack([arr, arr, arr], axis=0) # (3, H, W)
|
| 39 |
+
arr = (arr - _IMAGENET_MEAN) / _IMAGENET_STD
|
| 40 |
+
return torch.from_numpy(arr).float()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_normalize_fn(backbone: str):
|
| 44 |
+
"""Return the correct normalization callable for the given backbone name.
|
| 45 |
+
|
| 46 |
+
"densenet121" / "densenet121-res224-all"
|
| 47 |
+
β xrv_normalize_np (grayscale, [-1024, 1024])
|
| 48 |
+
"rad-dino"
|
| 49 |
+
β imagenet_normalize_np (3-ch RGB replicated, ImageNet stats)
|
| 50 |
+
RAD-DINO is a ViT-B/14; feed at 518Γ518 for best accuracy.
|
| 51 |
+
all other torchvision backbones
|
| 52 |
+
β imagenet_normalize_np
|
| 53 |
+
"""
|
| 54 |
+
if backbone in ("densenet121", "densenet121-res224-all"):
|
| 55 |
+
return xrv_normalize_np
|
| 56 |
+
return imagenet_normalize_np
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# Labelled dataset (train / val / test)
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
class ChestXrayDataset(Dataset):
|
| 63 |
+
"""Returns (image_tensor, label, filename) triples.
|
| 64 |
+
|
| 65 |
+
backbone controls the normalization applied after PIL transforms:
|
| 66 |
+
"densenet121" β single-channel tensor in [-1024, 1024] (xrv)
|
| 67 |
+
any torchvision model β 3-channel tensor with ImageNet normalization
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
df: pd.DataFrame,
|
| 73 |
+
pil_transform=None,
|
| 74 |
+
use_erasing: bool = False,
|
| 75 |
+
backbone: str | None = None,
|
| 76 |
+
) -> None:
|
| 77 |
+
from src.config import CFG
|
| 78 |
+
self.df = df.reset_index(drop=True)
|
| 79 |
+
self.pil_transform = pil_transform
|
| 80 |
+
self.use_erasing = use_erasing
|
| 81 |
+
self._normalize = get_normalize_fn(backbone or CFG.backbone)
|
| 82 |
+
self._erasing = T.RandomErasing(
|
| 83 |
+
p=0.5, scale=(0.02, 0.08), ratio=(0.3, 3.3), value=0
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def __len__(self) -> int:
|
| 87 |
+
return len(self.df)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx: int):
|
| 90 |
+
row = self.df.iloc[idx]
|
| 91 |
+
img = Image.open(row["image_path"]).convert("L")
|
| 92 |
+
if self.pil_transform is not None:
|
| 93 |
+
img = self.pil_transform(img)
|
| 94 |
+
normalize = getattr(self, "_normalize", xrv_normalize_np)
|
| 95 |
+
tensor = normalize(img)
|
| 96 |
+
if self.use_erasing:
|
| 97 |
+
tensor = self._erasing(tensor)
|
| 98 |
+
label = torch.tensor(float(row["label"]), dtype=torch.float32)
|
| 99 |
+
return tensor, label, row["filename"]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# TTA dataset β flexible source (DataFrame with image_path, or external dir)
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
class TTADataset(Dataset):
|
| 106 |
+
"""Used by inference passes (one TTA transform per pass)."""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
df: pd.DataFrame,
|
| 111 |
+
pil_transform,
|
| 112 |
+
image_dir: Optional[str] = None,
|
| 113 |
+
backbone: str | None = None,
|
| 114 |
+
) -> None:
|
| 115 |
+
from src.config import CFG
|
| 116 |
+
self.df = df.reset_index(drop=True)
|
| 117 |
+
self.pil_transform = pil_transform
|
| 118 |
+
self.image_dir = image_dir
|
| 119 |
+
self._normalize = get_normalize_fn(backbone or CFG.backbone)
|
| 120 |
+
|
| 121 |
+
def __len__(self) -> int:
|
| 122 |
+
return len(self.df)
|
| 123 |
+
|
| 124 |
+
def __getitem__(self, idx: int):
|
| 125 |
+
row = self.df.iloc[idx]
|
| 126 |
+
if "image_path" in row and pd.notna(row.get("image_path")):
|
| 127 |
+
path = row["image_path"]
|
| 128 |
+
else:
|
| 129 |
+
path = os.path.join(self.image_dir, row["filename"])
|
| 130 |
+
img = Image.open(path).convert("L")
|
| 131 |
+
img = self.pil_transform(img)
|
| 132 |
+
normalize = getattr(self, "_normalize", xrv_normalize_np)
|
| 133 |
+
tensor = normalize(img)
|
| 134 |
+
label = float(row["label"]) if "label" in row and not pd.isna(row.get("label", np.nan)) else 0.0
|
| 135 |
+
name = row["filename"] if "filename" in row else os.path.basename(path)
|
| 136 |
+
return tensor, torch.tensor(label, dtype=torch.float32), name
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Submission dataset (unlabelled images in a flat directory)
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
class SubmissionDataset(Dataset):
|
| 143 |
+
"""Unlabelled test images for final inference.
|
| 144 |
+
|
| 145 |
+
Returns (image_tensor, filename).
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
image_dir: str,
|
| 151 |
+
pil_transform=None,
|
| 152 |
+
backbone: str | None = None,
|
| 153 |
+
) -> None:
|
| 154 |
+
from src.config import CFG
|
| 155 |
+
self.image_dir = image_dir
|
| 156 |
+
self.pil_transform = pil_transform
|
| 157 |
+
self._normalize = get_normalize_fn(backbone or CFG.backbone)
|
| 158 |
+
self.image_files = sorted(
|
| 159 |
+
f for f in os.listdir(image_dir)
|
| 160 |
+
if f.lower().endswith((".png", ".jpg", ".jpeg"))
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def __len__(self) -> int:
|
| 164 |
+
return len(self.image_files)
|
| 165 |
+
|
| 166 |
+
def __getitem__(self, idx: int):
|
| 167 |
+
fname = self.image_files[idx]
|
| 168 |
+
img = Image.open(os.path.join(self.image_dir, fname)).convert("L")
|
| 169 |
+
if self.pil_transform is not None:
|
| 170 |
+
img = self.pil_transform(img)
|
| 171 |
+
normalize = getattr(self, "_normalize", xrv_normalize_np)
|
| 172 |
+
tensor = normalize(img)
|
| 173 |
+
return tensor, fname
|
model_training/src/model.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchxrayvision as xrv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ---------------------------------------------------------------------------
|
| 11 |
+
# RAD-DINO wrapper
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
class RadDinoWrapper(nn.Module):
|
| 14 |
+
"""microsoft/rad-dino β DINOv2 ViT-B/14 pretrained on ~1 M chest X-rays.
|
| 15 |
+
|
| 16 |
+
Wraps the HuggingFace model to expose the same ``.features`` / ``.classifier``
|
| 17 |
+
contract used by every other backbone, so freeze helpers and the two-stage
|
| 18 |
+
optimiser work without modification.
|
| 19 |
+
|
| 20 |
+
Architecture
|
| 21 |
+
ββββββββββββ
|
| 22 |
+
.features β the full Dinov2Model (embeddings + 12 transformer blocks + layernorm)
|
| 23 |
+
.classifier β nn.Linear(hidden_size=768, out_features=1)
|
| 24 |
+
|
| 25 |
+
Forward pass
|
| 26 |
+
ββββββββββββ
|
| 27 |
+
x : (B, 3, H, W) float tensor β ImageNet-normalised, any multiple of 14 px.
|
| 28 |
+
Recommended resolution: 518 Γ 518 (native: 37 Γ 37 patches at 14 px).
|
| 29 |
+
Returns (B,) logit tensor.
|
| 30 |
+
|
| 31 |
+
Freeze / unfreeze
|
| 32 |
+
βββββββββββββββββ
|
| 33 |
+
freeze_backbone() β freezes .features (all 12 blocks + embeddings)
|
| 34 |
+
partial_unfreeze(N) β unfreeze last (12 β N) blocks + layernorm;
|
| 35 |
+
keep embeddings + first N blocks frozen.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
from transformers import AutoModel # lazy β only loaded when this backbone is used
|
| 41 |
+
dinov2 = AutoModel.from_pretrained("microsoft/rad-dino")
|
| 42 |
+
self.features = dinov2
|
| 43 |
+
self.classifier = nn.Linear(dinov2.config.hidden_size, 1)
|
| 44 |
+
nn.init.trunc_normal_(self.classifier.weight, std=0.02)
|
| 45 |
+
nn.init.zeros_(self.classifier.bias)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
out = self.features(pixel_values=x) # Dinov2ModelOutput
|
| 49 |
+
cls = out.last_hidden_state[:, 0] # CLS token (B, 768)
|
| 50 |
+
return self.classifier(cls) # (B, 1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Backbone factory
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
def build_model(backbone: str | None = None) -> nn.Module:
|
| 57 |
+
"""Build a backbone model for Cardiomegaly classification.
|
| 58 |
+
|
| 59 |
+
backbone options (also set via CFG.backbone):
|
| 60 |
+
"densenet121" β torchxrayvision DenseNet-121, pretrained on ~1M chest
|
| 61 |
+
X-rays; outputs raw Cardiomegaly logit via pathology index.
|
| 62 |
+
"rad-dino" β microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M
|
| 63 |
+
chest X-rays (HuggingFace); 518Γ518 recommended input.
|
| 64 |
+
"mobilenet_v3_large" β torchvision MobileNetV3-Large (ImageNet); final linear
|
| 65 |
+
replaced with a single-output head.
|
| 66 |
+
"efficientnet_b0" β torchvision EfficientNet-B0 (ImageNet); same replacement.
|
| 67 |
+
"efficientnet_b3" β torchvision EfficientNet-B3 (ImageNet); same replacement.
|
| 68 |
+
|
| 69 |
+
All returned models expose .features and .classifier so that freeze_backbone()
|
| 70 |
+
and the two-stage optimizer in train_one_seed() work unchanged.
|
| 71 |
+
Input tensor format differs by backbone β use dataset.get_normalize_fn(backbone).
|
| 72 |
+
"""
|
| 73 |
+
from src.config import CFG # lazy to avoid circular import at module load
|
| 74 |
+
backbone = backbone or CFG.backbone
|
| 75 |
+
|
| 76 |
+
if backbone in ("densenet121", "densenet121-res224-all"):
|
| 77 |
+
model = xrv.models.DenseNet(weights="densenet121-res224-all")
|
| 78 |
+
model.op_threshs = None # raw logits at every output
|
| 79 |
+
model.apply_sigmoid = False # belt + suspenders
|
| 80 |
+
return model
|
| 81 |
+
|
| 82 |
+
if backbone == "rad-dino":
|
| 83 |
+
return RadDinoWrapper()
|
| 84 |
+
|
| 85 |
+
import torchvision.models as tvm
|
| 86 |
+
|
| 87 |
+
if backbone == "mobilenet_v3_large":
|
| 88 |
+
model = tvm.mobilenet_v3_large(weights=tvm.MobileNet_V3_Large_Weights.IMAGENET1K_V2)
|
| 89 |
+
in_features = model.classifier[-1].in_features
|
| 90 |
+
model.classifier[-1] = nn.Linear(in_features, 1)
|
| 91 |
+
return model
|
| 92 |
+
|
| 93 |
+
if backbone in ("efficientnet_b0", "efficientnet_b3"):
|
| 94 |
+
if backbone == "efficientnet_b0":
|
| 95 |
+
model = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1)
|
| 96 |
+
else:
|
| 97 |
+
model = tvm.efficientnet_b3(weights=tvm.EfficientNet_B3_Weights.IMAGENET1K_V1)
|
| 98 |
+
in_features = model.classifier[-1].in_features
|
| 99 |
+
model.classifier[-1] = nn.Linear(in_features, 1)
|
| 100 |
+
return model
|
| 101 |
+
|
| 102 |
+
raise ValueError(
|
| 103 |
+
f"Unknown backbone: {backbone!r}. "
|
| 104 |
+
"Choose from: densenet121, rad-dino, mobilenet_v3_large, efficientnet_b0, efficientnet_b3"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def cardio_logit(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
"""Forward pass returning a (B,) tensor of raw logits for Cardiomegaly.
|
| 110 |
+
|
| 111 |
+
For torchxrayvision DenseNet the logit is extracted from the pathology head.
|
| 112 |
+
For all other backbones (MobileNet, EfficientNet, RadDinoWrapper) the model
|
| 113 |
+
outputs (B, 1) which is squeezed to (B,).
|
| 114 |
+
"""
|
| 115 |
+
if isinstance(model, xrv.models.DenseNet):
|
| 116 |
+
out = model(x) # (B, num_pathologies)
|
| 117 |
+
idx = model.pathologies.index("Cardiomegaly")
|
| 118 |
+
return out[:, idx]
|
| 119 |
+
return model(x).squeeze(1) # (B, 1) β (B,)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# Backbone management helpers
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
def freeze_backbone(model: nn.Module) -> nn.Module:
|
| 126 |
+
"""Freeze all params in .features; keep .classifier trainable."""
|
| 127 |
+
for p in model.features.parameters():
|
| 128 |
+
p.requires_grad = False
|
| 129 |
+
for p in model.classifier.parameters():
|
| 130 |
+
p.requires_grad = True
|
| 131 |
+
return model
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def unfreeze_all(model: nn.Module) -> nn.Module:
|
| 135 |
+
"""Unfreeze every parameter. Kept for backwards compatibility; prefer partial_unfreeze."""
|
| 136 |
+
for p in model.parameters():
|
| 137 |
+
p.requires_grad = True
|
| 138 |
+
return model
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# DenseNet-121 block groups: (block_name, transition_name) for blocks 1β4
|
| 142 |
+
_DENSENET_BLOCK_GROUPS = [
|
| 143 |
+
("denseblock1", "transition1"),
|
| 144 |
+
("denseblock2", "transition2"),
|
| 145 |
+
("denseblock3", "transition3"),
|
| 146 |
+
("denseblock4", "norm5"),
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def partial_unfreeze(model: nn.Module, frozen_blocks: int = 0) -> nn.Module:
|
| 151 |
+
"""Selectively unfreeze the model for stage-2 fine-tuning.
|
| 152 |
+
|
| 153 |
+
frozen_blocks β how many feature blocks to keep frozen:
|
| 154 |
+
0 β unfreeze everything (same as unfreeze_all)
|
| 155 |
+
|
| 156 |
+
DenseNet-121 (4 dense block groups):
|
| 157 |
+
1 β keep denseblock1 (+transition1) frozen
|
| 158 |
+
2 β keep denseblock1β2 frozen
|
| 159 |
+
3 β keep denseblock1β3 frozen
|
| 160 |
+
4 β keep all dense blocks frozen (only classifier trains)
|
| 161 |
+
|
| 162 |
+
RAD-DINO / ViT-B (12 transformer blocks):
|
| 163 |
+
1β12 β keep embeddings + first N transformer blocks frozen
|
| 164 |
+
(last 12βN blocks + layernorm are unfrozen)
|
| 165 |
+
β₯12 β keep all transformer blocks frozen (only classifier trains)
|
| 166 |
+
|
| 167 |
+
torchvision models (MobileNet, EfficientNet):
|
| 168 |
+
N β freeze first N indexed children of model.features.
|
| 169 |
+
"""
|
| 170 |
+
for p in model.parameters():
|
| 171 |
+
p.requires_grad = True
|
| 172 |
+
|
| 173 |
+
if frozen_blocks <= 0:
|
| 174 |
+
return model
|
| 175 |
+
|
| 176 |
+
if isinstance(model, xrv.models.DenseNet):
|
| 177 |
+
frozen_names: set[str] = set()
|
| 178 |
+
for i in range(min(frozen_blocks, len(_DENSENET_BLOCK_GROUPS))):
|
| 179 |
+
frozen_names.update(_DENSENET_BLOCK_GROUPS[i])
|
| 180 |
+
for name, module in model.features.named_children():
|
| 181 |
+
if name in frozen_names:
|
| 182 |
+
for p in module.parameters():
|
| 183 |
+
p.requires_grad = False
|
| 184 |
+
|
| 185 |
+
elif isinstance(model, RadDinoWrapper):
|
| 186 |
+
# Always freeze the patch/position embeddings
|
| 187 |
+
for p in model.features.embeddings.parameters():
|
| 188 |
+
p.requires_grad = False
|
| 189 |
+
# Freeze the first `frozen_blocks` transformer blocks
|
| 190 |
+
encoder_layers = model.features.encoder.layer
|
| 191 |
+
for block in encoder_layers[:frozen_blocks]:
|
| 192 |
+
for p in block.parameters():
|
| 193 |
+
p.requires_grad = False
|
| 194 |
+
|
| 195 |
+
else:
|
| 196 |
+
for module in list(model.features.children())[:frozen_blocks]:
|
| 197 |
+
for p in module.parameters():
|
| 198 |
+
p.requires_grad = False
|
| 199 |
+
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def trainable_params(model: nn.Module) -> List[nn.Parameter]:
|
| 204 |
+
"""List of parameters with `requires_grad=True` (for optimiser construction)."""
|
| 205 |
+
return [p for p in model.parameters() if p.requires_grad]
|
model_training/src/train.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
from sklearn.metrics import confusion_matrix, roc_auc_score
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
|
| 16 |
+
from src.config import CFG
|
| 17 |
+
from src.dataset import ChestXrayDataset, SubmissionDataset, TTADataset
|
| 18 |
+
from src.model import (
|
| 19 |
+
build_model,
|
| 20 |
+
cardio_logit,
|
| 21 |
+
freeze_backbone,
|
| 22 |
+
partial_unfreeze,
|
| 23 |
+
trainable_params,
|
| 24 |
+
unfreeze_all,
|
| 25 |
+
)
|
| 26 |
+
from src.transforms import make_tta_transforms
|
| 27 |
+
from src.utils import free_device_cache, log_run, set_seed
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Mixup helper
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
def mixup_data(
|
| 34 |
+
x: torch.Tensor,
|
| 35 |
+
y: torch.Tensor,
|
| 36 |
+
alpha: float = 0.4,
|
| 37 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 38 |
+
"""Return a randomly mixed batch and the corresponding soft labels.
|
| 39 |
+
|
| 40 |
+
Ξ» ~ Beta(Ξ±, Ξ±). When Ξ± β€ 0 the original batch is returned unchanged.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
x: Image tensor (B, C, H, W) on the training device.
|
| 44 |
+
y: Label tensor (B,) β may already be soft (e.g. after smoothing).
|
| 45 |
+
alpha: Beta distribution parameter. Typical: 0.2 β 0.4.
|
| 46 |
+
"""
|
| 47 |
+
if alpha <= 0:
|
| 48 |
+
return x, y
|
| 49 |
+
lam = float(np.random.beta(alpha, alpha))
|
| 50 |
+
idx = torch.randperm(x.size(0), device=x.device)
|
| 51 |
+
mixed_x = lam * x + (1.0 - lam) * x[idx]
|
| 52 |
+
mixed_y = lam * y + (1.0 - lam) * y[idx]
|
| 53 |
+
return mixed_x, mixed_y
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# Differentiable composite loss
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
class SoftCompositeLoss(nn.Module):
|
| 60 |
+
"""Differentiable approximation of composite = 0.5Β·AUC + 0.25Β·sens + 0.25Β·spec.
|
| 61 |
+
|
| 62 |
+
Minimises ``1 - soft_composite``, blended with standard BCE for stability.
|
| 63 |
+
|
| 64 |
+
Soft-AUC
|
| 65 |
+
Pairwise sigmoid over all (positive, negative) logit pairs in the batch:
|
| 66 |
+
``soft_auc = mean( Ο(Ξ³ Β· (logitβΊ β logitβ»)) )``
|
| 67 |
+
where Ξ³ (``auc_gamma``) is a sharpness temperature.
|
| 68 |
+
|
| 69 |
+
Soft-sens / soft-spec
|
| 70 |
+
``soft_sens = mean( Ο(logit) | y=1 )``
|
| 71 |
+
``soft_spec = mean( 1 β Ο(logit) | y=0 )``
|
| 72 |
+
|
| 73 |
+
Total loss
|
| 74 |
+
``Ξ± Β· BCE + (1 β Ξ±) Β· (1 β soft_composite)``
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
alpha: Weight of BCE in the blend (0 = pure composite, 1 = pure BCE).
|
| 78 |
+
auc_gamma: Temperature for the pairwise sigmoid (higher β sharper AUC signal).
|
| 79 |
+
eps: Numerical stability floor.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, alpha: float = 0.5, auc_gamma: float = 1.0, eps: float = 1e-7):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.alpha = alpha
|
| 85 |
+
self.auc_gamma = auc_gamma
|
| 86 |
+
self.eps = eps
|
| 87 |
+
self._bce = nn.BCEWithLogitsLoss()
|
| 88 |
+
|
| 89 |
+
def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
bce_loss = self._bce(logit, target)
|
| 91 |
+
|
| 92 |
+
prob = torch.sigmoid(logit)
|
| 93 |
+
# Use > 0.5 so the masks work correctly for both hard labels {0,1}
|
| 94 |
+
# and soft targets produced by mixup or label smoothing.
|
| 95 |
+
pos_mask = (target > 0.5)
|
| 96 |
+
neg_mask = ~pos_mask
|
| 97 |
+
n_pos = pos_mask.sum()
|
| 98 |
+
n_neg = neg_mask.sum()
|
| 99 |
+
|
| 100 |
+
# ββ Soft AUC (pairwise) ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 101 |
+
if n_pos > 0 and n_neg > 0:
|
| 102 |
+
pos_logits = logit[pos_mask] # (n_pos,)
|
| 103 |
+
neg_logits = logit[neg_mask] # (n_neg,)
|
| 104 |
+
diff = pos_logits.unsqueeze(1) - neg_logits.unsqueeze(0) # (n_pos, n_neg)
|
| 105 |
+
soft_auc = torch.sigmoid(self.auc_gamma * diff).mean()
|
| 106 |
+
else:
|
| 107 |
+
soft_auc = torch.tensor(0.5, device=logit.device, dtype=logit.dtype)
|
| 108 |
+
|
| 109 |
+
# ββ Soft sensitivity / specificity ββββββββββββββββββββββββββββββββββ
|
| 110 |
+
soft_sens = prob[pos_mask].mean() if n_pos > 0 else torch.tensor(
|
| 111 |
+
0.0, device=logit.device, dtype=logit.dtype)
|
| 112 |
+
soft_spec = (1.0 - prob[neg_mask]).mean() if n_neg > 0 else torch.tensor(
|
| 113 |
+
0.0, device=logit.device, dtype=logit.dtype)
|
| 114 |
+
|
| 115 |
+
soft_composite = 0.5 * soft_auc + 0.25 * soft_sens + 0.25 * soft_spec
|
| 116 |
+
composite_loss = 1.0 - soft_composite
|
| 117 |
+
|
| 118 |
+
return self.alpha * bce_loss + (1.0 - self.alpha) * composite_loss
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
# Epoch runner
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
def run_one_epoch(
|
| 125 |
+
model: nn.Module,
|
| 126 |
+
loader: DataLoader,
|
| 127 |
+
criterion: Optional[nn.Module] = None,
|
| 128 |
+
optimizer: Optional[optim.Optimizer] = None,
|
| 129 |
+
scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
| 130 |
+
grad_clip: Optional[float] = None,
|
| 131 |
+
mixup_alpha: float = 0.0,
|
| 132 |
+
label_smoothing: float = 0.0,
|
| 133 |
+
) -> dict:
|
| 134 |
+
"""Single forward pass over *loader*.
|
| 135 |
+
|
| 136 |
+
Pass ``optimizer=None`` for evaluation (mixup and smoothing are skipped).
|
| 137 |
+
Expects (image, label, name) batches. Uses CUDA AMP when available.
|
| 138 |
+
|
| 139 |
+
Hard original labels are always accumulated for metric computation;
|
| 140 |
+
the (potentially mixed + smoothed) soft labels are only used for the loss.
|
| 141 |
+
"""
|
| 142 |
+
is_train = optimizer is not None
|
| 143 |
+
model.train(is_train)
|
| 144 |
+
|
| 145 |
+
losses, logits_all, labels_all, names_all = [], [], [], []
|
| 146 |
+
pin = (CFG.device == "cuda")
|
| 147 |
+
grad_clip = grad_clip if grad_clip is not None else CFG.grad_clip
|
| 148 |
+
|
| 149 |
+
amp_ctx = torch.cuda.amp.autocast(enabled=(CFG.device == "cuda"))
|
| 150 |
+
for x, y, names in loader:
|
| 151 |
+
x = x.to(CFG.device, non_blocking=pin)
|
| 152 |
+
y = y.to(CFG.device, non_blocking=pin)
|
| 153 |
+
|
| 154 |
+
# Keep hard labels for metric accumulation (before any augmentation)
|
| 155 |
+
y_hard = y.detach().clone()
|
| 156 |
+
|
| 157 |
+
if is_train:
|
| 158 |
+
# Mixup: interpolate two samples + their labels in-place
|
| 159 |
+
if mixup_alpha > 0.0:
|
| 160 |
+
x, y = mixup_data(x, y, alpha=mixup_alpha)
|
| 161 |
+
# Label smoothing: y_smooth = y*(1-Ξ΅) + 0.5*Ξ΅
|
| 162 |
+
if label_smoothing > 0.0:
|
| 163 |
+
y = y * (1.0 - label_smoothing) + 0.5 * label_smoothing
|
| 164 |
+
|
| 165 |
+
with torch.set_grad_enabled(is_train):
|
| 166 |
+
with amp_ctx:
|
| 167 |
+
logit = cardio_logit(model, x)
|
| 168 |
+
loss = criterion(logit, y) if criterion is not None else None
|
| 169 |
+
|
| 170 |
+
if is_train:
|
| 171 |
+
optimizer.zero_grad(set_to_none=True)
|
| 172 |
+
if scaler is not None and scaler.is_enabled():
|
| 173 |
+
scaler.scale(loss).backward()
|
| 174 |
+
scaler.unscale_(optimizer)
|
| 175 |
+
torch.nn.utils.clip_grad_norm_(trainable_params(model), grad_clip)
|
| 176 |
+
scaler.step(optimizer)
|
| 177 |
+
scaler.update()
|
| 178 |
+
else:
|
| 179 |
+
loss.backward()
|
| 180 |
+
torch.nn.utils.clip_grad_norm_(trainable_params(model), grad_clip)
|
| 181 |
+
optimizer.step()
|
| 182 |
+
|
| 183 |
+
if loss is not None:
|
| 184 |
+
losses.append(loss.item())
|
| 185 |
+
logits_all.append(logit.detach().float().cpu().numpy())
|
| 186 |
+
labels_all.append(y_hard.float().cpu().numpy()) # always hard labels
|
| 187 |
+
names_all.extend(list(names))
|
| 188 |
+
|
| 189 |
+
y_true = np.concatenate(labels_all)
|
| 190 |
+
y_logit = np.concatenate(logits_all)
|
| 191 |
+
y_prob = 1.0 / (1.0 + np.exp(-y_logit))
|
| 192 |
+
auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else float("nan")
|
| 193 |
+
|
| 194 |
+
# Per-epoch metrics at threshold=0.5 (used for progress logging)
|
| 195 |
+
y_pred = (y_prob >= 0.5).astype(int)
|
| 196 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
|
| 197 |
+
acc = float((tp + tn) / (tp + tn + fp + fn)) if (tp + tn + fp + fn) > 0 else float("nan")
|
| 198 |
+
sens = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
|
| 199 |
+
spec = float(tn / (tn + fp)) if (tn + fp) > 0 else 0.0
|
| 200 |
+
composite = 0.5 * (auc if not np.isnan(auc) else 0.0) + 0.25 * sens + 0.25 * spec
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"loss": float(np.mean(losses)) if losses else float("nan"),
|
| 204 |
+
"auc": float(auc),
|
| 205 |
+
"acc": acc,
|
| 206 |
+
"sens": sens,
|
| 207 |
+
"spec": spec,
|
| 208 |
+
"composite": float(composite),
|
| 209 |
+
"y_true": y_true,
|
| 210 |
+
"y_prob": y_prob,
|
| 211 |
+
"names": names_all,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
# Single-seed two-stage training
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
def train_one_seed(
|
| 219 |
+
seed: int,
|
| 220 |
+
train_loader: DataLoader,
|
| 221 |
+
val_loader: DataLoader,
|
| 222 |
+
output_dir: Optional[str] = None,
|
| 223 |
+
config=None,
|
| 224 |
+
) -> Tuple[nn.Module, float, str, list[dict]]:
|
| 225 |
+
"""Train ONE model end-to-end (frozen warmup β full fine-tune).
|
| 226 |
+
|
| 227 |
+
Returns (best_model, best_val_auc, checkpoint_path, history).
|
| 228 |
+
"""
|
| 229 |
+
cfg = config or CFG
|
| 230 |
+
output_dir = output_dir or cfg.output_dir
|
| 231 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 232 |
+
|
| 233 |
+
set_seed(seed)
|
| 234 |
+
|
| 235 |
+
model = build_model(cfg.backbone).to(cfg.device)
|
| 236 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 237 |
+
|
| 238 |
+
print("\n" + "=" * 80)
|
| 239 |
+
print(f" Seed : {seed}")
|
| 240 |
+
print(f" Model : {model.__class__.__name__} ({total_params:,} total params)")
|
| 241 |
+
print(f" Backbone: {cfg.backbone}")
|
| 242 |
+
print(f" Device : {cfg.device}")
|
| 243 |
+
print("=" * 80)
|
| 244 |
+
|
| 245 |
+
if cfg.use_composite_loss:
|
| 246 |
+
criterion = SoftCompositeLoss(
|
| 247 |
+
alpha=cfg.composite_loss_alpha,
|
| 248 |
+
auc_gamma=cfg.composite_loss_gamma,
|
| 249 |
+
)
|
| 250 |
+
print(
|
| 251 |
+
f" Loss : SoftCompositeLoss "
|
| 252 |
+
f"(Ξ±={cfg.composite_loss_alpha}, Ξ³={cfg.composite_loss_gamma})"
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 256 |
+
print(" Loss : BCEWithLogitsLoss")
|
| 257 |
+
|
| 258 |
+
mixup_alpha = getattr(cfg, "mixup_alpha", 0.0)
|
| 259 |
+
label_smoothing = getattr(cfg, "label_smoothing", 0.0)
|
| 260 |
+
if mixup_alpha > 0:
|
| 261 |
+
print(f" Mixup : Ξ±={mixup_alpha}")
|
| 262 |
+
if label_smoothing > 0:
|
| 263 |
+
print(f" Smoothing: Ξ΅={label_smoothing}")
|
| 264 |
+
|
| 265 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(cfg.device == "cuda"))
|
| 266 |
+
history: list[dict] = []
|
| 267 |
+
|
| 268 |
+
_aug_kw = dict(mixup_alpha=mixup_alpha, label_smoothing=label_smoothing)
|
| 269 |
+
|
| 270 |
+
# ββ Stage 1: frozen backbone, head-only warmup βββββββββββββββββββββββββ
|
| 271 |
+
freeze_backbone(model)
|
| 272 |
+
n_trainable = sum(p.numel() for p in trainable_params(model))
|
| 273 |
+
print(f"\n Stage 1 β all backbone blocks FROZEN | {n_trainable:,} trainable params")
|
| 274 |
+
opt_frozen = optim.AdamW(
|
| 275 |
+
trainable_params(model), lr=cfg.head_lr, weight_decay=cfg.weight_decay,
|
| 276 |
+
)
|
| 277 |
+
for ep in range(1, cfg.frozen_epochs + 1):
|
| 278 |
+
t = run_one_epoch(model, train_loader, criterion, opt_frozen, scaler, **_aug_kw)
|
| 279 |
+
v = run_one_epoch(model, val_loader, criterion)
|
| 280 |
+
history.append({
|
| 281 |
+
"seed": seed, "stage": "frozen", "epoch": ep,
|
| 282 |
+
"train_loss": t["loss"], "train_auc": t["auc"],
|
| 283 |
+
"train_acc": t["acc"], "train_composite": t["composite"],
|
| 284 |
+
"val_loss": v["loss"], "val_auc": v["auc"],
|
| 285 |
+
"val_acc": v["acc"], "val_sens": v["sens"],
|
| 286 |
+
"val_spec": v["spec"], "val_composite": v["composite"],
|
| 287 |
+
"lr": opt_frozen.param_groups[0]["lr"],
|
| 288 |
+
})
|
| 289 |
+
print(
|
| 290 |
+
f" [frozen] {ep}/{cfg.frozen_epochs} "
|
| 291 |
+
f"loss={t['loss']:.4f} train_acc={t['acc']*100:.1f}% | "
|
| 292 |
+
f"val_auc={v['auc']:.4f} val_acc={v['acc']*100:.1f}% "
|
| 293 |
+
f"sens={v['sens']:.3f} spec={v['spec']:.3f} comp={v['composite']:.4f}"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# ββ Stage 2: partial or full fine-tune βββββββββββββββββββββββββββββββ
|
| 297 |
+
frozen_blocks = getattr(cfg, "frozen_blocks", 0)
|
| 298 |
+
partial_unfreeze(model, frozen_blocks)
|
| 299 |
+
n_trainable = sum(p.numel() for p in trainable_params(model))
|
| 300 |
+
if frozen_blocks == 0:
|
| 301 |
+
stage2_label = "all blocks UNFROZEN"
|
| 302 |
+
else:
|
| 303 |
+
stage2_label = f"{frozen_blocks} block(s) still FROZEN"
|
| 304 |
+
print(f"\n Stage 2 β {stage2_label} | {n_trainable:,} trainable params")
|
| 305 |
+
|
| 306 |
+
opt_ft = optim.AdamW(
|
| 307 |
+
[
|
| 308 |
+
{"params": model.features.parameters(), "lr": cfg.backbone_lr},
|
| 309 |
+
{"params": model.classifier.parameters(), "lr": cfg.head_lr},
|
| 310 |
+
],
|
| 311 |
+
weight_decay=cfg.weight_decay,
|
| 312 |
+
)
|
| 313 |
+
sched = optim.lr_scheduler.CosineAnnealingLR(
|
| 314 |
+
opt_ft, T_max=cfg.finetune_epochs, eta_min=cfg.backbone_lr * 0.01,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
best_auc, best_state, patience_ctr = -1.0, None, 0
|
| 318 |
+
for ep in range(1, cfg.finetune_epochs + 1):
|
| 319 |
+
t = run_one_epoch(model, train_loader, criterion, opt_ft, scaler, **_aug_kw)
|
| 320 |
+
v = run_one_epoch(model, val_loader, criterion)
|
| 321 |
+
sched.step()
|
| 322 |
+
history.append({
|
| 323 |
+
"seed": seed, "stage": "finetune", "epoch": ep,
|
| 324 |
+
"train_loss": t["loss"], "train_auc": t["auc"],
|
| 325 |
+
"train_acc": t["acc"], "train_composite": t["composite"],
|
| 326 |
+
"val_loss": v["loss"], "val_auc": v["auc"],
|
| 327 |
+
"val_acc": v["acc"], "val_sens": v["sens"],
|
| 328 |
+
"val_spec": v["spec"], "val_composite": v["composite"],
|
| 329 |
+
"lr": opt_ft.param_groups[0]["lr"],
|
| 330 |
+
})
|
| 331 |
+
print(
|
| 332 |
+
f" [ft] {ep}/{cfg.finetune_epochs} "
|
| 333 |
+
f"loss={t['loss']:.4f} train_acc={t['acc']*100:.1f}% | "
|
| 334 |
+
f"val_auc={v['auc']:.4f} val_acc={v['acc']*100:.1f}% "
|
| 335 |
+
f"sens={v['sens']:.3f} spec={v['spec']:.3f} comp={v['composite']:.4f} "
|
| 336 |
+
f"lr={opt_ft.param_groups[0]['lr']:.2e}"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if v["auc"] > best_auc:
|
| 340 |
+
best_auc, best_state, patience_ctr = (
|
| 341 |
+
v["auc"], copy.deepcopy(model.state_dict()), 0
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
patience_ctr += 1
|
| 345 |
+
if patience_ctr >= cfg.early_stop_patience:
|
| 346 |
+
print(f" [ft] early stop at epoch {ep} (best val AUC = {best_auc:.4f})")
|
| 347 |
+
break
|
| 348 |
+
|
| 349 |
+
if best_state is not None:
|
| 350 |
+
model.load_state_dict(best_state)
|
| 351 |
+
|
| 352 |
+
ckpt_path = os.path.join(output_dir, f"model_seed{seed}.pth")
|
| 353 |
+
torch.save(best_state if best_state is not None else model.state_dict(), ckpt_path)
|
| 354 |
+
print(f"[seed={seed}] Best val AUC = {best_auc:.4f} checkpoint β {ckpt_path}")
|
| 355 |
+
|
| 356 |
+
return model, best_auc, ckpt_path, history
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# ---------------------------------------------------------------------------
|
| 360 |
+
# Multi-seed ensemble training
|
| 361 |
+
# ---------------------------------------------------------------------------
|
| 362 |
+
def train_ensemble(
|
| 363 |
+
train_loader: DataLoader,
|
| 364 |
+
val_loader: DataLoader,
|
| 365 |
+
seeds: Optional[List[int]] = None,
|
| 366 |
+
output_dir: Optional[str] = None,
|
| 367 |
+
config=None,
|
| 368 |
+
) -> Tuple[List[Tuple[int, nn.Module, float, str]], pd.DataFrame]:
|
| 369 |
+
"""Train one model per seed and return (models_list, full_history_df).
|
| 370 |
+
|
| 371 |
+
`models_list` items: (seed, trained_model, best_val_auc, checkpoint_path).
|
| 372 |
+
"""
|
| 373 |
+
cfg = config or CFG
|
| 374 |
+
seeds = seeds if seeds is not None else cfg.seeds
|
| 375 |
+
output_dir = output_dir or cfg.output_dir
|
| 376 |
+
|
| 377 |
+
print(f" ENSEMBLE TRAINING STARTED")
|
| 378 |
+
|
| 379 |
+
models, all_history = [], []
|
| 380 |
+
for seed in seeds:
|
| 381 |
+
m, auc, ckpt, hist = train_one_seed(
|
| 382 |
+
seed, train_loader, val_loader,
|
| 383 |
+
output_dir=output_dir, config=cfg,
|
| 384 |
+
)
|
| 385 |
+
models.append((seed, m, auc, ckpt))
|
| 386 |
+
all_history.extend(hist)
|
| 387 |
+
free_device_cache(cfg.device)
|
| 388 |
+
|
| 389 |
+
history_df = pd.DataFrame(all_history)
|
| 390 |
+
history_df.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
|
| 391 |
+
|
| 392 |
+
frozen_blocks = getattr(cfg, "frozen_blocks", 0)
|
| 393 |
+
stage2_label = "all blocks unfrozen" if frozen_blocks == 0 else f"{frozen_blocks} block(s) frozen"
|
| 394 |
+
|
| 395 |
+
print("\n" + "β" * 80)
|
| 396 |
+
print(f" ENSEMBLE COMPLETE")
|
| 397 |
+
print(f" Backbone : {cfg.backbone}")
|
| 398 |
+
print(f" Frozen blocks : {frozen_blocks} ({stage2_label} in Stage 2)")
|
| 399 |
+
print(f" Seeds trained : {len(models)}")
|
| 400 |
+
print(" Per-seed best val AUC:")
|
| 401 |
+
for seed, _, auc, _ in models:
|
| 402 |
+
print(f" seed {seed:>5} : {auc:.4f}")
|
| 403 |
+
best_seed = max(models, key=lambda x: x[2])
|
| 404 |
+
print(f" Best seed : {best_seed[0]} (AUC={best_seed[2]:.4f})")
|
| 405 |
+
print("β" * 80)
|
| 406 |
+
|
| 407 |
+
return models, history_df
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# ---------------------------------------------------------------------------
|
| 411 |
+
# Unified entry point (respects CFG.use_ensemble)
|
| 412 |
+
# ---------------------------------------------------------------------------
|
| 413 |
+
def train(
|
| 414 |
+
train_loader: DataLoader,
|
| 415 |
+
val_loader: DataLoader,
|
| 416 |
+
output_dir: Optional[str] = None,
|
| 417 |
+
config=None,
|
| 418 |
+
) -> Tuple[List[Tuple[int, nn.Module, float, str]], pd.DataFrame]:
|
| 419 |
+
"""Train and return (models_list, history_df) β same format as train_ensemble.
|
| 420 |
+
|
| 421 |
+
Behaviour is controlled by CFG.use_ensemble:
|
| 422 |
+
True β delegates to train_ensemble (one model per seed in CFG.seeds)
|
| 423 |
+
False β trains a single model with CFG.seed and wraps result in the
|
| 424 |
+
same list format so the rest of the pipeline works unchanged.
|
| 425 |
+
"""
|
| 426 |
+
cfg = config or CFG
|
| 427 |
+
if cfg.use_ensemble:
|
| 428 |
+
return train_ensemble(train_loader, val_loader, output_dir=output_dir, config=cfg)
|
| 429 |
+
|
| 430 |
+
m, auc, ckpt, hist = train_one_seed(
|
| 431 |
+
cfg.seed, train_loader, val_loader, output_dir=output_dir, config=cfg,
|
| 432 |
+
)
|
| 433 |
+
history_df = pd.DataFrame(hist)
|
| 434 |
+
history_df.to_csv(
|
| 435 |
+
os.path.join(output_dir or cfg.output_dir, "training_history.csv"), index=False,
|
| 436 |
+
)
|
| 437 |
+
return [(cfg.seed, m, auc, ckpt)], history_df
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# ---------------------------------------------------------------------------
|
| 441 |
+
# TTA inference
|
| 442 |
+
# ---------------------------------------------------------------------------
|
| 443 |
+
def tta_predict(
|
| 444 |
+
model: nn.Module,
|
| 445 |
+
df: pd.DataFrame,
|
| 446 |
+
image_dir: Optional[str] = None,
|
| 447 |
+
has_labels: bool = True,
|
| 448 |
+
tta_transforms: Optional[List] = None,
|
| 449 |
+
config=None,
|
| 450 |
+
) -> dict:
|
| 451 |
+
"""Run TTA inference for ONE model on a DataFrame.
|
| 452 |
+
|
| 453 |
+
Predictions are averaged in **logit space** across all TTA passes.
|
| 454 |
+
"""
|
| 455 |
+
cfg = config or CFG
|
| 456 |
+
tta_transforms = tta_transforms or make_tta_transforms(cfg.img_size)
|
| 457 |
+
tta_transforms = tta_transforms[:cfg.tta_passes]
|
| 458 |
+
|
| 459 |
+
all_logits: list[np.ndarray] = []
|
| 460 |
+
names_ref, labels_ref = None, None
|
| 461 |
+
|
| 462 |
+
pin = (cfg.device == "cuda")
|
| 463 |
+
amp_ctx = torch.cuda.amp.autocast(enabled=(cfg.device == "cuda"))
|
| 464 |
+
|
| 465 |
+
for tf in tta_transforms:
|
| 466 |
+
ds = TTADataset(df, tf, image_dir)
|
| 467 |
+
loader = DataLoader(
|
| 468 |
+
ds, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
|
| 469 |
+
pin_memory=pin, shuffle=False,
|
| 470 |
+
)
|
| 471 |
+
pass_logits, pass_names, pass_labels = [], [], []
|
| 472 |
+
model.eval()
|
| 473 |
+
with torch.no_grad(), amp_ctx:
|
| 474 |
+
for x, y, names in loader:
|
| 475 |
+
x = x.to(cfg.device, non_blocking=pin)
|
| 476 |
+
logit = cardio_logit(model, x).float().cpu().numpy()
|
| 477 |
+
pass_logits.append(logit)
|
| 478 |
+
pass_names.extend(list(names))
|
| 479 |
+
if has_labels:
|
| 480 |
+
pass_labels.append(y.numpy())
|
| 481 |
+
all_logits.append(np.concatenate(pass_logits))
|
| 482 |
+
if names_ref is None:
|
| 483 |
+
names_ref = pass_names
|
| 484 |
+
labels_ref = np.concatenate(pass_labels) if has_labels else None
|
| 485 |
+
|
| 486 |
+
mean_logit = np.stack(all_logits, axis=0).mean(axis=0)
|
| 487 |
+
mean_prob = (1.0 / (1.0 + np.exp(-mean_logit))).astype(np.float32)
|
| 488 |
+
return {
|
| 489 |
+
"names": names_ref,
|
| 490 |
+
"y_prob": mean_prob,
|
| 491 |
+
"y_true": labels_ref,
|
| 492 |
+
"mean_logit": mean_logit,
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def tta_predict_ensemble(
|
| 497 |
+
models_list: List[Tuple[int, nn.Module, float, str]],
|
| 498 |
+
df: pd.DataFrame,
|
| 499 |
+
image_dir: Optional[str] = None,
|
| 500 |
+
has_labels: bool = True,
|
| 501 |
+
tta_transforms: Optional[List] = None,
|
| 502 |
+
config=None,
|
| 503 |
+
) -> dict:
|
| 504 |
+
"""Run TTA for every model in `models_list` and average in logit space."""
|
| 505 |
+
cfg = config or CFG
|
| 506 |
+
all_logits: list[np.ndarray] = []
|
| 507 |
+
names_ref, labels_ref = None, None
|
| 508 |
+
|
| 509 |
+
for (seed, model, _, _) in models_list:
|
| 510 |
+
print(f" TTA with seed={seed}...")
|
| 511 |
+
pred = tta_predict(
|
| 512 |
+
model, df, image_dir=image_dir, has_labels=has_labels,
|
| 513 |
+
tta_transforms=tta_transforms, config=cfg,
|
| 514 |
+
)
|
| 515 |
+
all_logits.append(pred["mean_logit"])
|
| 516 |
+
if names_ref is None:
|
| 517 |
+
names_ref = pred["names"]
|
| 518 |
+
labels_ref = pred["y_true"]
|
| 519 |
+
|
| 520 |
+
mean_logit = np.stack(all_logits, axis=0).mean(axis=0)
|
| 521 |
+
mean_prob = (1.0 / (1.0 + np.exp(-mean_logit))).astype(np.float32)
|
| 522 |
+
return {"names": names_ref, "y_prob": mean_prob, "y_true": labels_ref}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# ---------------------------------------------------------------------------
|
| 526 |
+
# Submission inference
|
| 527 |
+
# ---------------------------------------------------------------------------
|
| 528 |
+
def predict_submission(
|
| 529 |
+
models_list: List[Tuple[int, nn.Module, float, str]],
|
| 530 |
+
submission_dir: str,
|
| 531 |
+
tta_transforms: Optional[List] = None,
|
| 532 |
+
config=None,
|
| 533 |
+
) -> dict:
|
| 534 |
+
"""TTA + ensemble inference on an unlabelled submission directory.
|
| 535 |
+
|
| 536 |
+
Wraps the directory in a DataFrame so we can reuse `tta_predict_ensemble`.
|
| 537 |
+
"""
|
| 538 |
+
cfg = config or CFG
|
| 539 |
+
files = sorted(
|
| 540 |
+
f for f in os.listdir(submission_dir)
|
| 541 |
+
if os.path.isfile(os.path.join(submission_dir, f))
|
| 542 |
+
and f.lower().endswith((".png", ".jpg", ".jpeg"))
|
| 543 |
+
)
|
| 544 |
+
sub_df = pd.DataFrame({"filename": files})
|
| 545 |
+
return tta_predict_ensemble(
|
| 546 |
+
models_list, sub_df,
|
| 547 |
+
image_dir=submission_dir, has_labels=False,
|
| 548 |
+
tta_transforms=tta_transforms, config=cfg,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# ---------------------------------------------------------------------------
|
| 553 |
+
# Metrics
|
| 554 |
+
# ---------------------------------------------------------------------------
|
| 555 |
+
def metrics_at_threshold(y_true, y_prob, threshold: float) -> dict:
|
| 556 |
+
"""Composite-grading-aware metric set at a given threshold.
|
| 557 |
+
|
| 558 |
+
composite = 0.5Β·AUC + 0.25Β·sensitivity + 0.25Β·specificity
|
| 559 |
+
"""
|
| 560 |
+
y_pred = (y_prob >= threshold).astype(int)
|
| 561 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
|
| 562 |
+
|
| 563 |
+
sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 564 |
+
spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
|
| 565 |
+
acc = (tp + tn) / (tp + tn + fp + fn)
|
| 566 |
+
auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else float("nan")
|
| 567 |
+
composite = 0.5 * auc + 0.25 * sens + 0.25 * spec
|
| 568 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 569 |
+
|
| 570 |
+
return {
|
| 571 |
+
"threshold": float(threshold),
|
| 572 |
+
"auc": float(auc),
|
| 573 |
+
"sensitivity": float(sens),
|
| 574 |
+
"specificity": float(spec),
|
| 575 |
+
"accuracy": float(acc),
|
| 576 |
+
"youden": float(sens + spec - 1.0),
|
| 577 |
+
"composite": float(composite),
|
| 578 |
+
"precision": float(precision),
|
| 579 |
+
"tp": int(tp), "tn": int(tn), "fp": int(fp), "fn": int(fn),
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# Backwards-compatible alias (used by older notebook cells)
|
| 584 |
+
compute_basic_metrics = metrics_at_threshold
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def find_best_threshold(y_true, y_prob) -> Tuple[float, dict]:
|
| 588 |
+
"""Pick the threshold that maximises sensitivity + specificity (Youden's J)."""
|
| 589 |
+
candidates = np.unique(np.round(np.concatenate([[0.0], y_prob, [1.0]]), 6))
|
| 590 |
+
best_score, best_row = -np.inf, None
|
| 591 |
+
for thr in candidates:
|
| 592 |
+
m = metrics_at_threshold(y_true, y_prob, thr)
|
| 593 |
+
score = m["sensitivity"] + m["specificity"]
|
| 594 |
+
if score > best_score:
|
| 595 |
+
best_score, best_row = score, m
|
| 596 |
+
return float(best_row["threshold"]), best_row
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def bootstrap_threshold(
|
| 600 |
+
y_true, y_prob,
|
| 601 |
+
n_boot: Optional[int] = None,
|
| 602 |
+
seed: Optional[int] = None,
|
| 603 |
+
) -> float:
|
| 604 |
+
"""Bootstrap-stabilised threshold (median across resamples).
|
| 605 |
+
|
| 606 |
+
Generalises better than a single-shot pick on the raw val set.
|
| 607 |
+
"""
|
| 608 |
+
n_boot = n_boot if n_boot is not None else CFG.n_bootstrap
|
| 609 |
+
seed = seed if seed is not None else CFG.seed
|
| 610 |
+
|
| 611 |
+
rng = np.random.RandomState(seed)
|
| 612 |
+
thrs: list[float] = []
|
| 613 |
+
n = len(y_true)
|
| 614 |
+
for _ in range(n_boot):
|
| 615 |
+
idx = rng.randint(0, n, size=n)
|
| 616 |
+
if len(np.unique(y_true[idx])) < 2:
|
| 617 |
+
continue
|
| 618 |
+
thr, _ = find_best_threshold(y_true[idx], y_prob[idx])
|
| 619 |
+
thrs.append(thr)
|
| 620 |
+
return float(np.median(thrs)) if thrs else 0.5
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def select_threshold(y_true, y_prob, config=None) -> Tuple[float, dict, dict]:
|
| 624 |
+
"""Pick the better of (single-shot) vs (bootstrap) thresholds on composite.
|
| 625 |
+
|
| 626 |
+
Bootstrap is preferred unless its composite is clearly worse (margin 0.005).
|
| 627 |
+
Returns (chosen_threshold, single_metrics, bootstrap_metrics).
|
| 628 |
+
"""
|
| 629 |
+
cfg = config or CFG
|
| 630 |
+
thr_single, _ = find_best_threshold(y_true, y_prob)
|
| 631 |
+
thr_boot = bootstrap_threshold(y_true, y_prob, n_boot=cfg.n_bootstrap, seed=cfg.seed)
|
| 632 |
+
m_single = metrics_at_threshold(y_true, y_prob, thr_single)
|
| 633 |
+
m_boot = metrics_at_threshold(y_true, y_prob, thr_boot)
|
| 634 |
+
chosen = thr_boot if m_boot["composite"] >= m_single["composite"] - 0.005 else thr_single
|
| 635 |
+
return float(chosen), m_single, m_boot
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# ---------------------------------------------------------------------------
|
| 639 |
+
# Saving results
|
| 640 |
+
# ---------------------------------------------------------------------------
|
| 641 |
+
def save_results(
|
| 642 |
+
models_list: List[Tuple[int, nn.Module, float, str]],
|
| 643 |
+
history: pd.DataFrame,
|
| 644 |
+
val_out: dict,
|
| 645 |
+
test_out: dict,
|
| 646 |
+
best_threshold: float,
|
| 647 |
+
output_dir: str,
|
| 648 |
+
model_name: str = "model",
|
| 649 |
+
config=None,
|
| 650 |
+
) -> None:
|
| 651 |
+
"""Persist per-seed checkpoints, history, metrics, predictions, and global log.
|
| 652 |
+
|
| 653 |
+
Per-seed `.pth` files are already written by `train_one_seed`; here we
|
| 654 |
+
only re-save them under the conventional name and write the metrics +
|
| 655 |
+
per-image prediction CSVs.
|
| 656 |
+
"""
|
| 657 |
+
cfg = config or CFG
|
| 658 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 659 |
+
|
| 660 |
+
# ββ Metric files + per-image predictions βββββββββββββββββββββββββββββ
|
| 661 |
+
val_metrics = metrics_at_threshold(val_out["y_true"], val_out["y_prob"], best_threshold)
|
| 662 |
+
test_metrics = metrics_at_threshold(test_out["y_true"], test_out["y_prob"], best_threshold)
|
| 663 |
+
|
| 664 |
+
for split_name, metrics in [("val", val_metrics), ("test", test_metrics)]:
|
| 665 |
+
with open(os.path.join(output_dir, f"{split_name}_metrics_final.json"), "w") as f:
|
| 666 |
+
json.dump(metrics, f, indent=2)
|
| 667 |
+
|
| 668 |
+
history.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
|
| 669 |
+
|
| 670 |
+
for split_name, out in [("val", val_out), ("test", test_out)]:
|
| 671 |
+
y_true = out["y_true"].astype(int)
|
| 672 |
+
y_pred = (out["y_prob"] >= best_threshold).astype(int)
|
| 673 |
+
pd.DataFrame({
|
| 674 |
+
"filename": out["names"],
|
| 675 |
+
"y_true": y_true,
|
| 676 |
+
"prob": out["y_prob"],
|
| 677 |
+
"pred": y_pred,
|
| 678 |
+
"correct": (y_pred == y_true).astype(int),
|
| 679 |
+
}).to_csv(os.path.join(output_dir, f"{split_name}_predictions.csv"), index=False)
|
| 680 |
+
|
| 681 |
+
# ββ Ensemble manifest (which seeds + which checkpoints) ββββββββββββββ
|
| 682 |
+
pd.DataFrame([
|
| 683 |
+
{"seed": s, "best_val_auc": auc, "checkpoint": ckpt}
|
| 684 |
+
for (s, _, auc, ckpt) in models_list
|
| 685 |
+
]).to_csv(os.path.join(output_dir, "ensemble_manifest.csv"), index=False)
|
| 686 |
+
|
| 687 |
+
print(f"Results saved β {output_dir}")
|
| 688 |
+
|
| 689 |
+
# ββ Append to global results log βββββββββββββββββββββββββββββββββββββ
|
| 690 |
+
log_run(
|
| 691 |
+
model_name=model_name,
|
| 692 |
+
val_metrics=val_metrics,
|
| 693 |
+
test_metrics=test_metrics,
|
| 694 |
+
config=cfg,
|
| 695 |
+
n_seeds=len(models_list),
|
| 696 |
+
log_path=cfg.results_log_path,
|
| 697 |
+
)
|
model_training/src/transforms.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import torchvision.transforms as T
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from src.config import CFG
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# PIL helpers (TTA expects PIL β PIL transforms; xrv normalisation is applied
|
| 13 |
+
# downstream inside the Dataset).
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
def _pil_hflip(img: Image.Image) -> Image.Image:
|
| 16 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Training and evaluation transforms
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
def make_transforms(img_size: int | None = None) -> Tuple[T.Compose, T.Compose]:
|
| 23 |
+
"""Return (train_transform, eval_transform) PIL-space pipelines.
|
| 24 |
+
|
| 25 |
+
All transforms produce a PIL grayscale image of size (img_size, img_size).
|
| 26 |
+
The downstream Dataset converts it to a single-channel xrv-normalised
|
| 27 |
+
tensor in [-1024, 1024].
|
| 28 |
+
|
| 29 |
+
Train pipeline: small affine, mild jitter, light hflip; random erasing
|
| 30 |
+
happens after xrv normalisation inside the Dataset.
|
| 31 |
+
Eval pipeline: deterministic resize.
|
| 32 |
+
"""
|
| 33 |
+
img_size = img_size if img_size is not None else CFG.img_size
|
| 34 |
+
|
| 35 |
+
train_tf = T.Compose([
|
| 36 |
+
T.Resize((img_size + 16, img_size + 16)),
|
| 37 |
+
T.RandomCrop((img_size, img_size)),
|
| 38 |
+
T.RandomHorizontalFlip(p=0.5),
|
| 39 |
+
T.RandomAffine(
|
| 40 |
+
degrees=8,
|
| 41 |
+
translate=(0.04, 0.04),
|
| 42 |
+
scale=(0.95, 1.05),
|
| 43 |
+
fill=0,
|
| 44 |
+
),
|
| 45 |
+
T.ColorJitter(brightness=0.15, contrast=0.15),
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
eval_tf = T.Compose([
|
| 49 |
+
T.Resize((img_size, img_size)),
|
| 50 |
+
])
|
| 51 |
+
|
| 52 |
+
return train_tf, eval_tf
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
# Test-time augmentation (TTA) transforms
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
def make_tta_transforms(img_size: int | None = None) -> List[T.Compose]:
|
| 59 |
+
"""Six deterministic PIL-space transforms.
|
| 60 |
+
|
| 61 |
+
All end with a resized PIL image ready for xrv_normalize_np().
|
| 62 |
+
Predictions are averaged across all passes (in logit space) inside
|
| 63 |
+
`tta_predict` / `tta_predict_ensemble`.
|
| 64 |
+
"""
|
| 65 |
+
img_size = img_size if img_size is not None else CFG.img_size
|
| 66 |
+
size = (img_size, img_size)
|
| 67 |
+
|
| 68 |
+
return [
|
| 69 |
+
T.Compose([T.Resize(size)]),
|
| 70 |
+
T.Compose([T.Resize(size), T.Lambda(_pil_hflip)]),
|
| 71 |
+
T.Compose([T.Resize((img_size + 20, img_size + 20)), T.CenterCrop(size)]),
|
| 72 |
+
T.Compose([T.Resize((img_size - 20, img_size - 20)),
|
| 73 |
+
T.Pad(10, fill=0), T.CenterCrop(size)]),
|
| 74 |
+
T.Compose([T.Resize(size),
|
| 75 |
+
T.RandomAffine(degrees=(6, 6), fill=0)]),
|
| 76 |
+
T.Compose([T.Resize(size),
|
| 77 |
+
T.RandomAffine(degrees=(-6, -6), fill=0)]),
|
| 78 |
+
]
|
model_training/src/utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from src.config import Config
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def set_seed(seed: int) -> None:
|
| 18 |
+
"""Set all relevant random seeds for reproducibility."""
|
| 19 |
+
random.seed(seed)
|
| 20 |
+
np.random.seed(seed)
|
| 21 |
+
torch.manual_seed(seed)
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
torch.cuda.manual_seed_all(seed)
|
| 24 |
+
elif torch.backends.mps.is_available():
|
| 25 |
+
torch.mps.manual_seed(seed)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def free_device_cache(device: str) -> None:
|
| 29 |
+
"""Release unused memory on GPU / MPS (useful between seeds / Optuna trials)."""
|
| 30 |
+
if device == "mps":
|
| 31 |
+
torch.mps.empty_cache()
|
| 32 |
+
elif device == "cuda":
|
| 33 |
+
torch.cuda.empty_cache()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def log_run(
|
| 37 |
+
model_name: str,
|
| 38 |
+
val_metrics: dict,
|
| 39 |
+
test_metrics: dict,
|
| 40 |
+
config: "Config",
|
| 41 |
+
n_seeds: int,
|
| 42 |
+
log_path: str = "results_log.csv",
|
| 43 |
+
) -> pd.DataFrame:
|
| 44 |
+
"""Append one training run to the global results log CSV.
|
| 45 |
+
|
| 46 |
+
Creates the file with a header if it does not exist yet, otherwise appends.
|
| 47 |
+
|
| 48 |
+
Columns
|
| 49 |
+
βββββββ
|
| 50 |
+
run_id, model_name, created_at,
|
| 51 |
+
<all Config fields except device/csv_path/image_dir/submission_test_dir/output_dir>,
|
| 52 |
+
n_seeds,
|
| 53 |
+
val_* (auc, sensitivity, specificity, accuracy, youden, composite,
|
| 54 |
+
precision, tp, tn, fp, fn, threshold)
|
| 55 |
+
test_* (same set)
|
| 56 |
+
|
| 57 |
+
Parameters
|
| 58 |
+
ββββββββββ
|
| 59 |
+
model_name : human-readable name for this run (e.g. "xrv_densenet_ensemble")
|
| 60 |
+
val_metrics : dict returned by metrics_at_threshold on the validation split
|
| 61 |
+
test_metrics : dict returned by metrics_at_threshold on the test split
|
| 62 |
+
config : the Config instance used for this run
|
| 63 |
+
n_seeds : number of seeds in the ensemble
|
| 64 |
+
log_path : path to the CSV results log (created if missing)
|
| 65 |
+
"""
|
| 66 |
+
cfg_dict = dataclasses.asdict(config)
|
| 67 |
+
|
| 68 |
+
# exclude path/device fields β not meaningful for comparison
|
| 69 |
+
skip = {"csv_path", "image_dir", "submission_test_dir", "output_dir", "device"}
|
| 70 |
+
hyperparams = {k: v for k, v in cfg_dict.items() if k not in skip}
|
| 71 |
+
|
| 72 |
+
# `seeds` is a list β flatten to a string so the CSV stays readable
|
| 73 |
+
if "seeds" in hyperparams and isinstance(hyperparams["seeds"], list):
|
| 74 |
+
hyperparams["seeds"] = ",".join(str(s) for s in hyperparams["seeds"])
|
| 75 |
+
|
| 76 |
+
row: dict = {
|
| 77 |
+
"run_id": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
| 78 |
+
"model_name": model_name,
|
| 79 |
+
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 80 |
+
"n_seeds": n_seeds,
|
| 81 |
+
**hyperparams,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
for prefix, metrics in [("val", val_metrics), ("test", test_metrics)]:
|
| 85 |
+
for key, value in metrics.items():
|
| 86 |
+
if key != "threshold":
|
| 87 |
+
row[f"{prefix}_{key}"] = value
|
| 88 |
+
row[f"{prefix}_threshold"] = metrics.get("threshold", float("nan"))
|
| 89 |
+
|
| 90 |
+
new_row_df = pd.DataFrame([row])
|
| 91 |
+
|
| 92 |
+
if os.path.exists(log_path):
|
| 93 |
+
log_df = pd.read_csv(log_path)
|
| 94 |
+
log_df = pd.concat([log_df, new_row_df], ignore_index=True)
|
| 95 |
+
else:
|
| 96 |
+
log_df = new_row_df
|
| 97 |
+
|
| 98 |
+
log_df.to_csv(log_path, index=False)
|
| 99 |
+
print(f"Run logged β {log_path} ({len(log_df)} total runs)")
|
| 100 |
+
return log_df
|