Spaces:
Sleeping
Sleeping
GitHub Actions
commited on
Commit
·
af59988
1
Parent(s):
e642110
Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6
Browse files- Dockerfile +3 -2
- api/Dockerfile +24 -0
- api/__init__.py +9 -0
- api/main.py +245 -0
- api/schemas.py +72 -0
- requirements.txt +43 -5
- src/__init__.py +6 -0
- src/config.py +112 -0
- src/dataset.py +201 -0
- src/evaluate.py +107 -0
- src/export.py +190 -0
- src/gradcam.py +137 -0
- src/model.py +87 -0
- src/predict.py +47 -0
- src/train.py +250 -0
- src/utils.py +74 -0
Dockerfile
CHANGED
|
@@ -13,11 +13,12 @@ COPY requirements.txt .
|
|
| 13 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
|
| 15 |
# Copy app files
|
| 16 |
-
COPY
|
|
|
|
| 17 |
COPY models/ models/
|
| 18 |
|
| 19 |
# Expose port 7860 (HF Spaces default)
|
| 20 |
EXPOSE 7860
|
| 21 |
|
| 22 |
# Run the API
|
| 23 |
-
CMD ["uvicorn", "
|
|
|
|
| 13 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
|
| 15 |
# Copy app files
|
| 16 |
+
COPY src/ src/
|
| 17 |
+
COPY api/ api/
|
| 18 |
COPY models/ models/
|
| 19 |
|
| 20 |
# Expose port 7860 (HF Spaces default)
|
| 21 |
EXPOSE 7860
|
| 22 |
|
| 23 |
# Run the API
|
| 24 |
+
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
api/Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
libgl1 \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements first for caching
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy app files
|
| 16 |
+
COPY src/ src/
|
| 17 |
+
COPY api/ api/
|
| 18 |
+
COPY models/ models/
|
| 19 |
+
|
| 20 |
+
# Expose port 7860 (HF Spaces default)
|
| 21 |
+
EXPOSE 7860
|
| 22 |
+
|
| 23 |
+
# Run the API
|
| 24 |
+
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
api/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for Pneumonia Detection."""
|
| 2 |
+
|
| 3 |
+
from .main import app
|
| 4 |
+
from .schemas import (
|
| 5 |
+
HealthResponse,
|
| 6 |
+
PredictionResponse,
|
| 7 |
+
GradCAMResponse,
|
| 8 |
+
ErrorResponse
|
| 9 |
+
)
|
api/main.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for Pneumonia Detection API.
|
| 3 |
+
|
| 4 |
+
Run with: uvicorn api.main:app --reload
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import base64
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from fastapi.responses import JSONResponse
|
| 17 |
+
|
| 18 |
+
from .schemas import (
|
| 19 |
+
HealthResponse,
|
| 20 |
+
PredictionResponse,
|
| 21 |
+
GradCAMResponse,
|
| 22 |
+
ErrorResponse
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
import sys
|
| 26 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 27 |
+
|
| 28 |
+
from src.config import CHECKPOINT_PATH, CLASS_NAMES, CONFIDENCE_THRESHOLD
|
| 29 |
+
from src.model import create_model, get_device
|
| 30 |
+
from src.predict import load_model, predict_image
|
| 31 |
+
from src.gradcam import generate_gradcam
|
| 32 |
+
|
| 33 |
+
# =============================================================================
|
| 34 |
+
# App Configuration
|
| 35 |
+
# =============================================================================
|
| 36 |
+
|
| 37 |
+
app = FastAPI(
|
| 38 |
+
title="Pneumonia Detection API",
|
| 39 |
+
description="Deep learning API for detecting pneumonia from chest X-ray images using EfficientNet-B0",
|
| 40 |
+
version="1.0.0",
|
| 41 |
+
docs_url="/docs",
|
| 42 |
+
redoc_url="/redoc"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# CORS middleware for frontend access
|
| 46 |
+
app.add_middleware(
|
| 47 |
+
CORSMiddleware,
|
| 48 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 49 |
+
allow_credentials=True,
|
| 50 |
+
allow_methods=["*"],
|
| 51 |
+
allow_headers=["*"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# =============================================================================
|
| 55 |
+
# Model Loading (on startup)
|
| 56 |
+
# =============================================================================
|
| 57 |
+
|
| 58 |
+
model = None
|
| 59 |
+
device = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@app.on_event("startup")
|
| 63 |
+
async def load_model_on_startup():
|
| 64 |
+
"""Load model when the API starts."""
|
| 65 |
+
global model, device
|
| 66 |
+
|
| 67 |
+
device = get_device()
|
| 68 |
+
print(f"Using device: {device}")
|
| 69 |
+
|
| 70 |
+
if not CHECKPOINT_PATH.exists():
|
| 71 |
+
print(f"Warning: Model checkpoint not found at {CHECKPOINT_PATH}")
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
model = create_model(pretrained=False, freeze_backbone=False, device=device)
|
| 75 |
+
model = load_model(model, CHECKPOINT_PATH, device)
|
| 76 |
+
print(f"Model loaded from {CHECKPOINT_PATH}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# =============================================================================
|
| 80 |
+
# Helper Functions
|
| 81 |
+
# =============================================================================
|
| 82 |
+
|
| 83 |
+
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png"}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def validate_image(file: UploadFile) -> None:
|
| 87 |
+
"""Validate uploaded image file."""
|
| 88 |
+
if not file.content_type.startswith("image/"):
|
| 89 |
+
raise HTTPException(
|
| 90 |
+
status_code=400,
|
| 91 |
+
detail=f"Invalid content type: {file.content_type}. Expected image/*"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
ext = Path(file.filename).suffix.lower() if file.filename else ""
|
| 95 |
+
if ext not in ALLOWED_EXTENSIONS:
|
| 96 |
+
raise HTTPException(
|
| 97 |
+
status_code=400,
|
| 98 |
+
detail=f"Invalid file extension: {ext}. Allowed: {ALLOWED_EXTENSIONS}"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
async def read_image(file: UploadFile) -> Image.Image:
|
| 103 |
+
"""Read uploaded file as PIL Image."""
|
| 104 |
+
try:
|
| 105 |
+
contents = await file.read()
|
| 106 |
+
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 107 |
+
return image
|
| 108 |
+
except Exception as e:
|
| 109 |
+
raise HTTPException(
|
| 110 |
+
status_code=400,
|
| 111 |
+
detail=f"Failed to read image: {str(e)}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# =============================================================================
|
| 116 |
+
# API Endpoints
|
| 117 |
+
# =============================================================================
|
| 118 |
+
|
| 119 |
+
@app.get("/", include_in_schema=False)
|
| 120 |
+
async def root():
|
| 121 |
+
"""Redirect to docs."""
|
| 122 |
+
return {"message": "Pneumonia Detection API", "docs": "/docs"}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@app.get("/health", response_model=HealthResponse, tags=["Health"])
|
| 126 |
+
async def health_check():
|
| 127 |
+
"""
|
| 128 |
+
Health check endpoint.
|
| 129 |
+
|
| 130 |
+
Returns the API status and model loading state.
|
| 131 |
+
"""
|
| 132 |
+
return HealthResponse(
|
| 133 |
+
status="healthy" if model is not None else "model_not_loaded",
|
| 134 |
+
model_loaded=model is not None,
|
| 135 |
+
model_path=str(CHECKPOINT_PATH)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@app.post(
|
| 140 |
+
"/predict",
|
| 141 |
+
response_model=PredictionResponse,
|
| 142 |
+
responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
|
| 143 |
+
tags=["Prediction"]
|
| 144 |
+
)
|
| 145 |
+
async def predict(file: UploadFile = File(..., description="Chest X-ray image (JPEG/PNG)")):
|
| 146 |
+
"""
|
| 147 |
+
Predict pneumonia from chest X-ray image.
|
| 148 |
+
|
| 149 |
+
Upload a chest X-ray image and get the prediction (NORMAL or PNEUMONIA)
|
| 150 |
+
with confidence score.
|
| 151 |
+
"""
|
| 152 |
+
if model is None:
|
| 153 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 154 |
+
|
| 155 |
+
validate_image(file)
|
| 156 |
+
image = await read_image(file)
|
| 157 |
+
|
| 158 |
+
# Run inference
|
| 159 |
+
start_time = time.time()
|
| 160 |
+
pred_class, confidence = predict_image(model, image, device)
|
| 161 |
+
processing_time = (time.time() - start_time) * 1000 # Convert to ms
|
| 162 |
+
|
| 163 |
+
# Calculate raw probability
|
| 164 |
+
probability = confidence if pred_class == "PNEUMONIA" else 1 - confidence
|
| 165 |
+
|
| 166 |
+
return PredictionResponse(
|
| 167 |
+
prediction=pred_class,
|
| 168 |
+
confidence=confidence,
|
| 169 |
+
probability=probability,
|
| 170 |
+
processing_time_ms=round(processing_time, 2)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@app.post(
|
| 175 |
+
"/predict/gradcam",
|
| 176 |
+
response_model=GradCAMResponse,
|
| 177 |
+
responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
|
| 178 |
+
tags=["Prediction"]
|
| 179 |
+
)
|
| 180 |
+
async def predict_with_gradcam(file: UploadFile = File(..., description="Chest X-ray image (JPEG/PNG)")):
|
| 181 |
+
"""
|
| 182 |
+
Predict with Grad-CAM visualization.
|
| 183 |
+
|
| 184 |
+
Returns prediction along with a Grad-CAM heatmap overlay showing
|
| 185 |
+
which regions of the image influenced the prediction.
|
| 186 |
+
"""
|
| 187 |
+
if model is None:
|
| 188 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 189 |
+
|
| 190 |
+
validate_image(file)
|
| 191 |
+
image = await read_image(file)
|
| 192 |
+
|
| 193 |
+
# Run inference with Grad-CAM
|
| 194 |
+
start_time = time.time()
|
| 195 |
+
cam_image, pred_class, confidence, _ = generate_gradcam(model, image, device)
|
| 196 |
+
processing_time = (time.time() - start_time) * 1000
|
| 197 |
+
|
| 198 |
+
# Convert Grad-CAM image to base64
|
| 199 |
+
cam_pil = Image.fromarray(cam_image)
|
| 200 |
+
buffer = io.BytesIO()
|
| 201 |
+
cam_pil.save(buffer, format="PNG")
|
| 202 |
+
buffer.seek(0)
|
| 203 |
+
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 204 |
+
|
| 205 |
+
# Calculate raw probability
|
| 206 |
+
probability = confidence if pred_class == "PNEUMONIA" else 1 - confidence
|
| 207 |
+
|
| 208 |
+
return GradCAMResponse(
|
| 209 |
+
prediction=pred_class,
|
| 210 |
+
confidence=confidence,
|
| 211 |
+
probability=probability,
|
| 212 |
+
processing_time_ms=round(processing_time, 2),
|
| 213 |
+
gradcam_image=f"data:image/png;base64,{img_base64}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# =============================================================================
|
| 218 |
+
# Error Handlers
|
| 219 |
+
# =============================================================================
|
| 220 |
+
|
| 221 |
+
@app.exception_handler(HTTPException)
|
| 222 |
+
async def http_exception_handler(request, exc):
|
| 223 |
+
"""Handle HTTP exceptions."""
|
| 224 |
+
return JSONResponse(
|
| 225 |
+
status_code=exc.status_code,
|
| 226 |
+
content={"error": exc.detail, "detail": None}
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@app.exception_handler(Exception)
|
| 231 |
+
async def general_exception_handler(request, exc):
|
| 232 |
+
"""Handle unexpected exceptions."""
|
| 233 |
+
return JSONResponse(
|
| 234 |
+
status_code=500,
|
| 235 |
+
content={"error": "Internal server error", "detail": str(exc)}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# =============================================================================
|
| 240 |
+
# Run with: uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
|
| 241 |
+
# =============================================================================
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
import uvicorn
|
| 245 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
api/schemas.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for API request/response validation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ClassLabel(str, Enum):
|
| 11 |
+
"""Prediction class labels."""
|
| 12 |
+
NORMAL = "NORMAL"
|
| 13 |
+
PNEUMONIA = "PNEUMONIA"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class HealthResponse(BaseModel):
|
| 17 |
+
"""Health check response."""
|
| 18 |
+
status: str = Field(..., example="healthy")
|
| 19 |
+
model_loaded: bool = Field(..., example=True)
|
| 20 |
+
model_path: str = Field(..., example="models/best_model.pt")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PredictionResponse(BaseModel):
|
| 24 |
+
"""Prediction response."""
|
| 25 |
+
prediction: ClassLabel = Field(..., description="Predicted class")
|
| 26 |
+
confidence: float = Field(..., ge=0, le=1, description="Confidence score")
|
| 27 |
+
probability: float = Field(..., ge=0, le=1, description="Raw probability for PNEUMONIA")
|
| 28 |
+
processing_time_ms: float = Field(..., description="Inference time in milliseconds")
|
| 29 |
+
|
| 30 |
+
class Config:
|
| 31 |
+
json_schema_extra = {
|
| 32 |
+
"example": {
|
| 33 |
+
"prediction": "PNEUMONIA",
|
| 34 |
+
"confidence": 0.92,
|
| 35 |
+
"probability": 0.92,
|
| 36 |
+
"processing_time_ms": 45.2
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GradCAMResponse(BaseModel):
|
| 42 |
+
"""Prediction with Grad-CAM visualization."""
|
| 43 |
+
prediction: ClassLabel = Field(..., description="Predicted class")
|
| 44 |
+
confidence: float = Field(..., ge=0, le=1, description="Confidence score")
|
| 45 |
+
probability: float = Field(..., ge=0, le=1, description="Raw probability for PNEUMONIA")
|
| 46 |
+
processing_time_ms: float = Field(..., description="Inference time in milliseconds")
|
| 47 |
+
gradcam_image: str = Field(..., description="Base64 encoded Grad-CAM overlay image")
|
| 48 |
+
|
| 49 |
+
class Config:
|
| 50 |
+
json_schema_extra = {
|
| 51 |
+
"example": {
|
| 52 |
+
"prediction": "PNEUMONIA",
|
| 53 |
+
"confidence": 0.92,
|
| 54 |
+
"probability": 0.92,
|
| 55 |
+
"processing_time_ms": 150.5,
|
| 56 |
+
"gradcam_image": "data:image/png;base64,..."
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ErrorResponse(BaseModel):
|
| 62 |
+
"""Error response."""
|
| 63 |
+
error: str = Field(..., description="Error message")
|
| 64 |
+
detail: Optional[str] = Field(None, description="Detailed error information")
|
| 65 |
+
|
| 66 |
+
class Config:
|
| 67 |
+
json_schema_extra = {
|
| 68 |
+
"example": {
|
| 69 |
+
"error": "Invalid image format",
|
| 70 |
+
"detail": "Supported formats: JPEG, PNG"
|
| 71 |
+
}
|
| 72 |
+
}
|
requirements.txt
CHANGED
|
@@ -1,7 +1,45 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
uvicorn>=0.23.0
|
| 5 |
-
python-multipart>=0.0.6
|
| 6 |
pillow>=10.0.0
|
| 7 |
numpy>=1.24.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Deep Learning
|
| 2 |
+
torch>=2.1.0
|
| 3 |
+
torchvision>=0.16.0
|
|
|
|
|
|
|
| 4 |
pillow>=10.0.0
|
| 5 |
numpy>=1.24.0
|
| 6 |
+
|
| 7 |
+
# Data Analysis & Visualization
|
| 8 |
+
pandas>=2.0.0
|
| 9 |
+
matplotlib>=3.7.0
|
| 10 |
+
seaborn>=0.12.0
|
| 11 |
+
|
| 12 |
+
# Experiment Tracking
|
| 13 |
+
wandb>=0.15.0
|
| 14 |
+
|
| 15 |
+
# Model Interpretability
|
| 16 |
+
grad-cam>=1.4.0
|
| 17 |
+
|
| 18 |
+
# API
|
| 19 |
+
fastapi>=0.104.0
|
| 20 |
+
uvicorn>=0.24.0
|
| 21 |
+
python-multipart>=0.0.6
|
| 22 |
+
|
| 23 |
+
# Web UI
|
| 24 |
+
streamlit>=1.28.0
|
| 25 |
+
|
| 26 |
+
# Testing
|
| 27 |
+
pytest>=7.4.0
|
| 28 |
+
|
| 29 |
+
# Code Quality
|
| 30 |
+
black>=23.0.0
|
| 31 |
+
ruff>=0.1.0
|
| 32 |
+
|
| 33 |
+
# Jupyter
|
| 34 |
+
jupyterlab>=4.0.0
|
| 35 |
+
ipywidgets>=8.0.0
|
| 36 |
+
|
| 37 |
+
# Utilities
|
| 38 |
+
python-dotenv>=1.0.0
|
| 39 |
+
tqdm>=4.66.0
|
| 40 |
+
scikit-learn>=1.3.0
|
| 41 |
+
|
| 42 |
+
# ONNX Export
|
| 43 |
+
onnx>=1.15.0
|
| 44 |
+
onnxruntime>=1.16.0
|
| 45 |
+
onnxscript>=0.1.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pneumonia Detection from Chest X-Rays
|
| 3 |
+
Medical Image Classification using Deep Learning
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "0.1.0"
|
src/config.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration constants for the Pneumonia Detection project.
|
| 3 |
+
All hyperparameters and paths are defined here for easy modification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# =============================================================================
|
| 9 |
+
# Project Paths
|
| 10 |
+
# =============================================================================
|
| 11 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 12 |
+
DATA_DIR = PROJECT_ROOT / "data" / "raw"
|
| 13 |
+
PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
|
| 14 |
+
MODEL_DIR = PROJECT_ROOT / "models"
|
| 15 |
+
OUTPUT_DIR = PROJECT_ROOT / "outputs"
|
| 16 |
+
FIGURES_DIR = OUTPUT_DIR / "figures"
|
| 17 |
+
LOGS_DIR = OUTPUT_DIR / "logs"
|
| 18 |
+
|
| 19 |
+
# =============================================================================
|
| 20 |
+
# Data Configuration
|
| 21 |
+
# =============================================================================
|
| 22 |
+
IMAGE_SIZE = 224 # EfficientNet-B0 input size
|
| 23 |
+
BATCH_SIZE = 32
|
| 24 |
+
NUM_WORKERS = 4 # DataLoader workers
|
| 25 |
+
|
| 26 |
+
# ImageNet normalization (required for pretrained models)
|
| 27 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 28 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 29 |
+
|
| 30 |
+
# Class labels
|
| 31 |
+
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
|
| 32 |
+
NUM_CLASSES = 1 # Binary classification with sigmoid
|
| 33 |
+
|
| 34 |
+
# =============================================================================
|
| 35 |
+
# Model Configuration
|
| 36 |
+
# =============================================================================
|
| 37 |
+
MODEL_NAME = "efficientnet_b0"
|
| 38 |
+
DROPOUT_RATE = 0.3
|
| 39 |
+
PRETRAINED = True
|
| 40 |
+
|
| 41 |
+
# =============================================================================
|
| 42 |
+
# Training Configuration - Stage 1 (Frozen Backbone)
|
| 43 |
+
# =============================================================================
|
| 44 |
+
STAGE1_EPOCHS = 5
|
| 45 |
+
STAGE1_LR = 1e-4
|
| 46 |
+
STAGE1_FREEZE_BACKBONE = True
|
| 47 |
+
|
| 48 |
+
# =============================================================================
|
| 49 |
+
# Training Configuration - Stage 2 (Fine-tuning)
|
| 50 |
+
# =============================================================================
|
| 51 |
+
STAGE2_EPOCHS = 15
|
| 52 |
+
STAGE2_LR = 1e-5
|
| 53 |
+
STAGE2_FREEZE_BACKBONE = False
|
| 54 |
+
|
| 55 |
+
# =============================================================================
|
| 56 |
+
# Optimizer Configuration
|
| 57 |
+
# =============================================================================
|
| 58 |
+
WEIGHT_DECAY = 1e-4
|
| 59 |
+
BETAS = (0.9, 0.999)
|
| 60 |
+
|
| 61 |
+
# =============================================================================
|
| 62 |
+
# Scheduler Configuration
|
| 63 |
+
# =============================================================================
|
| 64 |
+
SCHEDULER_PATIENCE = 3
|
| 65 |
+
SCHEDULER_FACTOR = 0.5
|
| 66 |
+
SCHEDULER_MIN_LR = 1e-7
|
| 67 |
+
|
| 68 |
+
# =============================================================================
|
| 69 |
+
# Early Stopping Configuration
|
| 70 |
+
# =============================================================================
|
| 71 |
+
EARLY_STOP_PATIENCE = 7
|
| 72 |
+
EARLY_STOP_MIN_DELTA = 0.001
|
| 73 |
+
|
| 74 |
+
# =============================================================================
|
| 75 |
+
# Model Checkpointing
|
| 76 |
+
# =============================================================================
|
| 77 |
+
CHECKPOINT_PATH = MODEL_DIR / "best_model.pt"
|
| 78 |
+
SAVE_BEST_ONLY = True
|
| 79 |
+
MONITOR_METRIC = "val_loss"
|
| 80 |
+
|
| 81 |
+
# =============================================================================
|
| 82 |
+
# Weights & Biases Configuration
|
| 83 |
+
# =============================================================================
|
| 84 |
+
WANDB_PROJECT = "pneumonia-detection"
|
| 85 |
+
WANDB_ENTITY = None # Set to your W&B username if needed
|
| 86 |
+
|
| 87 |
+
# =============================================================================
|
| 88 |
+
# Inference Configuration
|
| 89 |
+
# =============================================================================
|
| 90 |
+
CONFIDENCE_THRESHOLD = 0.5 # For binary classification
|
| 91 |
+
GRADCAM_TARGET_LAYER = "features" # EfficientNet feature extractor
|
| 92 |
+
|
| 93 |
+
# =============================================================================
|
| 94 |
+
# Random Seed (for reproducibility)
|
| 95 |
+
# =============================================================================
|
| 96 |
+
SEED = 42
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_directories():
|
| 100 |
+
"""Create all necessary directories if they don't exist."""
|
| 101 |
+
for directory in [DATA_DIR, PROCESSED_DIR, MODEL_DIR, FIGURES_DIR, LOGS_DIR]:
|
| 102 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
# Print configuration for verification
|
| 107 |
+
print(f"Project Root: {PROJECT_ROOT}")
|
| 108 |
+
print(f"Data Directory: {DATA_DIR}")
|
| 109 |
+
print(f"Model Directory: {MODEL_DIR}")
|
| 110 |
+
print(f"Image Size: {IMAGE_SIZE}")
|
| 111 |
+
print(f"Batch Size: {BATCH_SIZE}")
|
| 112 |
+
print(f"Model: {MODEL_NAME}")
|
src/dataset.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Dataset and DataLoader utilities for Chest X-Ray classification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Tuple, Optional, List
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
|
| 15 |
+
from .config import (
|
| 16 |
+
DATA_DIR, IMAGE_SIZE, BATCH_SIZE, NUM_WORKERS,
|
| 17 |
+
IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES, SEED
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ChestXRayDataset(Dataset):
|
| 22 |
+
"""Dataset for Chest X-Ray images."""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
image_paths: List[Path],
|
| 27 |
+
labels: List[int],
|
| 28 |
+
transform: Optional[transforms.Compose] = None
|
| 29 |
+
):
|
| 30 |
+
self.image_paths = image_paths
|
| 31 |
+
self.labels = labels
|
| 32 |
+
self.transform = transform
|
| 33 |
+
|
| 34 |
+
def __len__(self) -> int:
|
| 35 |
+
return len(self.image_paths)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| 38 |
+
img_path = self.image_paths[idx]
|
| 39 |
+
label = self.labels[idx]
|
| 40 |
+
|
| 41 |
+
# Load image and convert to RGB
|
| 42 |
+
image = Image.open(img_path).convert('RGB')
|
| 43 |
+
|
| 44 |
+
if self.transform:
|
| 45 |
+
image = self.transform(image)
|
| 46 |
+
|
| 47 |
+
return image, label
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_transforms(is_training: bool = True) -> transforms.Compose:
|
| 51 |
+
"""Get image transforms for training or validation/test."""
|
| 52 |
+
if is_training:
|
| 53 |
+
return transforms.Compose([
|
| 54 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 55 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 56 |
+
transforms.RandomRotation(10),
|
| 57 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
| 58 |
+
transforms.ToTensor(),
|
| 59 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 60 |
+
])
|
| 61 |
+
else:
|
| 62 |
+
return transforms.Compose([
|
| 63 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 64 |
+
transforms.ToTensor(),
|
| 65 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_image_paths_and_labels(
|
| 70 |
+
data_dir: Path,
|
| 71 |
+
split: str
|
| 72 |
+
) -> Tuple[List[Path], List[int]]:
|
| 73 |
+
"""Load image paths and labels from a data split directory."""
|
| 74 |
+
image_paths = []
|
| 75 |
+
labels = []
|
| 76 |
+
|
| 77 |
+
for class_idx, class_name in enumerate(CLASS_NAMES):
|
| 78 |
+
class_dir = data_dir / split / class_name
|
| 79 |
+
if class_dir.exists():
|
| 80 |
+
for img_path in class_dir.glob('*.jpeg'):
|
| 81 |
+
image_paths.append(img_path)
|
| 82 |
+
labels.append(class_idx)
|
| 83 |
+
|
| 84 |
+
return image_paths, labels
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_train_val_split(
|
| 88 |
+
data_dir: Path = DATA_DIR,
|
| 89 |
+
val_ratio: float = 0.15,
|
| 90 |
+
seed: int = SEED
|
| 91 |
+
) -> Tuple[List[Path], List[int], List[Path], List[int]]:
|
| 92 |
+
"""Create stratified train/val split from training data."""
|
| 93 |
+
# Load all training images
|
| 94 |
+
train_paths, train_labels = load_image_paths_and_labels(data_dir, 'train')
|
| 95 |
+
|
| 96 |
+
# Stratified split
|
| 97 |
+
train_paths, val_paths, train_labels, val_labels = train_test_split(
|
| 98 |
+
train_paths, train_labels,
|
| 99 |
+
test_size=val_ratio,
|
| 100 |
+
stratify=train_labels,
|
| 101 |
+
random_state=seed
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return train_paths, train_labels, val_paths, val_labels
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_class_weights(labels: List[int]) -> torch.Tensor:
|
| 108 |
+
"""Calculate class weights for imbalanced dataset."""
|
| 109 |
+
class_counts = torch.bincount(torch.tensor(labels))
|
| 110 |
+
total = len(labels)
|
| 111 |
+
weights = total / (len(class_counts) * class_counts.float())
|
| 112 |
+
return weights
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_sampler(labels: List[int]) -> WeightedRandomSampler:
|
| 116 |
+
"""Create weighted sampler for balanced batches."""
|
| 117 |
+
class_weights = get_class_weights(labels)
|
| 118 |
+
sample_weights = [class_weights[label] for label in labels]
|
| 119 |
+
sampler = WeightedRandomSampler(
|
| 120 |
+
weights=sample_weights,
|
| 121 |
+
num_samples=len(labels),
|
| 122 |
+
replacement=True
|
| 123 |
+
)
|
| 124 |
+
return sampler
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_dataloaders(
|
| 128 |
+
data_dir: Path = DATA_DIR,
|
| 129 |
+
batch_size: int = BATCH_SIZE,
|
| 130 |
+
num_workers: int = NUM_WORKERS,
|
| 131 |
+
val_ratio: float = 0.15,
|
| 132 |
+
use_weighted_sampling: bool = True
|
| 133 |
+
) -> Tuple[DataLoader, DataLoader, DataLoader]:
|
| 134 |
+
"""Create train, validation, and test DataLoaders."""
|
| 135 |
+
|
| 136 |
+
# Create train/val split
|
| 137 |
+
train_paths, train_labels, val_paths, val_labels = create_train_val_split(
|
| 138 |
+
data_dir, val_ratio
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Load test data
|
| 142 |
+
test_paths, test_labels = load_image_paths_and_labels(data_dir, 'test')
|
| 143 |
+
|
| 144 |
+
# Create datasets
|
| 145 |
+
train_dataset = ChestXRayDataset(
|
| 146 |
+
train_paths, train_labels, transform=get_transforms(is_training=True)
|
| 147 |
+
)
|
| 148 |
+
val_dataset = ChestXRayDataset(
|
| 149 |
+
val_paths, val_labels, transform=get_transforms(is_training=False)
|
| 150 |
+
)
|
| 151 |
+
test_dataset = ChestXRayDataset(
|
| 152 |
+
test_paths, test_labels, transform=get_transforms(is_training=False)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Create sampler for training if using weighted sampling
|
| 156 |
+
train_sampler = get_sampler(train_labels) if use_weighted_sampling else None
|
| 157 |
+
|
| 158 |
+
# Only use pin_memory for CUDA (not supported on MPS)
|
| 159 |
+
pin_memory = torch.cuda.is_available()
|
| 160 |
+
|
| 161 |
+
# Create dataloaders
|
| 162 |
+
train_loader = DataLoader(
|
| 163 |
+
train_dataset,
|
| 164 |
+
batch_size=batch_size,
|
| 165 |
+
sampler=train_sampler,
|
| 166 |
+
shuffle=(train_sampler is None),
|
| 167 |
+
num_workers=num_workers,
|
| 168 |
+
pin_memory=pin_memory
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
val_loader = DataLoader(
|
| 172 |
+
val_dataset,
|
| 173 |
+
batch_size=batch_size,
|
| 174 |
+
shuffle=False,
|
| 175 |
+
num_workers=num_workers,
|
| 176 |
+
pin_memory=pin_memory
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
test_loader = DataLoader(
|
| 180 |
+
test_dataset,
|
| 181 |
+
batch_size=batch_size,
|
| 182 |
+
shuffle=False,
|
| 183 |
+
num_workers=num_workers,
|
| 184 |
+
pin_memory=pin_memory
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Print dataset info
|
| 188 |
+
print(f"Train: {len(train_dataset)} images")
|
| 189 |
+
print(f"Val: {len(val_dataset)} images")
|
| 190 |
+
print(f"Test: {len(test_dataset)} images")
|
| 191 |
+
|
| 192 |
+
return train_loader, val_loader, test_loader
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_pos_weight(labels: List[int]) -> torch.Tensor:
|
| 196 |
+
"""Calculate pos_weight for BCEWithLogitsLoss to handle class imbalance."""
|
| 197 |
+
labels_tensor = torch.tensor(labels)
|
| 198 |
+
neg_count = (labels_tensor == 0).sum().float() # NORMAL
|
| 199 |
+
pos_count = (labels_tensor == 1).sum().float() # PNEUMONIA
|
| 200 |
+
pos_weight = neg_count / pos_count
|
| 201 |
+
return pos_weight
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation functions for Pneumonia classification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Dict, Tuple
|
| 10 |
+
from sklearn.metrics import (
|
| 11 |
+
accuracy_score, precision_score, recall_score, f1_score,
|
| 12 |
+
roc_auc_score, confusion_matrix, classification_report
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from .config import CLASS_NAMES
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def predict_proba(
|
| 19 |
+
model: nn.Module,
|
| 20 |
+
loader: DataLoader,
|
| 21 |
+
device: torch.device
|
| 22 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 23 |
+
"""Get predictions, probabilities, and true labels."""
|
| 24 |
+
model.eval()
|
| 25 |
+
all_probs, all_preds, all_labels = [], [], []
|
| 26 |
+
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
for images, labels in loader:
|
| 29 |
+
images = images.to(device)
|
| 30 |
+
outputs = model(images)
|
| 31 |
+
probs = torch.sigmoid(outputs).cpu().numpy()
|
| 32 |
+
preds = (probs > 0.5).astype(int)
|
| 33 |
+
|
| 34 |
+
all_probs.extend(probs.flatten())
|
| 35 |
+
all_preds.extend(preds.flatten())
|
| 36 |
+
all_labels.extend(labels.numpy())
|
| 37 |
+
|
| 38 |
+
return np.array(all_probs), np.array(all_preds), np.array(all_labels)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict:
|
| 42 |
+
"""Compute all evaluation metrics."""
|
| 43 |
+
return {
|
| 44 |
+
'accuracy': accuracy_score(y_true, y_pred),
|
| 45 |
+
'precision': precision_score(y_true, y_pred),
|
| 46 |
+
'recall': recall_score(y_true, y_pred),
|
| 47 |
+
'f1': f1_score(y_true, y_pred),
|
| 48 |
+
'roc_auc': roc_auc_score(y_true, y_proba),
|
| 49 |
+
'confusion_matrix': confusion_matrix(y_true, y_pred)
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def evaluate_model(
|
| 54 |
+
model: nn.Module,
|
| 55 |
+
loader: DataLoader,
|
| 56 |
+
device: torch.device
|
| 57 |
+
) -> Dict:
|
| 58 |
+
"""Full evaluation on a dataset."""
|
| 59 |
+
probs, preds, labels = predict_proba(model, loader, device)
|
| 60 |
+
metrics = compute_metrics(labels, preds, probs)
|
| 61 |
+
|
| 62 |
+
print("=" * 50)
|
| 63 |
+
print("EVALUATION RESULTS")
|
| 64 |
+
print("=" * 50)
|
| 65 |
+
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
| 66 |
+
print(f"Precision: {metrics['precision']:.4f}")
|
| 67 |
+
print(f"Recall: {metrics['recall']:.4f}")
|
| 68 |
+
print(f"F1 Score: {metrics['f1']:.4f}")
|
| 69 |
+
print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
|
| 70 |
+
print("\nConfusion Matrix:")
|
| 71 |
+
print(f" {CLASS_NAMES[0]:>10} {CLASS_NAMES[1]:>10}")
|
| 72 |
+
for i, row in enumerate(metrics['confusion_matrix']):
|
| 73 |
+
print(f" {CLASS_NAMES[i]:>10} {row[0]:>10} {row[1]:>10}")
|
| 74 |
+
|
| 75 |
+
print("\nClassification Report:")
|
| 76 |
+
print(classification_report(labels, preds, target_names=CLASS_NAMES))
|
| 77 |
+
|
| 78 |
+
return metrics
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_predictions_with_paths(
|
| 82 |
+
model: nn.Module,
|
| 83 |
+
dataset,
|
| 84 |
+
device: torch.device
|
| 85 |
+
) -> list:
|
| 86 |
+
"""Get predictions with image paths for error analysis."""
|
| 87 |
+
model.eval()
|
| 88 |
+
results = []
|
| 89 |
+
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
for idx in range(len(dataset)):
|
| 92 |
+
image, label = dataset[idx]
|
| 93 |
+
image = image.unsqueeze(0).to(device)
|
| 94 |
+
|
| 95 |
+
output = model(image)
|
| 96 |
+
prob = torch.sigmoid(output).item()
|
| 97 |
+
pred = 1 if prob > 0.5 else 0
|
| 98 |
+
|
| 99 |
+
results.append({
|
| 100 |
+
'path': dataset.image_paths[idx],
|
| 101 |
+
'true_label': label,
|
| 102 |
+
'pred_label': pred,
|
| 103 |
+
'probability': prob,
|
| 104 |
+
'correct': pred == label
|
| 105 |
+
})
|
| 106 |
+
|
| 107 |
+
return results
|
src/export.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ONNX export utilities for model deployment.
|
| 3 |
+
|
| 4 |
+
ONNX (Open Neural Network Exchange) is a universal format that allows
|
| 5 |
+
models to run on different frameworks and platforms:
|
| 6 |
+
- TensorFlow, PyTorch, etc.
|
| 7 |
+
- Mobile devices (iOS, Android)
|
| 8 |
+
- Web browsers (ONNX.js)
|
| 9 |
+
- C++, Java, and other languages
|
| 10 |
+
- Optimized inference servers
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Tuple, Optional
|
| 17 |
+
|
| 18 |
+
from .config import CHECKPOINT_PATH, MODEL_DIR, IMAGE_SIZE
|
| 19 |
+
from .model import create_model, get_device
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def export_to_onnx(
|
| 23 |
+
checkpoint_path: Path = CHECKPOINT_PATH,
|
| 24 |
+
output_path: Optional[Path] = None,
|
| 25 |
+
opset_version: int = 18
|
| 26 |
+
) -> Path:
|
| 27 |
+
"""
|
| 28 |
+
Export PyTorch model to ONNX format.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
checkpoint_path: Path to the PyTorch checkpoint
|
| 32 |
+
output_path: Path for the ONNX model (default: models/best_model.onnx)
|
| 33 |
+
opset_version: ONNX opset version (14 is widely compatible)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Path to the exported ONNX model
|
| 37 |
+
"""
|
| 38 |
+
if output_path is None:
|
| 39 |
+
output_path = MODEL_DIR / "best_model.onnx"
|
| 40 |
+
|
| 41 |
+
# Load model
|
| 42 |
+
device = torch.device("cpu") # Export on CPU for compatibility
|
| 43 |
+
model = create_model(pretrained=False, freeze_backbone=False, device=device)
|
| 44 |
+
|
| 45 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 46 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 47 |
+
model.eval()
|
| 48 |
+
|
| 49 |
+
# Create dummy input (batch_size=1, channels=3, height=224, width=224)
|
| 50 |
+
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
|
| 51 |
+
|
| 52 |
+
# Export to ONNX
|
| 53 |
+
torch.onnx.export(
|
| 54 |
+
model,
|
| 55 |
+
dummy_input,
|
| 56 |
+
output_path,
|
| 57 |
+
export_params=True,
|
| 58 |
+
opset_version=opset_version,
|
| 59 |
+
do_constant_folding=True, # Optimize constants
|
| 60 |
+
input_names=['image'],
|
| 61 |
+
output_names=['logits'],
|
| 62 |
+
dynamic_axes={
|
| 63 |
+
'image': {0: 'batch_size'}, # Variable batch size
|
| 64 |
+
'logits': {0: 'batch_size'}
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print(f"Model exported to: {output_path}")
|
| 69 |
+
print(f"File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
|
| 70 |
+
|
| 71 |
+
return output_path
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def validate_onnx_model(
|
| 75 |
+
onnx_path: Path,
|
| 76 |
+
checkpoint_path: Path = CHECKPOINT_PATH,
|
| 77 |
+
rtol: float = 1e-3,
|
| 78 |
+
atol: float = 1e-5
|
| 79 |
+
) -> bool:
|
| 80 |
+
"""
|
| 81 |
+
Validate that ONNX model produces same outputs as PyTorch model.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
onnx_path: Path to ONNX model
|
| 85 |
+
checkpoint_path: Path to PyTorch checkpoint
|
| 86 |
+
rtol: Relative tolerance for comparison
|
| 87 |
+
atol: Absolute tolerance for comparison
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
True if outputs match, False otherwise
|
| 91 |
+
"""
|
| 92 |
+
import onnx
|
| 93 |
+
import onnxruntime as ort
|
| 94 |
+
|
| 95 |
+
# Check ONNX model is valid
|
| 96 |
+
onnx_model = onnx.load(onnx_path)
|
| 97 |
+
onnx.checker.check_model(onnx_model)
|
| 98 |
+
print("ONNX model structure is valid")
|
| 99 |
+
|
| 100 |
+
# Load PyTorch model
|
| 101 |
+
device = torch.device("cpu")
|
| 102 |
+
model = create_model(pretrained=False, freeze_backbone=False, device=device)
|
| 103 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 104 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 105 |
+
model.eval()
|
| 106 |
+
|
| 107 |
+
# Create test input
|
| 108 |
+
test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
|
| 109 |
+
|
| 110 |
+
# Get PyTorch output
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
pytorch_output = model(test_input).numpy()
|
| 113 |
+
|
| 114 |
+
# Get ONNX output
|
| 115 |
+
ort_session = ort.InferenceSession(str(onnx_path))
|
| 116 |
+
onnx_output = ort_session.run(
|
| 117 |
+
None,
|
| 118 |
+
{'image': test_input.numpy()}
|
| 119 |
+
)[0]
|
| 120 |
+
|
| 121 |
+
# Compare outputs
|
| 122 |
+
is_close = np.allclose(pytorch_output, onnx_output, rtol=rtol, atol=atol)
|
| 123 |
+
|
| 124 |
+
if is_close:
|
| 125 |
+
print("Validation PASSED: ONNX outputs match PyTorch outputs")
|
| 126 |
+
print(f" PyTorch output: {pytorch_output.flatten()[:5]}...")
|
| 127 |
+
print(f" ONNX output: {onnx_output.flatten()[:5]}...")
|
| 128 |
+
else:
|
| 129 |
+
print("Validation FAILED: Outputs do not match!")
|
| 130 |
+
print(f" Max difference: {np.max(np.abs(pytorch_output - onnx_output))}")
|
| 131 |
+
|
| 132 |
+
return is_close
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def predict_with_onnx(
|
| 136 |
+
onnx_path: Path,
|
| 137 |
+
image_tensor: np.ndarray
|
| 138 |
+
) -> Tuple[str, float]:
|
| 139 |
+
"""
|
| 140 |
+
Run inference using ONNX Runtime.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
onnx_path: Path to ONNX model
|
| 144 |
+
image_tensor: Preprocessed image as numpy array (1, 3, 224, 224)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Tuple of (predicted_class, confidence)
|
| 148 |
+
"""
|
| 149 |
+
import onnxruntime as ort
|
| 150 |
+
from .config import CLASS_NAMES
|
| 151 |
+
|
| 152 |
+
# Create session
|
| 153 |
+
ort_session = ort.InferenceSession(str(onnx_path))
|
| 154 |
+
|
| 155 |
+
# Run inference
|
| 156 |
+
logits = ort_session.run(
|
| 157 |
+
None,
|
| 158 |
+
{'image': image_tensor.astype(np.float32)}
|
| 159 |
+
)[0]
|
| 160 |
+
|
| 161 |
+
# Apply sigmoid and get prediction
|
| 162 |
+
prob = 1 / (1 + np.exp(-logits[0, 0])) # Sigmoid
|
| 163 |
+
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
|
| 164 |
+
confidence = float(prob if prob > 0.5 else 1 - prob)
|
| 165 |
+
|
| 166 |
+
return pred_class, confidence
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
# Export model
|
| 171 |
+
print("=" * 50)
|
| 172 |
+
print("EXPORTING MODEL TO ONNX")
|
| 173 |
+
print("=" * 50)
|
| 174 |
+
|
| 175 |
+
onnx_path = export_to_onnx()
|
| 176 |
+
|
| 177 |
+
print("\n" + "=" * 50)
|
| 178 |
+
print("VALIDATING ONNX MODEL")
|
| 179 |
+
print("=" * 50)
|
| 180 |
+
|
| 181 |
+
validate_onnx_model(onnx_path)
|
| 182 |
+
|
| 183 |
+
print("\n" + "=" * 50)
|
| 184 |
+
print("TESTING ONNX INFERENCE")
|
| 185 |
+
print("=" * 50)
|
| 186 |
+
|
| 187 |
+
# Test with random input
|
| 188 |
+
test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
|
| 189 |
+
pred_class, confidence = predict_with_onnx(onnx_path, test_input)
|
| 190 |
+
print(f"Test prediction: {pred_class} ({confidence:.1%})")
|
src/gradcam.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grad-CAM visualization for model interpretability.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Union
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from pytorch_grad_cam import GradCAM
|
| 13 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 14 |
+
|
| 15 |
+
from .dataset import get_transforms
|
| 16 |
+
from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_gradcam(model, target_layer=None):
|
| 20 |
+
"""Create GradCAM object for the model."""
|
| 21 |
+
if target_layer is None:
|
| 22 |
+
# Use the last conv layer of EfficientNet
|
| 23 |
+
target_layer = model.backbone.features[-1]
|
| 24 |
+
return GradCAM(model=model, target_layers=[target_layer])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
|
| 28 |
+
"""Denormalize tensor to numpy image [0,1]."""
|
| 29 |
+
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
|
| 30 |
+
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
|
| 31 |
+
img = tensor.cpu() * std + mean
|
| 32 |
+
img = img.permute(1, 2, 0).numpy()
|
| 33 |
+
return np.clip(img, 0, 1)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_gradcam(
|
| 37 |
+
model,
|
| 38 |
+
image: Union[str, Path, Image.Image],
|
| 39 |
+
device: torch.device
|
| 40 |
+
) -> tuple:
|
| 41 |
+
"""Generate Grad-CAM heatmap for an image."""
|
| 42 |
+
model.eval()
|
| 43 |
+
|
| 44 |
+
# Load and transform image
|
| 45 |
+
if isinstance(image, (str, Path)):
|
| 46 |
+
image = Image.open(image).convert('RGB')
|
| 47 |
+
|
| 48 |
+
transform = get_transforms(is_training=False)
|
| 49 |
+
img_tensor = transform(image).unsqueeze(0).to(device)
|
| 50 |
+
|
| 51 |
+
# Get prediction
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
output = model(img_tensor)
|
| 54 |
+
prob = torch.sigmoid(output).item()
|
| 55 |
+
|
| 56 |
+
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
|
| 57 |
+
confidence = prob if prob > 0.5 else 1 - prob
|
| 58 |
+
|
| 59 |
+
# Generate Grad-CAM
|
| 60 |
+
cam = get_gradcam(model)
|
| 61 |
+
grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0]
|
| 62 |
+
|
| 63 |
+
# Create visualization
|
| 64 |
+
rgb_img = denormalize_image(img_tensor[0])
|
| 65 |
+
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
|
| 66 |
+
|
| 67 |
+
return cam_image, pred_class, confidence, rgb_img
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def plot_gradcam(
|
| 71 |
+
model,
|
| 72 |
+
image_path: Union[str, Path],
|
| 73 |
+
true_label: str,
|
| 74 |
+
device: torch.device,
|
| 75 |
+
save_path: str = None
|
| 76 |
+
):
|
| 77 |
+
"""Plot original image with Grad-CAM overlay."""
|
| 78 |
+
cam_image, pred_class, confidence, original = generate_gradcam(model, image_path, device)
|
| 79 |
+
|
| 80 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
| 81 |
+
|
| 82 |
+
# Original
|
| 83 |
+
axes[0].imshow(original)
|
| 84 |
+
axes[0].set_title(f"Original\nTrue: {true_label}")
|
| 85 |
+
axes[0].axis('off')
|
| 86 |
+
|
| 87 |
+
# Grad-CAM
|
| 88 |
+
color = 'green' if pred_class == true_label else 'red'
|
| 89 |
+
axes[1].imshow(cam_image)
|
| 90 |
+
axes[1].set_title(f"Grad-CAM\nPred: {pred_class} ({confidence:.1%})", color=color)
|
| 91 |
+
axes[1].axis('off')
|
| 92 |
+
|
| 93 |
+
plt.tight_layout()
|
| 94 |
+
|
| 95 |
+
if save_path:
|
| 96 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 97 |
+
|
| 98 |
+
plt.show()
|
| 99 |
+
return pred_class, confidence
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def plot_gradcam_grid(
|
| 103 |
+
model,
|
| 104 |
+
image_paths: list,
|
| 105 |
+
true_labels: list,
|
| 106 |
+
device: torch.device,
|
| 107 |
+
save_path: str = None,
|
| 108 |
+
title: str = "Grad-CAM Visualizations"
|
| 109 |
+
):
|
| 110 |
+
"""Plot grid of Grad-CAM visualizations."""
|
| 111 |
+
n = len(image_paths)
|
| 112 |
+
fig, axes = plt.subplots(n, 2, figsize=(8, 3 * n))
|
| 113 |
+
|
| 114 |
+
if n == 1:
|
| 115 |
+
axes = axes.reshape(1, -1)
|
| 116 |
+
|
| 117 |
+
for i, (path, true_label) in enumerate(zip(image_paths, true_labels)):
|
| 118 |
+
cam_image, pred_class, confidence, original = generate_gradcam(model, path, device)
|
| 119 |
+
|
| 120 |
+
# Original
|
| 121 |
+
axes[i, 0].imshow(original)
|
| 122 |
+
axes[i, 0].set_title(f"True: {true_label}")
|
| 123 |
+
axes[i, 0].axis('off')
|
| 124 |
+
|
| 125 |
+
# Grad-CAM
|
| 126 |
+
color = 'green' if pred_class == true_label else 'red'
|
| 127 |
+
axes[i, 1].imshow(cam_image)
|
| 128 |
+
axes[i, 1].set_title(f"Pred: {pred_class} ({confidence:.1%})", color=color)
|
| 129 |
+
axes[i, 1].axis('off')
|
| 130 |
+
|
| 131 |
+
plt.suptitle(title, fontsize=14, fontweight='bold')
|
| 132 |
+
plt.tight_layout()
|
| 133 |
+
|
| 134 |
+
if save_path:
|
| 135 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 136 |
+
|
| 137 |
+
plt.show()
|
src/model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EfficientNet-B0 model for Pneumonia classification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torchvision import models
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
from .config import DROPOUT_RATE, NUM_CLASSES
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PneumoniaClassifier(nn.Module):
|
| 14 |
+
"""EfficientNet-B0 based classifier for chest X-ray pneumonia detection."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
pretrained: bool = True,
|
| 19 |
+
dropout_rate: float = DROPOUT_RATE,
|
| 20 |
+
freeze_backbone: bool = True
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# Load pretrained EfficientNet-B0
|
| 25 |
+
weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
|
| 26 |
+
self.backbone = models.efficientnet_b0(weights=weights)
|
| 27 |
+
|
| 28 |
+
# Get the number of features from the classifier
|
| 29 |
+
in_features = self.backbone.classifier[1].in_features # 1280
|
| 30 |
+
|
| 31 |
+
# Replace classifier head
|
| 32 |
+
self.backbone.classifier = nn.Sequential(
|
| 33 |
+
nn.Dropout(p=dropout_rate, inplace=True),
|
| 34 |
+
nn.Linear(in_features, NUM_CLASSES)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Freeze backbone if specified
|
| 38 |
+
if freeze_backbone:
|
| 39 |
+
self.freeze_backbone()
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
return self.backbone(x)
|
| 43 |
+
|
| 44 |
+
def freeze_backbone(self):
|
| 45 |
+
"""Freeze all layers except the classifier."""
|
| 46 |
+
for param in self.backbone.features.parameters():
|
| 47 |
+
param.requires_grad = False
|
| 48 |
+
|
| 49 |
+
def unfreeze_backbone(self):
|
| 50 |
+
"""Unfreeze all layers for fine-tuning."""
|
| 51 |
+
for param in self.backbone.features.parameters():
|
| 52 |
+
param.requires_grad = True
|
| 53 |
+
|
| 54 |
+
def get_param_counts(self) -> Tuple[int, int]:
|
| 55 |
+
"""Return (trainable_params, total_params)."""
|
| 56 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 57 |
+
total = sum(p.numel() for p in self.parameters())
|
| 58 |
+
return trainable, total
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def create_model(
|
| 62 |
+
pretrained: bool = True,
|
| 63 |
+
dropout_rate: float = DROPOUT_RATE,
|
| 64 |
+
freeze_backbone: bool = True,
|
| 65 |
+
device: str = None
|
| 66 |
+
) -> PneumoniaClassifier:
|
| 67 |
+
"""Factory function to create the model."""
|
| 68 |
+
if device is None:
|
| 69 |
+
device = "mps" if torch.backends.mps.is_available() else \
|
| 70 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 71 |
+
|
| 72 |
+
model = PneumoniaClassifier(
|
| 73 |
+
pretrained=pretrained,
|
| 74 |
+
dropout_rate=dropout_rate,
|
| 75 |
+
freeze_backbone=freeze_backbone
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return model.to(device)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_device() -> torch.device:
|
| 82 |
+
"""Get the best available device."""
|
| 83 |
+
if torch.backends.mps.is_available():
|
| 84 |
+
return torch.device("mps")
|
| 85 |
+
elif torch.cuda.is_available():
|
| 86 |
+
return torch.device("cuda")
|
| 87 |
+
return torch.device("cpu")
|
src/predict.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference functions for Pneumonia classification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Union, Tuple
|
| 10 |
+
|
| 11 |
+
from .dataset import get_transforms
|
| 12 |
+
from .config import CLASS_NAMES, CHECKPOINT_PATH
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module:
|
| 16 |
+
"""Load model from checkpoint."""
|
| 17 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 18 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 19 |
+
model.eval()
|
| 20 |
+
return model
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def predict_image(
|
| 24 |
+
model: nn.Module,
|
| 25 |
+
image: Union[str, Path, Image.Image],
|
| 26 |
+
device: torch.device
|
| 27 |
+
) -> Tuple[str, float]:
|
| 28 |
+
"""Predict class for a single image."""
|
| 29 |
+
model.eval()
|
| 30 |
+
|
| 31 |
+
# Load image if path
|
| 32 |
+
if isinstance(image, (str, Path)):
|
| 33 |
+
image = Image.open(image).convert('RGB')
|
| 34 |
+
|
| 35 |
+
# Transform
|
| 36 |
+
transform = get_transforms(is_training=False)
|
| 37 |
+
img_tensor = transform(image).unsqueeze(0).to(device)
|
| 38 |
+
|
| 39 |
+
# Predict
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
output = model(img_tensor)
|
| 42 |
+
prob = torch.sigmoid(output).item()
|
| 43 |
+
|
| 44 |
+
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
|
| 45 |
+
confidence = prob if prob > 0.5 else 1 - prob
|
| 46 |
+
|
| 47 |
+
return pred_class, confidence
|
src/train.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training pipeline for Pneumonia classification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.optim import AdamW
|
| 8 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Optional, Tuple
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
| 15 |
+
|
| 16 |
+
from .config import (
|
| 17 |
+
STAGE1_EPOCHS, STAGE1_LR, STAGE2_EPOCHS, STAGE2_LR,
|
| 18 |
+
WEIGHT_DECAY, SCHEDULER_PATIENCE, SCHEDULER_FACTOR,
|
| 19 |
+
EARLY_STOP_PATIENCE, CHECKPOINT_PATH, MODEL_DIR
|
| 20 |
+
)
|
| 21 |
+
from .model import PneumoniaClassifier, get_device
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class EarlyStopping:
|
| 25 |
+
"""Early stopping to prevent overfitting."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, patience: int = 7, min_delta: float = 0.001):
|
| 28 |
+
self.patience = patience
|
| 29 |
+
self.min_delta = min_delta
|
| 30 |
+
self.counter = 0
|
| 31 |
+
self.best_loss = float('inf')
|
| 32 |
+
self.should_stop = False
|
| 33 |
+
|
| 34 |
+
def __call__(self, val_loss: float) -> bool:
|
| 35 |
+
if val_loss < self.best_loss - self.min_delta:
|
| 36 |
+
self.best_loss = val_loss
|
| 37 |
+
self.counter = 0
|
| 38 |
+
else:
|
| 39 |
+
self.counter += 1
|
| 40 |
+
if self.counter >= self.patience:
|
| 41 |
+
self.should_stop = True
|
| 42 |
+
return self.should_stop
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def train_epoch(
|
| 46 |
+
model: nn.Module,
|
| 47 |
+
loader: DataLoader,
|
| 48 |
+
criterion: nn.Module,
|
| 49 |
+
optimizer: torch.optim.Optimizer,
|
| 50 |
+
device: torch.device
|
| 51 |
+
) -> Tuple[float, float]:
|
| 52 |
+
"""Train for one epoch."""
|
| 53 |
+
model.train()
|
| 54 |
+
total_loss = 0
|
| 55 |
+
all_preds, all_labels = [], []
|
| 56 |
+
|
| 57 |
+
for images, labels in loader:
|
| 58 |
+
images = images.to(device)
|
| 59 |
+
labels = labels.float().unsqueeze(1).to(device)
|
| 60 |
+
|
| 61 |
+
optimizer.zero_grad()
|
| 62 |
+
outputs = model(images)
|
| 63 |
+
loss = criterion(outputs, labels)
|
| 64 |
+
loss.backward()
|
| 65 |
+
optimizer.step()
|
| 66 |
+
|
| 67 |
+
total_loss += loss.item() * images.size(0)
|
| 68 |
+
preds = (torch.sigmoid(outputs) > 0.5).int()
|
| 69 |
+
all_preds.extend(preds.cpu().numpy())
|
| 70 |
+
all_labels.extend(labels.cpu().numpy())
|
| 71 |
+
|
| 72 |
+
avg_loss = total_loss / len(loader.dataset)
|
| 73 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 74 |
+
return avg_loss, accuracy
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def validate(
|
| 78 |
+
model: nn.Module,
|
| 79 |
+
loader: DataLoader,
|
| 80 |
+
criterion: nn.Module,
|
| 81 |
+
device: torch.device
|
| 82 |
+
) -> Dict[str, float]:
|
| 83 |
+
"""Validate the model."""
|
| 84 |
+
model.eval()
|
| 85 |
+
total_loss = 0
|
| 86 |
+
all_preds, all_labels = [], []
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
for images, labels in loader:
|
| 90 |
+
images = images.to(device)
|
| 91 |
+
labels = labels.float().unsqueeze(1).to(device)
|
| 92 |
+
|
| 93 |
+
outputs = model(images)
|
| 94 |
+
loss = criterion(outputs, labels)
|
| 95 |
+
|
| 96 |
+
total_loss += loss.item() * images.size(0)
|
| 97 |
+
preds = (torch.sigmoid(outputs) > 0.5).int()
|
| 98 |
+
all_preds.extend(preds.cpu().numpy())
|
| 99 |
+
all_labels.extend(labels.cpu().numpy())
|
| 100 |
+
|
| 101 |
+
avg_loss = total_loss / len(loader.dataset)
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
'loss': avg_loss,
|
| 105 |
+
'accuracy': accuracy_score(all_labels, all_preds),
|
| 106 |
+
'precision': precision_score(all_labels, all_preds, zero_division=0),
|
| 107 |
+
'recall': recall_score(all_labels, all_preds, zero_division=0),
|
| 108 |
+
'f1': f1_score(all_labels, all_preds, zero_division=0)
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def train(
|
| 113 |
+
model: PneumoniaClassifier,
|
| 114 |
+
train_loader: DataLoader,
|
| 115 |
+
val_loader: DataLoader,
|
| 116 |
+
pos_weight: torch.Tensor,
|
| 117 |
+
epochs: int,
|
| 118 |
+
lr: float,
|
| 119 |
+
device: torch.device,
|
| 120 |
+
stage: str = "stage1",
|
| 121 |
+
use_wandb: bool = True,
|
| 122 |
+
wandb_run = None
|
| 123 |
+
) -> Dict[str, list]:
|
| 124 |
+
"""Training loop with validation."""
|
| 125 |
+
|
| 126 |
+
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
|
| 127 |
+
optimizer = AdamW(
|
| 128 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 129 |
+
lr=lr,
|
| 130 |
+
weight_decay=WEIGHT_DECAY
|
| 131 |
+
)
|
| 132 |
+
scheduler = ReduceLROnPlateau(
|
| 133 |
+
optimizer, mode='min',
|
| 134 |
+
patience=SCHEDULER_PATIENCE,
|
| 135 |
+
factor=SCHEDULER_FACTOR
|
| 136 |
+
)
|
| 137 |
+
early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE)
|
| 138 |
+
|
| 139 |
+
history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'val_f1': [], 'lr': []}
|
| 140 |
+
best_val_loss = float('inf')
|
| 141 |
+
|
| 142 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 143 |
+
|
| 144 |
+
for epoch in range(epochs):
|
| 145 |
+
start = time.time()
|
| 146 |
+
|
| 147 |
+
# Train
|
| 148 |
+
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
| 149 |
+
|
| 150 |
+
# Validate
|
| 151 |
+
val_metrics = validate(model, val_loader, criterion, device)
|
| 152 |
+
|
| 153 |
+
# Get current LR
|
| 154 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 155 |
+
|
| 156 |
+
# Update scheduler
|
| 157 |
+
scheduler.step(val_metrics['loss'])
|
| 158 |
+
|
| 159 |
+
# Log
|
| 160 |
+
elapsed = time.time() - start
|
| 161 |
+
print(f"[{stage}] Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) | "
|
| 162 |
+
f"Train Loss: {train_loss:.4f} | "
|
| 163 |
+
f"Val Loss: {val_metrics['loss']:.4f} | "
|
| 164 |
+
f"Val Acc: {val_metrics['accuracy']:.3f} | "
|
| 165 |
+
f"Val F1: {val_metrics['f1']:.3f} | "
|
| 166 |
+
f"LR: {current_lr:.2e}")
|
| 167 |
+
|
| 168 |
+
# W&B logging
|
| 169 |
+
if use_wandb and wandb_run:
|
| 170 |
+
wandb_run.log({
|
| 171 |
+
f"{stage}/train_loss": train_loss,
|
| 172 |
+
f"{stage}/train_acc": train_acc,
|
| 173 |
+
f"{stage}/val_loss": val_metrics['loss'],
|
| 174 |
+
f"{stage}/val_acc": val_metrics['accuracy'],
|
| 175 |
+
f"{stage}/val_precision": val_metrics['precision'],
|
| 176 |
+
f"{stage}/val_recall": val_metrics['recall'],
|
| 177 |
+
f"{stage}/val_f1": val_metrics['f1'],
|
| 178 |
+
f"{stage}/lr": current_lr,
|
| 179 |
+
"epoch": epoch + 1
|
| 180 |
+
})
|
| 181 |
+
|
| 182 |
+
# Save history
|
| 183 |
+
history['train_loss'].append(train_loss)
|
| 184 |
+
history['val_loss'].append(val_metrics['loss'])
|
| 185 |
+
history['val_acc'].append(val_metrics['accuracy'])
|
| 186 |
+
history['val_f1'].append(val_metrics['f1'])
|
| 187 |
+
history['lr'].append(current_lr)
|
| 188 |
+
|
| 189 |
+
# Save best model
|
| 190 |
+
if val_metrics['loss'] < best_val_loss:
|
| 191 |
+
best_val_loss = val_metrics['loss']
|
| 192 |
+
torch.save({
|
| 193 |
+
'epoch': epoch + 1,
|
| 194 |
+
'model_state_dict': model.state_dict(),
|
| 195 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 196 |
+
'val_loss': best_val_loss,
|
| 197 |
+
'val_metrics': val_metrics
|
| 198 |
+
}, CHECKPOINT_PATH)
|
| 199 |
+
print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
|
| 200 |
+
|
| 201 |
+
# Early stopping
|
| 202 |
+
if early_stopping(val_metrics['loss']):
|
| 203 |
+
print(f"Early stopping triggered at epoch {epoch+1}")
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
return history
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def train_two_stage(
|
| 210 |
+
model: PneumoniaClassifier,
|
| 211 |
+
train_loader: DataLoader,
|
| 212 |
+
val_loader: DataLoader,
|
| 213 |
+
pos_weight: torch.Tensor,
|
| 214 |
+
device: torch.device,
|
| 215 |
+
use_wandb: bool = True,
|
| 216 |
+
wandb_run = None
|
| 217 |
+
) -> Dict[str, list]:
|
| 218 |
+
"""Two-stage training: frozen backbone then fine-tuning."""
|
| 219 |
+
|
| 220 |
+
# Stage 1: Train classifier only
|
| 221 |
+
print("\n" + "=" * 60)
|
| 222 |
+
print("STAGE 1: Training classifier (backbone frozen)")
|
| 223 |
+
print("=" * 60)
|
| 224 |
+
model.freeze_backbone()
|
| 225 |
+
trainable, total = model.get_param_counts()
|
| 226 |
+
print(f"Trainable params: {trainable:,} / {total:,}")
|
| 227 |
+
|
| 228 |
+
history1 = train(
|
| 229 |
+
model, train_loader, val_loader, pos_weight,
|
| 230 |
+
epochs=STAGE1_EPOCHS, lr=STAGE1_LR, device=device,
|
| 231 |
+
stage="stage1", use_wandb=use_wandb, wandb_run=wandb_run
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Stage 2: Fine-tune entire network
|
| 235 |
+
print("\n" + "=" * 60)
|
| 236 |
+
print("STAGE 2: Fine-tuning entire network")
|
| 237 |
+
print("=" * 60)
|
| 238 |
+
model.unfreeze_backbone()
|
| 239 |
+
trainable, total = model.get_param_counts()
|
| 240 |
+
print(f"Trainable params: {trainable:,} / {total:,}")
|
| 241 |
+
|
| 242 |
+
history2 = train(
|
| 243 |
+
model, train_loader, val_loader, pos_weight,
|
| 244 |
+
epochs=STAGE2_EPOCHS, lr=STAGE2_LR, device=device,
|
| 245 |
+
stage="stage2", use_wandb=use_wandb, wandb_run=wandb_run
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Combine histories
|
| 249 |
+
history = {k: history1[k] + history2[k] for k in history1}
|
| 250 |
+
return history
|
src/utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for visualization and helpers.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
"""Denormalize image tensor from ImageNet normalization."""
|
| 15 |
+
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
|
| 16 |
+
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
|
| 17 |
+
return tensor * std + mean
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def show_batch(
|
| 21 |
+
images: torch.Tensor,
|
| 22 |
+
labels: torch.Tensor,
|
| 23 |
+
predictions: Optional[torch.Tensor] = None,
|
| 24 |
+
n_images: int = 8,
|
| 25 |
+
save_path: Optional[str] = None
|
| 26 |
+
):
|
| 27 |
+
"""Display a batch of images with labels."""
|
| 28 |
+
n_images = min(n_images, len(images))
|
| 29 |
+
cols = 4
|
| 30 |
+
rows = (n_images + cols - 1) // cols
|
| 31 |
+
|
| 32 |
+
fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
|
| 33 |
+
axes = axes.flatten() if rows > 1 else [axes] if cols == 1 else axes
|
| 34 |
+
|
| 35 |
+
for idx in range(n_images):
|
| 36 |
+
img = denormalize(images[idx]).permute(1, 2, 0).numpy()
|
| 37 |
+
img = np.clip(img, 0, 1)
|
| 38 |
+
|
| 39 |
+
axes[idx].imshow(img)
|
| 40 |
+
axes[idx].axis('off')
|
| 41 |
+
|
| 42 |
+
label = CLASS_NAMES[labels[idx]]
|
| 43 |
+
title = f"True: {label}"
|
| 44 |
+
|
| 45 |
+
if predictions is not None:
|
| 46 |
+
pred = CLASS_NAMES[predictions[idx]]
|
| 47 |
+
color = 'green' if pred == label else 'red'
|
| 48 |
+
title += f"\nPred: {pred}"
|
| 49 |
+
axes[idx].set_title(title, color=color, fontsize=10)
|
| 50 |
+
else:
|
| 51 |
+
axes[idx].set_title(title, fontsize=10)
|
| 52 |
+
|
| 53 |
+
# Hide empty subplots
|
| 54 |
+
for idx in range(n_images, len(axes)):
|
| 55 |
+
axes[idx].axis('off')
|
| 56 |
+
|
| 57 |
+
plt.tight_layout()
|
| 58 |
+
|
| 59 |
+
if save_path:
|
| 60 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 61 |
+
|
| 62 |
+
plt.show()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def set_seed(seed: int = 42):
|
| 66 |
+
"""Set random seed for reproducibility."""
|
| 67 |
+
import random
|
| 68 |
+
random.seed(seed)
|
| 69 |
+
np.random.seed(seed)
|
| 70 |
+
torch.manual_seed(seed)
|
| 71 |
+
if torch.cuda.is_available():
|
| 72 |
+
torch.cuda.manual_seed_all(seed)
|
| 73 |
+
if torch.backends.mps.is_available():
|
| 74 |
+
torch.mps.manual_seed(seed)
|