GitHub Actions commited on
Commit
af59988
·
1 Parent(s): e642110

Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6

Browse files
Files changed (16) hide show
  1. Dockerfile +3 -2
  2. api/Dockerfile +24 -0
  3. api/__init__.py +9 -0
  4. api/main.py +245 -0
  5. api/schemas.py +72 -0
  6. requirements.txt +43 -5
  7. src/__init__.py +6 -0
  8. src/config.py +112 -0
  9. src/dataset.py +201 -0
  10. src/evaluate.py +107 -0
  11. src/export.py +190 -0
  12. src/gradcam.py +137 -0
  13. src/model.py +87 -0
  14. src/predict.py +47 -0
  15. src/train.py +250 -0
  16. src/utils.py +74 -0
Dockerfile CHANGED
@@ -13,11 +13,12 @@ COPY requirements.txt .
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
  # Copy app files
16
- COPY app.py .
 
17
  COPY models/ models/
18
 
19
  # Expose port 7860 (HF Spaces default)
20
  EXPOSE 7860
21
 
22
  # Run the API
23
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
  # Copy app files
16
+ COPY src/ src/
17
+ COPY api/ api/
18
  COPY models/ models/
19
 
20
  # Expose port 7860 (HF Spaces default)
21
  EXPOSE 7860
22
 
23
  # Run the API
24
+ CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
api/Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1 \
8
+ libglib2.0-0 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for caching
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy app files
16
+ COPY src/ src/
17
+ COPY api/ api/
18
+ COPY models/ models/
19
+
20
+ # Expose port 7860 (HF Spaces default)
21
+ EXPOSE 7860
22
+
23
+ # Run the API
24
+ CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
api/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for Pneumonia Detection."""
2
+
3
+ from .main import app
4
+ from .schemas import (
5
+ HealthResponse,
6
+ PredictionResponse,
7
+ GradCAMResponse,
8
+ ErrorResponse
9
+ )
api/main.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for Pneumonia Detection API.
3
+
4
+ Run with: uvicorn api.main:app --reload
5
+ """
6
+
7
+ import io
8
+ import time
9
+ import base64
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from PIL import Image
14
+ from fastapi import FastAPI, UploadFile, File, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import JSONResponse
17
+
18
+ from .schemas import (
19
+ HealthResponse,
20
+ PredictionResponse,
21
+ GradCAMResponse,
22
+ ErrorResponse
23
+ )
24
+
25
+ import sys
26
+ sys.path.insert(0, str(Path(__file__).parent.parent))
27
+
28
+ from src.config import CHECKPOINT_PATH, CLASS_NAMES, CONFIDENCE_THRESHOLD
29
+ from src.model import create_model, get_device
30
+ from src.predict import load_model, predict_image
31
+ from src.gradcam import generate_gradcam
32
+
33
+ # =============================================================================
34
+ # App Configuration
35
+ # =============================================================================
36
+
37
+ app = FastAPI(
38
+ title="Pneumonia Detection API",
39
+ description="Deep learning API for detecting pneumonia from chest X-ray images using EfficientNet-B0",
40
+ version="1.0.0",
41
+ docs_url="/docs",
42
+ redoc_url="/redoc"
43
+ )
44
+
45
+ # CORS middleware for frontend access
46
+ app.add_middleware(
47
+ CORSMiddleware,
48
+ allow_origins=["*"], # Configure appropriately for production
49
+ allow_credentials=True,
50
+ allow_methods=["*"],
51
+ allow_headers=["*"],
52
+ )
53
+
54
+ # =============================================================================
55
+ # Model Loading (on startup)
56
+ # =============================================================================
57
+
58
+ model = None
59
+ device = None
60
+
61
+
62
+ @app.on_event("startup")
63
+ async def load_model_on_startup():
64
+ """Load model when the API starts."""
65
+ global model, device
66
+
67
+ device = get_device()
68
+ print(f"Using device: {device}")
69
+
70
+ if not CHECKPOINT_PATH.exists():
71
+ print(f"Warning: Model checkpoint not found at {CHECKPOINT_PATH}")
72
+ return
73
+
74
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
75
+ model = load_model(model, CHECKPOINT_PATH, device)
76
+ print(f"Model loaded from {CHECKPOINT_PATH}")
77
+
78
+
79
+ # =============================================================================
80
+ # Helper Functions
81
+ # =============================================================================
82
+
83
+ ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png"}
84
+
85
+
86
+ def validate_image(file: UploadFile) -> None:
87
+ """Validate uploaded image file."""
88
+ if not file.content_type.startswith("image/"):
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail=f"Invalid content type: {file.content_type}. Expected image/*"
92
+ )
93
+
94
+ ext = Path(file.filename).suffix.lower() if file.filename else ""
95
+ if ext not in ALLOWED_EXTENSIONS:
96
+ raise HTTPException(
97
+ status_code=400,
98
+ detail=f"Invalid file extension: {ext}. Allowed: {ALLOWED_EXTENSIONS}"
99
+ )
100
+
101
+
102
+ async def read_image(file: UploadFile) -> Image.Image:
103
+ """Read uploaded file as PIL Image."""
104
+ try:
105
+ contents = await file.read()
106
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
107
+ return image
108
+ except Exception as e:
109
+ raise HTTPException(
110
+ status_code=400,
111
+ detail=f"Failed to read image: {str(e)}"
112
+ )
113
+
114
+
115
+ # =============================================================================
116
+ # API Endpoints
117
+ # =============================================================================
118
+
119
+ @app.get("/", include_in_schema=False)
120
+ async def root():
121
+ """Redirect to docs."""
122
+ return {"message": "Pneumonia Detection API", "docs": "/docs"}
123
+
124
+
125
+ @app.get("/health", response_model=HealthResponse, tags=["Health"])
126
+ async def health_check():
127
+ """
128
+ Health check endpoint.
129
+
130
+ Returns the API status and model loading state.
131
+ """
132
+ return HealthResponse(
133
+ status="healthy" if model is not None else "model_not_loaded",
134
+ model_loaded=model is not None,
135
+ model_path=str(CHECKPOINT_PATH)
136
+ )
137
+
138
+
139
+ @app.post(
140
+ "/predict",
141
+ response_model=PredictionResponse,
142
+ responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
143
+ tags=["Prediction"]
144
+ )
145
+ async def predict(file: UploadFile = File(..., description="Chest X-ray image (JPEG/PNG)")):
146
+ """
147
+ Predict pneumonia from chest X-ray image.
148
+
149
+ Upload a chest X-ray image and get the prediction (NORMAL or PNEUMONIA)
150
+ with confidence score.
151
+ """
152
+ if model is None:
153
+ raise HTTPException(status_code=503, detail="Model not loaded")
154
+
155
+ validate_image(file)
156
+ image = await read_image(file)
157
+
158
+ # Run inference
159
+ start_time = time.time()
160
+ pred_class, confidence = predict_image(model, image, device)
161
+ processing_time = (time.time() - start_time) * 1000 # Convert to ms
162
+
163
+ # Calculate raw probability
164
+ probability = confidence if pred_class == "PNEUMONIA" else 1 - confidence
165
+
166
+ return PredictionResponse(
167
+ prediction=pred_class,
168
+ confidence=confidence,
169
+ probability=probability,
170
+ processing_time_ms=round(processing_time, 2)
171
+ )
172
+
173
+
174
+ @app.post(
175
+ "/predict/gradcam",
176
+ response_model=GradCAMResponse,
177
+ responses={400: {"model": ErrorResponse}, 503: {"model": ErrorResponse}},
178
+ tags=["Prediction"]
179
+ )
180
+ async def predict_with_gradcam(file: UploadFile = File(..., description="Chest X-ray image (JPEG/PNG)")):
181
+ """
182
+ Predict with Grad-CAM visualization.
183
+
184
+ Returns prediction along with a Grad-CAM heatmap overlay showing
185
+ which regions of the image influenced the prediction.
186
+ """
187
+ if model is None:
188
+ raise HTTPException(status_code=503, detail="Model not loaded")
189
+
190
+ validate_image(file)
191
+ image = await read_image(file)
192
+
193
+ # Run inference with Grad-CAM
194
+ start_time = time.time()
195
+ cam_image, pred_class, confidence, _ = generate_gradcam(model, image, device)
196
+ processing_time = (time.time() - start_time) * 1000
197
+
198
+ # Convert Grad-CAM image to base64
199
+ cam_pil = Image.fromarray(cam_image)
200
+ buffer = io.BytesIO()
201
+ cam_pil.save(buffer, format="PNG")
202
+ buffer.seek(0)
203
+ img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
204
+
205
+ # Calculate raw probability
206
+ probability = confidence if pred_class == "PNEUMONIA" else 1 - confidence
207
+
208
+ return GradCAMResponse(
209
+ prediction=pred_class,
210
+ confidence=confidence,
211
+ probability=probability,
212
+ processing_time_ms=round(processing_time, 2),
213
+ gradcam_image=f"data:image/png;base64,{img_base64}"
214
+ )
215
+
216
+
217
+ # =============================================================================
218
+ # Error Handlers
219
+ # =============================================================================
220
+
221
+ @app.exception_handler(HTTPException)
222
+ async def http_exception_handler(request, exc):
223
+ """Handle HTTP exceptions."""
224
+ return JSONResponse(
225
+ status_code=exc.status_code,
226
+ content={"error": exc.detail, "detail": None}
227
+ )
228
+
229
+
230
+ @app.exception_handler(Exception)
231
+ async def general_exception_handler(request, exc):
232
+ """Handle unexpected exceptions."""
233
+ return JSONResponse(
234
+ status_code=500,
235
+ content={"error": "Internal server error", "detail": str(exc)}
236
+ )
237
+
238
+
239
+ # =============================================================================
240
+ # Run with: uvicorn api.main:app --reload --host 0.0.0.0 --port 8000
241
+ # =============================================================================
242
+
243
+ if __name__ == "__main__":
244
+ import uvicorn
245
+ uvicorn.run(app, host="0.0.0.0", port=8000)
api/schemas.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for API request/response validation.
3
+ """
4
+
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional
7
+ from enum import Enum
8
+
9
+
10
+ class ClassLabel(str, Enum):
11
+ """Prediction class labels."""
12
+ NORMAL = "NORMAL"
13
+ PNEUMONIA = "PNEUMONIA"
14
+
15
+
16
+ class HealthResponse(BaseModel):
17
+ """Health check response."""
18
+ status: str = Field(..., example="healthy")
19
+ model_loaded: bool = Field(..., example=True)
20
+ model_path: str = Field(..., example="models/best_model.pt")
21
+
22
+
23
+ class PredictionResponse(BaseModel):
24
+ """Prediction response."""
25
+ prediction: ClassLabel = Field(..., description="Predicted class")
26
+ confidence: float = Field(..., ge=0, le=1, description="Confidence score")
27
+ probability: float = Field(..., ge=0, le=1, description="Raw probability for PNEUMONIA")
28
+ processing_time_ms: float = Field(..., description="Inference time in milliseconds")
29
+
30
+ class Config:
31
+ json_schema_extra = {
32
+ "example": {
33
+ "prediction": "PNEUMONIA",
34
+ "confidence": 0.92,
35
+ "probability": 0.92,
36
+ "processing_time_ms": 45.2
37
+ }
38
+ }
39
+
40
+
41
+ class GradCAMResponse(BaseModel):
42
+ """Prediction with Grad-CAM visualization."""
43
+ prediction: ClassLabel = Field(..., description="Predicted class")
44
+ confidence: float = Field(..., ge=0, le=1, description="Confidence score")
45
+ probability: float = Field(..., ge=0, le=1, description="Raw probability for PNEUMONIA")
46
+ processing_time_ms: float = Field(..., description="Inference time in milliseconds")
47
+ gradcam_image: str = Field(..., description="Base64 encoded Grad-CAM overlay image")
48
+
49
+ class Config:
50
+ json_schema_extra = {
51
+ "example": {
52
+ "prediction": "PNEUMONIA",
53
+ "confidence": 0.92,
54
+ "probability": 0.92,
55
+ "processing_time_ms": 150.5,
56
+ "gradcam_image": "data:image/png;base64,..."
57
+ }
58
+ }
59
+
60
+
61
+ class ErrorResponse(BaseModel):
62
+ """Error response."""
63
+ error: str = Field(..., description="Error message")
64
+ detail: Optional[str] = Field(None, description="Detailed error information")
65
+
66
+ class Config:
67
+ json_schema_extra = {
68
+ "example": {
69
+ "error": "Invalid image format",
70
+ "detail": "Supported formats: JPEG, PNG"
71
+ }
72
+ }
requirements.txt CHANGED
@@ -1,7 +1,45 @@
1
- torch>=2.0.0
2
- torchvision>=0.15.0
3
- fastapi>=0.100.0
4
- uvicorn>=0.23.0
5
- python-multipart>=0.0.6
6
  pillow>=10.0.0
7
  numpy>=1.24.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Deep Learning
2
+ torch>=2.1.0
3
+ torchvision>=0.16.0
 
 
4
  pillow>=10.0.0
5
  numpy>=1.24.0
6
+
7
+ # Data Analysis & Visualization
8
+ pandas>=2.0.0
9
+ matplotlib>=3.7.0
10
+ seaborn>=0.12.0
11
+
12
+ # Experiment Tracking
13
+ wandb>=0.15.0
14
+
15
+ # Model Interpretability
16
+ grad-cam>=1.4.0
17
+
18
+ # API
19
+ fastapi>=0.104.0
20
+ uvicorn>=0.24.0
21
+ python-multipart>=0.0.6
22
+
23
+ # Web UI
24
+ streamlit>=1.28.0
25
+
26
+ # Testing
27
+ pytest>=7.4.0
28
+
29
+ # Code Quality
30
+ black>=23.0.0
31
+ ruff>=0.1.0
32
+
33
+ # Jupyter
34
+ jupyterlab>=4.0.0
35
+ ipywidgets>=8.0.0
36
+
37
+ # Utilities
38
+ python-dotenv>=1.0.0
39
+ tqdm>=4.66.0
40
+ scikit-learn>=1.3.0
41
+
42
+ # ONNX Export
43
+ onnx>=1.15.0
44
+ onnxruntime>=1.16.0
45
+ onnxscript>=0.1.0
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Pneumonia Detection from Chest X-Rays
3
+ Medical Image Classification using Deep Learning
4
+ """
5
+
6
+ __version__ = "0.1.0"
src/config.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration constants for the Pneumonia Detection project.
3
+ All hyperparameters and paths are defined here for easy modification.
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ # =============================================================================
9
+ # Project Paths
10
+ # =============================================================================
11
+ PROJECT_ROOT = Path(__file__).parent.parent
12
+ DATA_DIR = PROJECT_ROOT / "data" / "raw"
13
+ PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
14
+ MODEL_DIR = PROJECT_ROOT / "models"
15
+ OUTPUT_DIR = PROJECT_ROOT / "outputs"
16
+ FIGURES_DIR = OUTPUT_DIR / "figures"
17
+ LOGS_DIR = OUTPUT_DIR / "logs"
18
+
19
+ # =============================================================================
20
+ # Data Configuration
21
+ # =============================================================================
22
+ IMAGE_SIZE = 224 # EfficientNet-B0 input size
23
+ BATCH_SIZE = 32
24
+ NUM_WORKERS = 4 # DataLoader workers
25
+
26
+ # ImageNet normalization (required for pretrained models)
27
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
28
+ IMAGENET_STD = [0.229, 0.224, 0.225]
29
+
30
+ # Class labels
31
+ CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
32
+ NUM_CLASSES = 1 # Binary classification with sigmoid
33
+
34
+ # =============================================================================
35
+ # Model Configuration
36
+ # =============================================================================
37
+ MODEL_NAME = "efficientnet_b0"
38
+ DROPOUT_RATE = 0.3
39
+ PRETRAINED = True
40
+
41
+ # =============================================================================
42
+ # Training Configuration - Stage 1 (Frozen Backbone)
43
+ # =============================================================================
44
+ STAGE1_EPOCHS = 5
45
+ STAGE1_LR = 1e-4
46
+ STAGE1_FREEZE_BACKBONE = True
47
+
48
+ # =============================================================================
49
+ # Training Configuration - Stage 2 (Fine-tuning)
50
+ # =============================================================================
51
+ STAGE2_EPOCHS = 15
52
+ STAGE2_LR = 1e-5
53
+ STAGE2_FREEZE_BACKBONE = False
54
+
55
+ # =============================================================================
56
+ # Optimizer Configuration
57
+ # =============================================================================
58
+ WEIGHT_DECAY = 1e-4
59
+ BETAS = (0.9, 0.999)
60
+
61
+ # =============================================================================
62
+ # Scheduler Configuration
63
+ # =============================================================================
64
+ SCHEDULER_PATIENCE = 3
65
+ SCHEDULER_FACTOR = 0.5
66
+ SCHEDULER_MIN_LR = 1e-7
67
+
68
+ # =============================================================================
69
+ # Early Stopping Configuration
70
+ # =============================================================================
71
+ EARLY_STOP_PATIENCE = 7
72
+ EARLY_STOP_MIN_DELTA = 0.001
73
+
74
+ # =============================================================================
75
+ # Model Checkpointing
76
+ # =============================================================================
77
+ CHECKPOINT_PATH = MODEL_DIR / "best_model.pt"
78
+ SAVE_BEST_ONLY = True
79
+ MONITOR_METRIC = "val_loss"
80
+
81
+ # =============================================================================
82
+ # Weights & Biases Configuration
83
+ # =============================================================================
84
+ WANDB_PROJECT = "pneumonia-detection"
85
+ WANDB_ENTITY = None # Set to your W&B username if needed
86
+
87
+ # =============================================================================
88
+ # Inference Configuration
89
+ # =============================================================================
90
+ CONFIDENCE_THRESHOLD = 0.5 # For binary classification
91
+ GRADCAM_TARGET_LAYER = "features" # EfficientNet feature extractor
92
+
93
+ # =============================================================================
94
+ # Random Seed (for reproducibility)
95
+ # =============================================================================
96
+ SEED = 42
97
+
98
+
99
+ def create_directories():
100
+ """Create all necessary directories if they don't exist."""
101
+ for directory in [DATA_DIR, PROCESSED_DIR, MODEL_DIR, FIGURES_DIR, LOGS_DIR]:
102
+ directory.mkdir(parents=True, exist_ok=True)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ # Print configuration for verification
107
+ print(f"Project Root: {PROJECT_ROOT}")
108
+ print(f"Data Directory: {DATA_DIR}")
109
+ print(f"Model Directory: {MODEL_DIR}")
110
+ print(f"Image Size: {IMAGE_SIZE}")
111
+ print(f"Batch Size: {BATCH_SIZE}")
112
+ print(f"Model: {MODEL_NAME}")
src/dataset.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Dataset and DataLoader utilities for Chest X-Ray classification.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Tuple, Optional, List
7
+ import random
8
+
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
11
+ from torchvision import transforms
12
+ from PIL import Image
13
+ from sklearn.model_selection import train_test_split
14
+
15
+ from .config import (
16
+ DATA_DIR, IMAGE_SIZE, BATCH_SIZE, NUM_WORKERS,
17
+ IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES, SEED
18
+ )
19
+
20
+
21
+ class ChestXRayDataset(Dataset):
22
+ """Dataset for Chest X-Ray images."""
23
+
24
+ def __init__(
25
+ self,
26
+ image_paths: List[Path],
27
+ labels: List[int],
28
+ transform: Optional[transforms.Compose] = None
29
+ ):
30
+ self.image_paths = image_paths
31
+ self.labels = labels
32
+ self.transform = transform
33
+
34
+ def __len__(self) -> int:
35
+ return len(self.image_paths)
36
+
37
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
38
+ img_path = self.image_paths[idx]
39
+ label = self.labels[idx]
40
+
41
+ # Load image and convert to RGB
42
+ image = Image.open(img_path).convert('RGB')
43
+
44
+ if self.transform:
45
+ image = self.transform(image)
46
+
47
+ return image, label
48
+
49
+
50
+ def get_transforms(is_training: bool = True) -> transforms.Compose:
51
+ """Get image transforms for training or validation/test."""
52
+ if is_training:
53
+ return transforms.Compose([
54
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
55
+ transforms.RandomHorizontalFlip(p=0.5),
56
+ transforms.RandomRotation(10),
57
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
60
+ ])
61
+ else:
62
+ return transforms.Compose([
63
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
66
+ ])
67
+
68
+
69
+ def load_image_paths_and_labels(
70
+ data_dir: Path,
71
+ split: str
72
+ ) -> Tuple[List[Path], List[int]]:
73
+ """Load image paths and labels from a data split directory."""
74
+ image_paths = []
75
+ labels = []
76
+
77
+ for class_idx, class_name in enumerate(CLASS_NAMES):
78
+ class_dir = data_dir / split / class_name
79
+ if class_dir.exists():
80
+ for img_path in class_dir.glob('*.jpeg'):
81
+ image_paths.append(img_path)
82
+ labels.append(class_idx)
83
+
84
+ return image_paths, labels
85
+
86
+
87
+ def create_train_val_split(
88
+ data_dir: Path = DATA_DIR,
89
+ val_ratio: float = 0.15,
90
+ seed: int = SEED
91
+ ) -> Tuple[List[Path], List[int], List[Path], List[int]]:
92
+ """Create stratified train/val split from training data."""
93
+ # Load all training images
94
+ train_paths, train_labels = load_image_paths_and_labels(data_dir, 'train')
95
+
96
+ # Stratified split
97
+ train_paths, val_paths, train_labels, val_labels = train_test_split(
98
+ train_paths, train_labels,
99
+ test_size=val_ratio,
100
+ stratify=train_labels,
101
+ random_state=seed
102
+ )
103
+
104
+ return train_paths, train_labels, val_paths, val_labels
105
+
106
+
107
+ def get_class_weights(labels: List[int]) -> torch.Tensor:
108
+ """Calculate class weights for imbalanced dataset."""
109
+ class_counts = torch.bincount(torch.tensor(labels))
110
+ total = len(labels)
111
+ weights = total / (len(class_counts) * class_counts.float())
112
+ return weights
113
+
114
+
115
+ def get_sampler(labels: List[int]) -> WeightedRandomSampler:
116
+ """Create weighted sampler for balanced batches."""
117
+ class_weights = get_class_weights(labels)
118
+ sample_weights = [class_weights[label] for label in labels]
119
+ sampler = WeightedRandomSampler(
120
+ weights=sample_weights,
121
+ num_samples=len(labels),
122
+ replacement=True
123
+ )
124
+ return sampler
125
+
126
+
127
+ def get_dataloaders(
128
+ data_dir: Path = DATA_DIR,
129
+ batch_size: int = BATCH_SIZE,
130
+ num_workers: int = NUM_WORKERS,
131
+ val_ratio: float = 0.15,
132
+ use_weighted_sampling: bool = True
133
+ ) -> Tuple[DataLoader, DataLoader, DataLoader]:
134
+ """Create train, validation, and test DataLoaders."""
135
+
136
+ # Create train/val split
137
+ train_paths, train_labels, val_paths, val_labels = create_train_val_split(
138
+ data_dir, val_ratio
139
+ )
140
+
141
+ # Load test data
142
+ test_paths, test_labels = load_image_paths_and_labels(data_dir, 'test')
143
+
144
+ # Create datasets
145
+ train_dataset = ChestXRayDataset(
146
+ train_paths, train_labels, transform=get_transforms(is_training=True)
147
+ )
148
+ val_dataset = ChestXRayDataset(
149
+ val_paths, val_labels, transform=get_transforms(is_training=False)
150
+ )
151
+ test_dataset = ChestXRayDataset(
152
+ test_paths, test_labels, transform=get_transforms(is_training=False)
153
+ )
154
+
155
+ # Create sampler for training if using weighted sampling
156
+ train_sampler = get_sampler(train_labels) if use_weighted_sampling else None
157
+
158
+ # Only use pin_memory for CUDA (not supported on MPS)
159
+ pin_memory = torch.cuda.is_available()
160
+
161
+ # Create dataloaders
162
+ train_loader = DataLoader(
163
+ train_dataset,
164
+ batch_size=batch_size,
165
+ sampler=train_sampler,
166
+ shuffle=(train_sampler is None),
167
+ num_workers=num_workers,
168
+ pin_memory=pin_memory
169
+ )
170
+
171
+ val_loader = DataLoader(
172
+ val_dataset,
173
+ batch_size=batch_size,
174
+ shuffle=False,
175
+ num_workers=num_workers,
176
+ pin_memory=pin_memory
177
+ )
178
+
179
+ test_loader = DataLoader(
180
+ test_dataset,
181
+ batch_size=batch_size,
182
+ shuffle=False,
183
+ num_workers=num_workers,
184
+ pin_memory=pin_memory
185
+ )
186
+
187
+ # Print dataset info
188
+ print(f"Train: {len(train_dataset)} images")
189
+ print(f"Val: {len(val_dataset)} images")
190
+ print(f"Test: {len(test_dataset)} images")
191
+
192
+ return train_loader, val_loader, test_loader
193
+
194
+
195
+ def get_pos_weight(labels: List[int]) -> torch.Tensor:
196
+ """Calculate pos_weight for BCEWithLogitsLoss to handle class imbalance."""
197
+ labels_tensor = torch.tensor(labels)
198
+ neg_count = (labels_tensor == 0).sum().float() # NORMAL
199
+ pos_count = (labels_tensor == 1).sum().float() # PNEUMONIA
200
+ pos_weight = neg_count / pos_count
201
+ return pos_weight
src/evaluate.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation functions for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ import numpy as np
9
+ from typing import Dict, Tuple
10
+ from sklearn.metrics import (
11
+ accuracy_score, precision_score, recall_score, f1_score,
12
+ roc_auc_score, confusion_matrix, classification_report
13
+ )
14
+
15
+ from .config import CLASS_NAMES
16
+
17
+
18
+ def predict_proba(
19
+ model: nn.Module,
20
+ loader: DataLoader,
21
+ device: torch.device
22
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
23
+ """Get predictions, probabilities, and true labels."""
24
+ model.eval()
25
+ all_probs, all_preds, all_labels = [], [], []
26
+
27
+ with torch.no_grad():
28
+ for images, labels in loader:
29
+ images = images.to(device)
30
+ outputs = model(images)
31
+ probs = torch.sigmoid(outputs).cpu().numpy()
32
+ preds = (probs > 0.5).astype(int)
33
+
34
+ all_probs.extend(probs.flatten())
35
+ all_preds.extend(preds.flatten())
36
+ all_labels.extend(labels.numpy())
37
+
38
+ return np.array(all_probs), np.array(all_preds), np.array(all_labels)
39
+
40
+
41
+ def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict:
42
+ """Compute all evaluation metrics."""
43
+ return {
44
+ 'accuracy': accuracy_score(y_true, y_pred),
45
+ 'precision': precision_score(y_true, y_pred),
46
+ 'recall': recall_score(y_true, y_pred),
47
+ 'f1': f1_score(y_true, y_pred),
48
+ 'roc_auc': roc_auc_score(y_true, y_proba),
49
+ 'confusion_matrix': confusion_matrix(y_true, y_pred)
50
+ }
51
+
52
+
53
+ def evaluate_model(
54
+ model: nn.Module,
55
+ loader: DataLoader,
56
+ device: torch.device
57
+ ) -> Dict:
58
+ """Full evaluation on a dataset."""
59
+ probs, preds, labels = predict_proba(model, loader, device)
60
+ metrics = compute_metrics(labels, preds, probs)
61
+
62
+ print("=" * 50)
63
+ print("EVALUATION RESULTS")
64
+ print("=" * 50)
65
+ print(f"Accuracy: {metrics['accuracy']:.4f}")
66
+ print(f"Precision: {metrics['precision']:.4f}")
67
+ print(f"Recall: {metrics['recall']:.4f}")
68
+ print(f"F1 Score: {metrics['f1']:.4f}")
69
+ print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
70
+ print("\nConfusion Matrix:")
71
+ print(f" {CLASS_NAMES[0]:>10} {CLASS_NAMES[1]:>10}")
72
+ for i, row in enumerate(metrics['confusion_matrix']):
73
+ print(f" {CLASS_NAMES[i]:>10} {row[0]:>10} {row[1]:>10}")
74
+
75
+ print("\nClassification Report:")
76
+ print(classification_report(labels, preds, target_names=CLASS_NAMES))
77
+
78
+ return metrics
79
+
80
+
81
+ def get_predictions_with_paths(
82
+ model: nn.Module,
83
+ dataset,
84
+ device: torch.device
85
+ ) -> list:
86
+ """Get predictions with image paths for error analysis."""
87
+ model.eval()
88
+ results = []
89
+
90
+ with torch.no_grad():
91
+ for idx in range(len(dataset)):
92
+ image, label = dataset[idx]
93
+ image = image.unsqueeze(0).to(device)
94
+
95
+ output = model(image)
96
+ prob = torch.sigmoid(output).item()
97
+ pred = 1 if prob > 0.5 else 0
98
+
99
+ results.append({
100
+ 'path': dataset.image_paths[idx],
101
+ 'true_label': label,
102
+ 'pred_label': pred,
103
+ 'probability': prob,
104
+ 'correct': pred == label
105
+ })
106
+
107
+ return results
src/export.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX export utilities for model deployment.
3
+
4
+ ONNX (Open Neural Network Exchange) is a universal format that allows
5
+ models to run on different frameworks and platforms:
6
+ - TensorFlow, PyTorch, etc.
7
+ - Mobile devices (iOS, Android)
8
+ - Web browsers (ONNX.js)
9
+ - C++, Java, and other languages
10
+ - Optimized inference servers
11
+ """
12
+
13
+ import torch
14
+ import numpy as np
15
+ from pathlib import Path
16
+ from typing import Tuple, Optional
17
+
18
+ from .config import CHECKPOINT_PATH, MODEL_DIR, IMAGE_SIZE
19
+ from .model import create_model, get_device
20
+
21
+
22
+ def export_to_onnx(
23
+ checkpoint_path: Path = CHECKPOINT_PATH,
24
+ output_path: Optional[Path] = None,
25
+ opset_version: int = 18
26
+ ) -> Path:
27
+ """
28
+ Export PyTorch model to ONNX format.
29
+
30
+ Args:
31
+ checkpoint_path: Path to the PyTorch checkpoint
32
+ output_path: Path for the ONNX model (default: models/best_model.onnx)
33
+ opset_version: ONNX opset version (14 is widely compatible)
34
+
35
+ Returns:
36
+ Path to the exported ONNX model
37
+ """
38
+ if output_path is None:
39
+ output_path = MODEL_DIR / "best_model.onnx"
40
+
41
+ # Load model
42
+ device = torch.device("cpu") # Export on CPU for compatibility
43
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
44
+
45
+ checkpoint = torch.load(checkpoint_path, map_location=device)
46
+ model.load_state_dict(checkpoint['model_state_dict'])
47
+ model.eval()
48
+
49
+ # Create dummy input (batch_size=1, channels=3, height=224, width=224)
50
+ dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
51
+
52
+ # Export to ONNX
53
+ torch.onnx.export(
54
+ model,
55
+ dummy_input,
56
+ output_path,
57
+ export_params=True,
58
+ opset_version=opset_version,
59
+ do_constant_folding=True, # Optimize constants
60
+ input_names=['image'],
61
+ output_names=['logits'],
62
+ dynamic_axes={
63
+ 'image': {0: 'batch_size'}, # Variable batch size
64
+ 'logits': {0: 'batch_size'}
65
+ }
66
+ )
67
+
68
+ print(f"Model exported to: {output_path}")
69
+ print(f"File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
70
+
71
+ return output_path
72
+
73
+
74
+ def validate_onnx_model(
75
+ onnx_path: Path,
76
+ checkpoint_path: Path = CHECKPOINT_PATH,
77
+ rtol: float = 1e-3,
78
+ atol: float = 1e-5
79
+ ) -> bool:
80
+ """
81
+ Validate that ONNX model produces same outputs as PyTorch model.
82
+
83
+ Args:
84
+ onnx_path: Path to ONNX model
85
+ checkpoint_path: Path to PyTorch checkpoint
86
+ rtol: Relative tolerance for comparison
87
+ atol: Absolute tolerance for comparison
88
+
89
+ Returns:
90
+ True if outputs match, False otherwise
91
+ """
92
+ import onnx
93
+ import onnxruntime as ort
94
+
95
+ # Check ONNX model is valid
96
+ onnx_model = onnx.load(onnx_path)
97
+ onnx.checker.check_model(onnx_model)
98
+ print("ONNX model structure is valid")
99
+
100
+ # Load PyTorch model
101
+ device = torch.device("cpu")
102
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
103
+ checkpoint = torch.load(checkpoint_path, map_location=device)
104
+ model.load_state_dict(checkpoint['model_state_dict'])
105
+ model.eval()
106
+
107
+ # Create test input
108
+ test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
109
+
110
+ # Get PyTorch output
111
+ with torch.no_grad():
112
+ pytorch_output = model(test_input).numpy()
113
+
114
+ # Get ONNX output
115
+ ort_session = ort.InferenceSession(str(onnx_path))
116
+ onnx_output = ort_session.run(
117
+ None,
118
+ {'image': test_input.numpy()}
119
+ )[0]
120
+
121
+ # Compare outputs
122
+ is_close = np.allclose(pytorch_output, onnx_output, rtol=rtol, atol=atol)
123
+
124
+ if is_close:
125
+ print("Validation PASSED: ONNX outputs match PyTorch outputs")
126
+ print(f" PyTorch output: {pytorch_output.flatten()[:5]}...")
127
+ print(f" ONNX output: {onnx_output.flatten()[:5]}...")
128
+ else:
129
+ print("Validation FAILED: Outputs do not match!")
130
+ print(f" Max difference: {np.max(np.abs(pytorch_output - onnx_output))}")
131
+
132
+ return is_close
133
+
134
+
135
+ def predict_with_onnx(
136
+ onnx_path: Path,
137
+ image_tensor: np.ndarray
138
+ ) -> Tuple[str, float]:
139
+ """
140
+ Run inference using ONNX Runtime.
141
+
142
+ Args:
143
+ onnx_path: Path to ONNX model
144
+ image_tensor: Preprocessed image as numpy array (1, 3, 224, 224)
145
+
146
+ Returns:
147
+ Tuple of (predicted_class, confidence)
148
+ """
149
+ import onnxruntime as ort
150
+ from .config import CLASS_NAMES
151
+
152
+ # Create session
153
+ ort_session = ort.InferenceSession(str(onnx_path))
154
+
155
+ # Run inference
156
+ logits = ort_session.run(
157
+ None,
158
+ {'image': image_tensor.astype(np.float32)}
159
+ )[0]
160
+
161
+ # Apply sigmoid and get prediction
162
+ prob = 1 / (1 + np.exp(-logits[0, 0])) # Sigmoid
163
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
164
+ confidence = float(prob if prob > 0.5 else 1 - prob)
165
+
166
+ return pred_class, confidence
167
+
168
+
169
+ if __name__ == "__main__":
170
+ # Export model
171
+ print("=" * 50)
172
+ print("EXPORTING MODEL TO ONNX")
173
+ print("=" * 50)
174
+
175
+ onnx_path = export_to_onnx()
176
+
177
+ print("\n" + "=" * 50)
178
+ print("VALIDATING ONNX MODEL")
179
+ print("=" * 50)
180
+
181
+ validate_onnx_model(onnx_path)
182
+
183
+ print("\n" + "=" * 50)
184
+ print("TESTING ONNX INFERENCE")
185
+ print("=" * 50)
186
+
187
+ # Test with random input
188
+ test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
189
+ pred_class, confidence = predict_with_onnx(onnx_path, test_input)
190
+ print(f"Test prediction: {pred_class} ({confidence:.1%})")
src/gradcam.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grad-CAM visualization for model interpretability.
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from typing import Union
10
+ import matplotlib.pyplot as plt
11
+
12
+ from pytorch_grad_cam import GradCAM
13
+ from pytorch_grad_cam.utils.image import show_cam_on_image
14
+
15
+ from .dataset import get_transforms
16
+ from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
17
+
18
+
19
+ def get_gradcam(model, target_layer=None):
20
+ """Create GradCAM object for the model."""
21
+ if target_layer is None:
22
+ # Use the last conv layer of EfficientNet
23
+ target_layer = model.backbone.features[-1]
24
+ return GradCAM(model=model, target_layers=[target_layer])
25
+
26
+
27
+ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
28
+ """Denormalize tensor to numpy image [0,1]."""
29
+ mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
30
+ std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
31
+ img = tensor.cpu() * std + mean
32
+ img = img.permute(1, 2, 0).numpy()
33
+ return np.clip(img, 0, 1)
34
+
35
+
36
+ def generate_gradcam(
37
+ model,
38
+ image: Union[str, Path, Image.Image],
39
+ device: torch.device
40
+ ) -> tuple:
41
+ """Generate Grad-CAM heatmap for an image."""
42
+ model.eval()
43
+
44
+ # Load and transform image
45
+ if isinstance(image, (str, Path)):
46
+ image = Image.open(image).convert('RGB')
47
+
48
+ transform = get_transforms(is_training=False)
49
+ img_tensor = transform(image).unsqueeze(0).to(device)
50
+
51
+ # Get prediction
52
+ with torch.no_grad():
53
+ output = model(img_tensor)
54
+ prob = torch.sigmoid(output).item()
55
+
56
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
57
+ confidence = prob if prob > 0.5 else 1 - prob
58
+
59
+ # Generate Grad-CAM
60
+ cam = get_gradcam(model)
61
+ grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0]
62
+
63
+ # Create visualization
64
+ rgb_img = denormalize_image(img_tensor[0])
65
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
66
+
67
+ return cam_image, pred_class, confidence, rgb_img
68
+
69
+
70
+ def plot_gradcam(
71
+ model,
72
+ image_path: Union[str, Path],
73
+ true_label: str,
74
+ device: torch.device,
75
+ save_path: str = None
76
+ ):
77
+ """Plot original image with Grad-CAM overlay."""
78
+ cam_image, pred_class, confidence, original = generate_gradcam(model, image_path, device)
79
+
80
+ fig, axes = plt.subplots(1, 2, figsize=(10, 4))
81
+
82
+ # Original
83
+ axes[0].imshow(original)
84
+ axes[0].set_title(f"Original\nTrue: {true_label}")
85
+ axes[0].axis('off')
86
+
87
+ # Grad-CAM
88
+ color = 'green' if pred_class == true_label else 'red'
89
+ axes[1].imshow(cam_image)
90
+ axes[1].set_title(f"Grad-CAM\nPred: {pred_class} ({confidence:.1%})", color=color)
91
+ axes[1].axis('off')
92
+
93
+ plt.tight_layout()
94
+
95
+ if save_path:
96
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
97
+
98
+ plt.show()
99
+ return pred_class, confidence
100
+
101
+
102
+ def plot_gradcam_grid(
103
+ model,
104
+ image_paths: list,
105
+ true_labels: list,
106
+ device: torch.device,
107
+ save_path: str = None,
108
+ title: str = "Grad-CAM Visualizations"
109
+ ):
110
+ """Plot grid of Grad-CAM visualizations."""
111
+ n = len(image_paths)
112
+ fig, axes = plt.subplots(n, 2, figsize=(8, 3 * n))
113
+
114
+ if n == 1:
115
+ axes = axes.reshape(1, -1)
116
+
117
+ for i, (path, true_label) in enumerate(zip(image_paths, true_labels)):
118
+ cam_image, pred_class, confidence, original = generate_gradcam(model, path, device)
119
+
120
+ # Original
121
+ axes[i, 0].imshow(original)
122
+ axes[i, 0].set_title(f"True: {true_label}")
123
+ axes[i, 0].axis('off')
124
+
125
+ # Grad-CAM
126
+ color = 'green' if pred_class == true_label else 'red'
127
+ axes[i, 1].imshow(cam_image)
128
+ axes[i, 1].set_title(f"Pred: {pred_class} ({confidence:.1%})", color=color)
129
+ axes[i, 1].axis('off')
130
+
131
+ plt.suptitle(title, fontsize=14, fontweight='bold')
132
+ plt.tight_layout()
133
+
134
+ if save_path:
135
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
136
+
137
+ plt.show()
src/model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EfficientNet-B0 model for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+ from typing import Tuple
9
+
10
+ from .config import DROPOUT_RATE, NUM_CLASSES
11
+
12
+
13
+ class PneumoniaClassifier(nn.Module):
14
+ """EfficientNet-B0 based classifier for chest X-ray pneumonia detection."""
15
+
16
+ def __init__(
17
+ self,
18
+ pretrained: bool = True,
19
+ dropout_rate: float = DROPOUT_RATE,
20
+ freeze_backbone: bool = True
21
+ ):
22
+ super().__init__()
23
+
24
+ # Load pretrained EfficientNet-B0
25
+ weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
26
+ self.backbone = models.efficientnet_b0(weights=weights)
27
+
28
+ # Get the number of features from the classifier
29
+ in_features = self.backbone.classifier[1].in_features # 1280
30
+
31
+ # Replace classifier head
32
+ self.backbone.classifier = nn.Sequential(
33
+ nn.Dropout(p=dropout_rate, inplace=True),
34
+ nn.Linear(in_features, NUM_CLASSES)
35
+ )
36
+
37
+ # Freeze backbone if specified
38
+ if freeze_backbone:
39
+ self.freeze_backbone()
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return self.backbone(x)
43
+
44
+ def freeze_backbone(self):
45
+ """Freeze all layers except the classifier."""
46
+ for param in self.backbone.features.parameters():
47
+ param.requires_grad = False
48
+
49
+ def unfreeze_backbone(self):
50
+ """Unfreeze all layers for fine-tuning."""
51
+ for param in self.backbone.features.parameters():
52
+ param.requires_grad = True
53
+
54
+ def get_param_counts(self) -> Tuple[int, int]:
55
+ """Return (trainable_params, total_params)."""
56
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
57
+ total = sum(p.numel() for p in self.parameters())
58
+ return trainable, total
59
+
60
+
61
+ def create_model(
62
+ pretrained: bool = True,
63
+ dropout_rate: float = DROPOUT_RATE,
64
+ freeze_backbone: bool = True,
65
+ device: str = None
66
+ ) -> PneumoniaClassifier:
67
+ """Factory function to create the model."""
68
+ if device is None:
69
+ device = "mps" if torch.backends.mps.is_available() else \
70
+ "cuda" if torch.cuda.is_available() else "cpu"
71
+
72
+ model = PneumoniaClassifier(
73
+ pretrained=pretrained,
74
+ dropout_rate=dropout_rate,
75
+ freeze_backbone=freeze_backbone
76
+ )
77
+
78
+ return model.to(device)
79
+
80
+
81
+ def get_device() -> torch.device:
82
+ """Get the best available device."""
83
+ if torch.backends.mps.is_available():
84
+ return torch.device("mps")
85
+ elif torch.cuda.is_available():
86
+ return torch.device("cuda")
87
+ return torch.device("cpu")
src/predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference functions for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from typing import Union, Tuple
10
+
11
+ from .dataset import get_transforms
12
+ from .config import CLASS_NAMES, CHECKPOINT_PATH
13
+
14
+
15
+ def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module:
16
+ """Load model from checkpoint."""
17
+ checkpoint = torch.load(checkpoint_path, map_location=device)
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ model.eval()
20
+ return model
21
+
22
+
23
+ def predict_image(
24
+ model: nn.Module,
25
+ image: Union[str, Path, Image.Image],
26
+ device: torch.device
27
+ ) -> Tuple[str, float]:
28
+ """Predict class for a single image."""
29
+ model.eval()
30
+
31
+ # Load image if path
32
+ if isinstance(image, (str, Path)):
33
+ image = Image.open(image).convert('RGB')
34
+
35
+ # Transform
36
+ transform = get_transforms(is_training=False)
37
+ img_tensor = transform(image).unsqueeze(0).to(device)
38
+
39
+ # Predict
40
+ with torch.no_grad():
41
+ output = model(img_tensor)
42
+ prob = torch.sigmoid(output).item()
43
+
44
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
45
+ confidence = prob if prob > 0.5 else 1 - prob
46
+
47
+ return pred_class, confidence
src/train.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training pipeline for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import AdamW
8
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
9
+ from torch.utils.data import DataLoader
10
+ from pathlib import Path
11
+ from typing import Dict, Optional, Tuple
12
+ import time
13
+
14
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
+
16
+ from .config import (
17
+ STAGE1_EPOCHS, STAGE1_LR, STAGE2_EPOCHS, STAGE2_LR,
18
+ WEIGHT_DECAY, SCHEDULER_PATIENCE, SCHEDULER_FACTOR,
19
+ EARLY_STOP_PATIENCE, CHECKPOINT_PATH, MODEL_DIR
20
+ )
21
+ from .model import PneumoniaClassifier, get_device
22
+
23
+
24
+ class EarlyStopping:
25
+ """Early stopping to prevent overfitting."""
26
+
27
+ def __init__(self, patience: int = 7, min_delta: float = 0.001):
28
+ self.patience = patience
29
+ self.min_delta = min_delta
30
+ self.counter = 0
31
+ self.best_loss = float('inf')
32
+ self.should_stop = False
33
+
34
+ def __call__(self, val_loss: float) -> bool:
35
+ if val_loss < self.best_loss - self.min_delta:
36
+ self.best_loss = val_loss
37
+ self.counter = 0
38
+ else:
39
+ self.counter += 1
40
+ if self.counter >= self.patience:
41
+ self.should_stop = True
42
+ return self.should_stop
43
+
44
+
45
+ def train_epoch(
46
+ model: nn.Module,
47
+ loader: DataLoader,
48
+ criterion: nn.Module,
49
+ optimizer: torch.optim.Optimizer,
50
+ device: torch.device
51
+ ) -> Tuple[float, float]:
52
+ """Train for one epoch."""
53
+ model.train()
54
+ total_loss = 0
55
+ all_preds, all_labels = [], []
56
+
57
+ for images, labels in loader:
58
+ images = images.to(device)
59
+ labels = labels.float().unsqueeze(1).to(device)
60
+
61
+ optimizer.zero_grad()
62
+ outputs = model(images)
63
+ loss = criterion(outputs, labels)
64
+ loss.backward()
65
+ optimizer.step()
66
+
67
+ total_loss += loss.item() * images.size(0)
68
+ preds = (torch.sigmoid(outputs) > 0.5).int()
69
+ all_preds.extend(preds.cpu().numpy())
70
+ all_labels.extend(labels.cpu().numpy())
71
+
72
+ avg_loss = total_loss / len(loader.dataset)
73
+ accuracy = accuracy_score(all_labels, all_preds)
74
+ return avg_loss, accuracy
75
+
76
+
77
+ def validate(
78
+ model: nn.Module,
79
+ loader: DataLoader,
80
+ criterion: nn.Module,
81
+ device: torch.device
82
+ ) -> Dict[str, float]:
83
+ """Validate the model."""
84
+ model.eval()
85
+ total_loss = 0
86
+ all_preds, all_labels = [], []
87
+
88
+ with torch.no_grad():
89
+ for images, labels in loader:
90
+ images = images.to(device)
91
+ labels = labels.float().unsqueeze(1).to(device)
92
+
93
+ outputs = model(images)
94
+ loss = criterion(outputs, labels)
95
+
96
+ total_loss += loss.item() * images.size(0)
97
+ preds = (torch.sigmoid(outputs) > 0.5).int()
98
+ all_preds.extend(preds.cpu().numpy())
99
+ all_labels.extend(labels.cpu().numpy())
100
+
101
+ avg_loss = total_loss / len(loader.dataset)
102
+
103
+ return {
104
+ 'loss': avg_loss,
105
+ 'accuracy': accuracy_score(all_labels, all_preds),
106
+ 'precision': precision_score(all_labels, all_preds, zero_division=0),
107
+ 'recall': recall_score(all_labels, all_preds, zero_division=0),
108
+ 'f1': f1_score(all_labels, all_preds, zero_division=0)
109
+ }
110
+
111
+
112
+ def train(
113
+ model: PneumoniaClassifier,
114
+ train_loader: DataLoader,
115
+ val_loader: DataLoader,
116
+ pos_weight: torch.Tensor,
117
+ epochs: int,
118
+ lr: float,
119
+ device: torch.device,
120
+ stage: str = "stage1",
121
+ use_wandb: bool = True,
122
+ wandb_run = None
123
+ ) -> Dict[str, list]:
124
+ """Training loop with validation."""
125
+
126
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
127
+ optimizer = AdamW(
128
+ filter(lambda p: p.requires_grad, model.parameters()),
129
+ lr=lr,
130
+ weight_decay=WEIGHT_DECAY
131
+ )
132
+ scheduler = ReduceLROnPlateau(
133
+ optimizer, mode='min',
134
+ patience=SCHEDULER_PATIENCE,
135
+ factor=SCHEDULER_FACTOR
136
+ )
137
+ early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE)
138
+
139
+ history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'val_f1': [], 'lr': []}
140
+ best_val_loss = float('inf')
141
+
142
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
143
+
144
+ for epoch in range(epochs):
145
+ start = time.time()
146
+
147
+ # Train
148
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
149
+
150
+ # Validate
151
+ val_metrics = validate(model, val_loader, criterion, device)
152
+
153
+ # Get current LR
154
+ current_lr = optimizer.param_groups[0]['lr']
155
+
156
+ # Update scheduler
157
+ scheduler.step(val_metrics['loss'])
158
+
159
+ # Log
160
+ elapsed = time.time() - start
161
+ print(f"[{stage}] Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) | "
162
+ f"Train Loss: {train_loss:.4f} | "
163
+ f"Val Loss: {val_metrics['loss']:.4f} | "
164
+ f"Val Acc: {val_metrics['accuracy']:.3f} | "
165
+ f"Val F1: {val_metrics['f1']:.3f} | "
166
+ f"LR: {current_lr:.2e}")
167
+
168
+ # W&B logging
169
+ if use_wandb and wandb_run:
170
+ wandb_run.log({
171
+ f"{stage}/train_loss": train_loss,
172
+ f"{stage}/train_acc": train_acc,
173
+ f"{stage}/val_loss": val_metrics['loss'],
174
+ f"{stage}/val_acc": val_metrics['accuracy'],
175
+ f"{stage}/val_precision": val_metrics['precision'],
176
+ f"{stage}/val_recall": val_metrics['recall'],
177
+ f"{stage}/val_f1": val_metrics['f1'],
178
+ f"{stage}/lr": current_lr,
179
+ "epoch": epoch + 1
180
+ })
181
+
182
+ # Save history
183
+ history['train_loss'].append(train_loss)
184
+ history['val_loss'].append(val_metrics['loss'])
185
+ history['val_acc'].append(val_metrics['accuracy'])
186
+ history['val_f1'].append(val_metrics['f1'])
187
+ history['lr'].append(current_lr)
188
+
189
+ # Save best model
190
+ if val_metrics['loss'] < best_val_loss:
191
+ best_val_loss = val_metrics['loss']
192
+ torch.save({
193
+ 'epoch': epoch + 1,
194
+ 'model_state_dict': model.state_dict(),
195
+ 'optimizer_state_dict': optimizer.state_dict(),
196
+ 'val_loss': best_val_loss,
197
+ 'val_metrics': val_metrics
198
+ }, CHECKPOINT_PATH)
199
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
200
+
201
+ # Early stopping
202
+ if early_stopping(val_metrics['loss']):
203
+ print(f"Early stopping triggered at epoch {epoch+1}")
204
+ break
205
+
206
+ return history
207
+
208
+
209
+ def train_two_stage(
210
+ model: PneumoniaClassifier,
211
+ train_loader: DataLoader,
212
+ val_loader: DataLoader,
213
+ pos_weight: torch.Tensor,
214
+ device: torch.device,
215
+ use_wandb: bool = True,
216
+ wandb_run = None
217
+ ) -> Dict[str, list]:
218
+ """Two-stage training: frozen backbone then fine-tuning."""
219
+
220
+ # Stage 1: Train classifier only
221
+ print("\n" + "=" * 60)
222
+ print("STAGE 1: Training classifier (backbone frozen)")
223
+ print("=" * 60)
224
+ model.freeze_backbone()
225
+ trainable, total = model.get_param_counts()
226
+ print(f"Trainable params: {trainable:,} / {total:,}")
227
+
228
+ history1 = train(
229
+ model, train_loader, val_loader, pos_weight,
230
+ epochs=STAGE1_EPOCHS, lr=STAGE1_LR, device=device,
231
+ stage="stage1", use_wandb=use_wandb, wandb_run=wandb_run
232
+ )
233
+
234
+ # Stage 2: Fine-tune entire network
235
+ print("\n" + "=" * 60)
236
+ print("STAGE 2: Fine-tuning entire network")
237
+ print("=" * 60)
238
+ model.unfreeze_backbone()
239
+ trainable, total = model.get_param_counts()
240
+ print(f"Trainable params: {trainable:,} / {total:,}")
241
+
242
+ history2 = train(
243
+ model, train_loader, val_loader, pos_weight,
244
+ epochs=STAGE2_EPOCHS, lr=STAGE2_LR, device=device,
245
+ stage="stage2", use_wandb=use_wandb, wandb_run=wandb_run
246
+ )
247
+
248
+ # Combine histories
249
+ history = {k: history1[k] + history2[k] for k in history1}
250
+ return history
src/utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for visualization and helpers.
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from typing import Optional
9
+
10
+ from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
11
+
12
+
13
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
14
+ """Denormalize image tensor from ImageNet normalization."""
15
+ mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
16
+ std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
17
+ return tensor * std + mean
18
+
19
+
20
+ def show_batch(
21
+ images: torch.Tensor,
22
+ labels: torch.Tensor,
23
+ predictions: Optional[torch.Tensor] = None,
24
+ n_images: int = 8,
25
+ save_path: Optional[str] = None
26
+ ):
27
+ """Display a batch of images with labels."""
28
+ n_images = min(n_images, len(images))
29
+ cols = 4
30
+ rows = (n_images + cols - 1) // cols
31
+
32
+ fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
33
+ axes = axes.flatten() if rows > 1 else [axes] if cols == 1 else axes
34
+
35
+ for idx in range(n_images):
36
+ img = denormalize(images[idx]).permute(1, 2, 0).numpy()
37
+ img = np.clip(img, 0, 1)
38
+
39
+ axes[idx].imshow(img)
40
+ axes[idx].axis('off')
41
+
42
+ label = CLASS_NAMES[labels[idx]]
43
+ title = f"True: {label}"
44
+
45
+ if predictions is not None:
46
+ pred = CLASS_NAMES[predictions[idx]]
47
+ color = 'green' if pred == label else 'red'
48
+ title += f"\nPred: {pred}"
49
+ axes[idx].set_title(title, color=color, fontsize=10)
50
+ else:
51
+ axes[idx].set_title(title, fontsize=10)
52
+
53
+ # Hide empty subplots
54
+ for idx in range(n_images, len(axes)):
55
+ axes[idx].axis('off')
56
+
57
+ plt.tight_layout()
58
+
59
+ if save_path:
60
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
61
+
62
+ plt.show()
63
+
64
+
65
+ def set_seed(seed: int = 42):
66
+ """Set random seed for reproducibility."""
67
+ import random
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ if torch.cuda.is_available():
72
+ torch.cuda.manual_seed_all(seed)
73
+ if torch.backends.mps.is_available():
74
+ torch.mps.manual_seed(seed)