GitHub Actions commited on
Commit
65c5202
·
1 Parent(s): 6529f5c

Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6

Browse files
Files changed (12) hide show
  1. app.py +156 -133
  2. requirements.txt +37 -12
  3. src/__init__.py +6 -0
  4. src/config.py +112 -0
  5. src/dataset.py +201 -0
  6. src/evaluate.py +107 -0
  7. src/export.py +190 -0
  8. src/gradcam.py +137 -0
  9. src/model.py +87 -0
  10. src/predict.py +47 -0
  11. src/train.py +250 -0
  12. src/utils.py +74 -0
app.py CHANGED
@@ -1,28 +1,24 @@
1
  """
2
- Hugging Face Spaces - Pneumonia Detection App
3
- This is a self-contained version for Hugging Face Spaces deployment.
4
- Copy this file as 'app.py' to your HF Spaces repository.
5
  """
6
 
 
 
 
 
 
 
7
  import streamlit as st
8
  import torch
9
- import torch.nn as nn
10
- from torchvision import models, transforms
11
  from PIL import Image
12
- import numpy as np
13
- from pytorch_grad_cam import GradCAM
14
- from pytorch_grad_cam.utils.image import show_cam_on_image
15
  import time
16
 
17
- # =============================================================================
18
- # Configuration
19
- # =============================================================================
20
-
21
- IMAGE_SIZE = 224
22
- IMAGENET_MEAN = [0.485, 0.456, 0.406]
23
- IMAGENET_STD = [0.229, 0.224, 0.225]
24
- CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
25
- MODEL_PATH = "models/best_model.pt"
26
 
27
  # =============================================================================
28
  # Page Configuration
@@ -31,79 +27,10 @@ MODEL_PATH = "models/best_model.pt"
31
  st.set_page_config(
32
  page_title="Pneumonia Detection",
33
  page_icon="🫁",
34
- layout="wide"
 
35
  )
36
 
37
- # =============================================================================
38
- # Model Definition
39
- # =============================================================================
40
-
41
- class PneumoniaClassifier(nn.Module):
42
- def __init__(self):
43
- super().__init__()
44
- self.backbone = models.efficientnet_b0(weights=None)
45
- in_features = self.backbone.classifier[1].in_features
46
- self.backbone.classifier = nn.Sequential(
47
- nn.Dropout(p=0.3, inplace=True),
48
- nn.Linear(in_features, 1)
49
- )
50
-
51
- def forward(self, x):
52
- return self.backbone(x)
53
-
54
-
55
- # =============================================================================
56
- # Helper Functions
57
- # =============================================================================
58
-
59
- def get_transforms():
60
- return transforms.Compose([
61
- transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
62
- transforms.ToTensor(),
63
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
64
- ])
65
-
66
-
67
- def denormalize(tensor):
68
- mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
69
- std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
70
- img = tensor.cpu() * std + mean
71
- return img.permute(1, 2, 0).numpy().clip(0, 1)
72
-
73
-
74
- @st.cache_resource
75
- def load_model():
76
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
- model = PneumoniaClassifier()
78
- checkpoint = torch.load(MODEL_PATH, map_location=device)
79
- model.load_state_dict(checkpoint['model_state_dict'])
80
- model.eval()
81
- return model.to(device), device
82
-
83
-
84
- def predict(model, image, device):
85
- transform = get_transforms()
86
- img_tensor = transform(image).unsqueeze(0).to(device)
87
-
88
- with torch.no_grad():
89
- output = model(img_tensor)
90
- prob = torch.sigmoid(output).item()
91
-
92
- pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
93
- confidence = prob if prob > 0.5 else 1 - prob
94
- return pred_class, confidence, img_tensor
95
-
96
-
97
- def generate_gradcam(model, img_tensor, device):
98
- target_layer = model.backbone.features[-1]
99
- cam = GradCAM(model=model, target_layers=[target_layer])
100
- grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0]
101
-
102
- rgb_img = denormalize(img_tensor[0])
103
- cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
104
- return cam_image, rgb_img
105
-
106
-
107
  # =============================================================================
108
  # Custom CSS
109
  # =============================================================================
@@ -137,17 +64,42 @@ st.markdown("""
137
  background-color: #FFEBEE;
138
  border: 2px solid #F44336;
139
  }
 
 
 
 
 
 
 
 
 
 
140
  </style>
141
  """, unsafe_allow_html=True)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # =============================================================================
144
  # Sidebar
145
  # =============================================================================
146
 
147
  with st.sidebar:
148
- st.title("🫁 About")
 
 
149
  st.markdown("""
150
- Deep learning model for detecting **pneumonia** from chest X-rays.
151
 
152
  **Model:** EfficientNet-B0
153
  **Accuracy:** 90.5%
@@ -155,101 +107,172 @@ with st.sidebar:
155
  """)
156
 
157
  st.divider()
 
 
 
 
 
 
 
 
 
 
158
  st.subheader("Model Metrics")
159
  col1, col2 = st.columns(2)
160
- col1.metric("Accuracy", "90.5%")
161
- col2.metric("Recall", "98.2%")
 
 
 
 
162
 
163
  st.divider()
164
- st.markdown("*Built with PyTorch & Streamlit*")
 
 
 
 
 
 
 
165
 
166
  # =============================================================================
167
  # Main Content
168
  # =============================================================================
169
 
170
- st.markdown('<p class="main-header">🫁 Pneumonia Detection</p>', unsafe_allow_html=True)
171
- st.markdown('<p class="sub-header">Upload a chest X-ray image to detect pneumonia</p>', unsafe_allow_html=True)
 
172
 
173
  # Load model
174
  try:
175
- model, device = load_model()
176
  model_loaded = True
177
  except Exception as e:
178
  st.error(f"Failed to load model: {e}")
179
  model_loaded = False
180
 
181
  if model_loaded:
 
182
  col1, col2 = st.columns([1, 1])
183
 
184
  with col1:
185
  st.subheader("📤 Upload Image")
 
186
  uploaded_file = st.file_uploader(
187
  "Choose a chest X-ray image",
188
- type=["jpg", "jpeg", "png"]
 
189
  )
190
 
191
- # Sample images
192
- st.subheader("🖼️ Or Try Sample Images")
 
 
193
  sample_col1, sample_col2 = st.columns(2)
 
 
194
  with sample_col1:
195
- if st.button("Normal Sample", use_container_width=True):
196
- st.session_state.sample = "samples/normal_sample.jpeg"
197
  with sample_col2:
198
- if st.button("Pneumonia Sample", use_container_width=True):
199
- st.session_state.sample = "samples/pneumonia_sample.jpeg"
200
-
201
- # Determine which image to use
202
- image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if uploaded_file is not None:
204
- image = Image.open(uploaded_file).convert("RGB")
205
- st.session_state.sample = None # Clear sample when uploading
206
- elif "sample" in st.session_state and st.session_state.sample:
207
- image = Image.open(st.session_state.sample).convert("RGB")
208
-
209
- if image is not None:
 
 
 
 
 
210
 
 
211
  with col1:
212
- st.image(image, caption="Uploaded X-Ray", use_container_width=True)
213
- analyze = st.button("🔬 Analyze Image", type="primary", use_container_width=True)
214
 
215
- if analyze:
216
  with col2:
217
- with st.spinner("Analyzing..."):
 
218
  start_time = time.time()
219
- pred_class, confidence, img_tensor = predict(model, image, device)
220
- cam_image, original = generate_gradcam(model, img_tensor, device)
221
  inference_time = (time.time() - start_time) * 1000
222
 
223
- # Results
 
 
 
224
  if pred_class == "PNEUMONIA":
225
  st.markdown(f"""
226
  <div class="prediction-box prediction-pneumonia">
227
- <h2 style="color: #F44336;">⚠️ PNEUMONIA DETECTED</h2>
228
- <p>Confidence: {confidence:.1%}</p>
229
  </div>
230
  """, unsafe_allow_html=True)
231
  else:
232
  st.markdown(f"""
233
  <div class="prediction-box prediction-normal">
234
- <h2 style="color: #4CAF50;">✅ NORMAL</h2>
235
- <p>Confidence: {confidence:.1%}</p>
236
  </div>
237
  """, unsafe_allow_html=True)
238
 
239
  # Metrics row
240
  m1, m2, m3 = st.columns(3)
241
- m1.metric("Prediction", pred_class)
242
- m2.metric("Confidence", f"{confidence:.1%}")
243
- m3.metric("Time", f"{inference_time:.0f}ms")
 
 
 
 
 
 
 
 
244
 
245
- # Grad-CAM
246
- st.subheader("🔥 Grad-CAM")
247
  gcol1, gcol2 = st.columns(2)
248
- gcol1.image(original, caption="Original", use_container_width=True)
249
- gcol2.image(cam_image, caption="Heatmap", use_container_width=True)
 
 
 
 
 
 
 
 
250
 
251
- st.warning("**Disclaimer:** For educational purposes only. Consult a healthcare professional.")
 
 
 
 
 
252
 
253
- # Footers
254
  st.markdown("---")
255
- st.markdown("<p style='text-align:center;color:#888;'>Built with PyTorch & Streamlit</p>", unsafe_allow_html=True)
 
 
 
 
1
  """
2
+ Streamlit Web UI for Pneumonia Detection.
3
+
4
+ Run with: streamlit run app/app.py
5
  """
6
 
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add project root to path
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
  import streamlit as st
14
  import torch
 
 
15
  from PIL import Image
 
 
 
16
  import time
17
 
18
+ from src.config import CHECKPOINT_PATH, CLASS_NAMES
19
+ from src.model import create_model, get_device
20
+ from src.predict import load_model, predict_image
21
+ from src.gradcam import generate_gradcam
 
 
 
 
 
22
 
23
  # =============================================================================
24
  # Page Configuration
 
27
  st.set_page_config(
28
  page_title="Pneumonia Detection",
29
  page_icon="🫁",
30
+ layout="wide",
31
+ initial_sidebar_state="expanded"
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # =============================================================================
35
  # Custom CSS
36
  # =============================================================================
 
64
  background-color: #FFEBEE;
65
  border: 2px solid #F44336;
66
  }
67
+ .confidence-text {
68
+ font-size: 1.2rem;
69
+ font-weight: bold;
70
+ }
71
+ .metric-card {
72
+ background-color: #f8f9fa;
73
+ padding: 1rem;
74
+ border-radius: 8px;
75
+ text-align: center;
76
+ }
77
  </style>
78
  """, unsafe_allow_html=True)
79
 
80
+ # =============================================================================
81
+ # Model Loading (Cached)
82
+ # =============================================================================
83
+
84
+ @st.cache_resource
85
+ def load_model_cached():
86
+ """Load model once and cache it."""
87
+ device = get_device()
88
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
89
+ model = load_model(model, CHECKPOINT_PATH, device)
90
+ return model, device
91
+
92
+
93
  # =============================================================================
94
  # Sidebar
95
  # =============================================================================
96
 
97
  with st.sidebar:
98
+ st.image("https://img.icons8.com/fluency/96/lungs.png", width=80)
99
+ st.title("About")
100
+
101
  st.markdown("""
102
+ This application uses deep learning to detect **pneumonia** from chest X-ray images.
103
 
104
  **Model:** EfficientNet-B0
105
  **Accuracy:** 90.5%
 
107
  """)
108
 
109
  st.divider()
110
+
111
+ st.subheader("How to Use")
112
+ st.markdown("""
113
+ 1. Upload a chest X-ray image
114
+ 2. Click **Analyze Image**
115
+ 3. View prediction and Grad-CAM
116
+ """)
117
+
118
+ st.divider()
119
+
120
  st.subheader("Model Metrics")
121
  col1, col2 = st.columns(2)
122
+ with col1:
123
+ st.metric("Accuracy", "90.5%")
124
+ st.metric("Precision", "88.0%")
125
+ with col2:
126
+ st.metric("Recall", "98.2%")
127
+ st.metric("F1 Score", "92.8%")
128
 
129
  st.divider()
130
+
131
+ st.markdown("""
132
+ **Links:**
133
+ [GitHub Repository](#) | [Live Demo](#)
134
+
135
+ ---
136
+ *Built with PyTorch & Streamlit*
137
+ """)
138
 
139
  # =============================================================================
140
  # Main Content
141
  # =============================================================================
142
 
143
+ # Header
144
+ st.markdown('<p class="main-header">🫁 Pneumonia Detection from Chest X-Rays</p>', unsafe_allow_html=True)
145
+ st.markdown('<p class="sub-header">Upload a chest X-ray image to detect pneumonia using AI</p>', unsafe_allow_html=True)
146
 
147
  # Load model
148
  try:
149
+ model, device = load_model_cached()
150
  model_loaded = True
151
  except Exception as e:
152
  st.error(f"Failed to load model: {e}")
153
  model_loaded = False
154
 
155
  if model_loaded:
156
+ # Create columns for layout
157
  col1, col2 = st.columns([1, 1])
158
 
159
  with col1:
160
  st.subheader("📤 Upload Image")
161
+
162
  uploaded_file = st.file_uploader(
163
  "Choose a chest X-ray image",
164
+ type=["jpg", "jpeg", "png"],
165
+ help="Supported formats: JPG, JPEG, PNG"
166
  )
167
 
168
+ # Sample images section
169
+ st.markdown("---")
170
+ st.markdown("**Or try a sample image:**")
171
+
172
  sample_col1, sample_col2 = st.columns(2)
173
+
174
+ use_sample = None
175
  with sample_col1:
176
+ if st.button("🟢 Normal Sample", width="stretch"):
177
+ use_sample = "normal"
178
  with sample_col2:
179
+ if st.button("🔴 Pneumonia Sample", width="stretch"):
180
+ use_sample = "pneumonia"
181
+
182
+ # Load sample image if selected
183
+ if use_sample == "normal":
184
+ sample_path = Path(__file__).parent / "samples" / "normal_sample.jpeg"
185
+ if sample_path.exists():
186
+ uploaded_file = sample_path
187
+ elif use_sample == "pneumonia":
188
+ sample_path = Path(__file__).parent / "samples" / "pneumonia_sample.jpeg"
189
+ if sample_path.exists():
190
+ uploaded_file = sample_path
191
+
192
+ with col2:
193
+ st.subheader("🔍 Analysis Results")
194
+ results_placeholder = st.empty()
195
+
196
+ # Process image if uploaded
197
  if uploaded_file is not None:
198
+ # Load image
199
+ if isinstance(uploaded_file, Path):
200
+ image = Image.open(uploaded_file).convert("RGB")
201
+ st.session_state['image_source'] = str(uploaded_file)
202
+ else:
203
+ image = Image.open(uploaded_file).convert("RGB")
204
+ st.session_state['image_source'] = uploaded_file.name
205
+
206
+ # Display uploaded image
207
+ with col1:
208
+ st.image(image, caption="Uploaded X-Ray", width="stretch")
209
 
210
+ # Analyze button
211
  with col1:
212
+ analyze_button = st.button("🔬 Analyze Image", type="primary", width="stretch")
 
213
 
214
+ if analyze_button:
215
  with col2:
216
+ with st.spinner("Analyzing image..."):
217
+ # Run prediction
218
  start_time = time.time()
219
+ pred_class, confidence = predict_image(model, image, device)
 
220
  inference_time = (time.time() - start_time) * 1000
221
 
222
+ # Generate Grad-CAM
223
+ cam_image, _, _, original = generate_gradcam(model, image, device)
224
+
225
+ # Display results
226
  if pred_class == "PNEUMONIA":
227
  st.markdown(f"""
228
  <div class="prediction-box prediction-pneumonia">
229
+ <h2 style="color: #F44336; margin: 0;">⚠️ PNEUMONIA DETECTED</h2>
230
+ <p class="confidence-text">Confidence: {confidence:.1%}</p>
231
  </div>
232
  """, unsafe_allow_html=True)
233
  else:
234
  st.markdown(f"""
235
  <div class="prediction-box prediction-normal">
236
+ <h2 style="color: #4CAF50; margin: 0;">✅ NORMAL</h2>
237
+ <p class="confidence-text">Confidence: {confidence:.1%}</p>
238
  </div>
239
  """, unsafe_allow_html=True)
240
 
241
  # Metrics row
242
  m1, m2, m3 = st.columns(3)
243
+ with m1:
244
+ st.metric("Prediction", pred_class)
245
+ with m2:
246
+ st.metric("Confidence", f"{confidence:.1%}")
247
+ with m3:
248
+ st.metric("Time", f"{inference_time:.0f}ms")
249
+
250
+ # Grad-CAM visualization
251
+ st.markdown("---")
252
+ st.subheader("🔥 Grad-CAM Visualization")
253
+ st.caption("Highlighted regions show areas that influenced the prediction")
254
 
 
 
255
  gcol1, gcol2 = st.columns(2)
256
+ with gcol1:
257
+ st.image(original, caption="Original", width="stretch")
258
+ with gcol2:
259
+ st.image(cam_image, caption="Grad-CAM Heatmap", width="stretch")
260
+
261
+ # Disclaimer
262
+ st.warning("""
263
+ **Disclaimer:** This tool is for educational purposes only and should not be used
264
+ for medical diagnosis. Always consult a qualified healthcare professional.
265
+ """)
266
 
267
+ else:
268
+ st.error("Model could not be loaded. Please check the model file exists.")
269
+
270
+ # =============================================================================
271
+ # Footer
272
+ # =============================================================================
273
 
 
274
  st.markdown("---")
275
+ st.markdown(
276
+ "<p style='text-align: center; color: #888;'>Built with ❤️ using PyTorch, EfficientNet-B0, and Streamlit</p>",
277
+ unsafe_allow_html=True
278
+ )
requirements.txt CHANGED
@@ -1,20 +1,45 @@
1
- # Hugging Face Spaces Requirements
2
- # Minimal dependencies for deployment
3
-
4
- # PyTorch (CPU version for HF Spaces)
5
- --extra-index-url https://download.pytorch.org/whl/cpu
6
- torch>=2.5.0
7
- torchvision>=0.20.0
8
-
9
- # Core
10
  pillow>=10.0.0
11
  numpy>=1.24.0
12
 
13
- # OpenCV headless (must be before grad-cam to avoid libGL issues)
14
- opencv-python-headless>=4.8.0
 
 
 
 
 
15
 
16
  # Model Interpretability
17
  grad-cam>=1.4.0
18
 
 
 
 
 
 
19
  # Web UI
20
- streamlit>=1.28.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Deep Learning
2
+ torch>=2.1.0
3
+ torchvision>=0.16.0
 
 
 
 
 
 
4
  pillow>=10.0.0
5
  numpy>=1.24.0
6
 
7
+ # Data Analysis & Visualization
8
+ pandas>=2.0.0
9
+ matplotlib>=3.7.0
10
+ seaborn>=0.12.0
11
+
12
+ # Experiment Tracking
13
+ wandb>=0.15.0
14
 
15
  # Model Interpretability
16
  grad-cam>=1.4.0
17
 
18
+ # API
19
+ fastapi>=0.104.0
20
+ uvicorn>=0.24.0
21
+ python-multipart>=0.0.6
22
+
23
  # Web UI
24
+ streamlit>=1.28.0
25
+
26
+ # Testing
27
+ pytest>=7.4.0
28
+
29
+ # Code Quality
30
+ black>=23.0.0
31
+ ruff>=0.1.0
32
+
33
+ # Jupyter
34
+ jupyterlab>=4.0.0
35
+ ipywidgets>=8.0.0
36
+
37
+ # Utilities
38
+ python-dotenv>=1.0.0
39
+ tqdm>=4.66.0
40
+ scikit-learn>=1.3.0
41
+
42
+ # ONNX Export
43
+ onnx>=1.15.0
44
+ onnxruntime>=1.16.0
45
+ onnxscript>=0.1.0
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Pneumonia Detection from Chest X-Rays
3
+ Medical Image Classification using Deep Learning
4
+ """
5
+
6
+ __version__ = "0.1.0"
src/config.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration constants for the Pneumonia Detection project.
3
+ All hyperparameters and paths are defined here for easy modification.
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ # =============================================================================
9
+ # Project Paths
10
+ # =============================================================================
11
+ PROJECT_ROOT = Path(__file__).parent.parent
12
+ DATA_DIR = PROJECT_ROOT / "data" / "raw"
13
+ PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
14
+ MODEL_DIR = PROJECT_ROOT / "models"
15
+ OUTPUT_DIR = PROJECT_ROOT / "outputs"
16
+ FIGURES_DIR = OUTPUT_DIR / "figures"
17
+ LOGS_DIR = OUTPUT_DIR / "logs"
18
+
19
+ # =============================================================================
20
+ # Data Configuration
21
+ # =============================================================================
22
+ IMAGE_SIZE = 224 # EfficientNet-B0 input size
23
+ BATCH_SIZE = 32
24
+ NUM_WORKERS = 4 # DataLoader workers
25
+
26
+ # ImageNet normalization (required for pretrained models)
27
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
28
+ IMAGENET_STD = [0.229, 0.224, 0.225]
29
+
30
+ # Class labels
31
+ CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
32
+ NUM_CLASSES = 1 # Binary classification with sigmoid
33
+
34
+ # =============================================================================
35
+ # Model Configuration
36
+ # =============================================================================
37
+ MODEL_NAME = "efficientnet_b0"
38
+ DROPOUT_RATE = 0.3
39
+ PRETRAINED = True
40
+
41
+ # =============================================================================
42
+ # Training Configuration - Stage 1 (Frozen Backbone)
43
+ # =============================================================================
44
+ STAGE1_EPOCHS = 5
45
+ STAGE1_LR = 1e-4
46
+ STAGE1_FREEZE_BACKBONE = True
47
+
48
+ # =============================================================================
49
+ # Training Configuration - Stage 2 (Fine-tuning)
50
+ # =============================================================================
51
+ STAGE2_EPOCHS = 15
52
+ STAGE2_LR = 1e-5
53
+ STAGE2_FREEZE_BACKBONE = False
54
+
55
+ # =============================================================================
56
+ # Optimizer Configuration
57
+ # =============================================================================
58
+ WEIGHT_DECAY = 1e-4
59
+ BETAS = (0.9, 0.999)
60
+
61
+ # =============================================================================
62
+ # Scheduler Configuration
63
+ # =============================================================================
64
+ SCHEDULER_PATIENCE = 3
65
+ SCHEDULER_FACTOR = 0.5
66
+ SCHEDULER_MIN_LR = 1e-7
67
+
68
+ # =============================================================================
69
+ # Early Stopping Configuration
70
+ # =============================================================================
71
+ EARLY_STOP_PATIENCE = 7
72
+ EARLY_STOP_MIN_DELTA = 0.001
73
+
74
+ # =============================================================================
75
+ # Model Checkpointing
76
+ # =============================================================================
77
+ CHECKPOINT_PATH = MODEL_DIR / "best_model.pt"
78
+ SAVE_BEST_ONLY = True
79
+ MONITOR_METRIC = "val_loss"
80
+
81
+ # =============================================================================
82
+ # Weights & Biases Configuration
83
+ # =============================================================================
84
+ WANDB_PROJECT = "pneumonia-detection"
85
+ WANDB_ENTITY = None # Set to your W&B username if needed
86
+
87
+ # =============================================================================
88
+ # Inference Configuration
89
+ # =============================================================================
90
+ CONFIDENCE_THRESHOLD = 0.5 # For binary classification
91
+ GRADCAM_TARGET_LAYER = "features" # EfficientNet feature extractor
92
+
93
+ # =============================================================================
94
+ # Random Seed (for reproducibility)
95
+ # =============================================================================
96
+ SEED = 42
97
+
98
+
99
+ def create_directories():
100
+ """Create all necessary directories if they don't exist."""
101
+ for directory in [DATA_DIR, PROCESSED_DIR, MODEL_DIR, FIGURES_DIR, LOGS_DIR]:
102
+ directory.mkdir(parents=True, exist_ok=True)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ # Print configuration for verification
107
+ print(f"Project Root: {PROJECT_ROOT}")
108
+ print(f"Data Directory: {DATA_DIR}")
109
+ print(f"Model Directory: {MODEL_DIR}")
110
+ print(f"Image Size: {IMAGE_SIZE}")
111
+ print(f"Batch Size: {BATCH_SIZE}")
112
+ print(f"Model: {MODEL_NAME}")
src/dataset.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Dataset and DataLoader utilities for Chest X-Ray classification.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Tuple, Optional, List
7
+ import random
8
+
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
11
+ from torchvision import transforms
12
+ from PIL import Image
13
+ from sklearn.model_selection import train_test_split
14
+
15
+ from .config import (
16
+ DATA_DIR, IMAGE_SIZE, BATCH_SIZE, NUM_WORKERS,
17
+ IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES, SEED
18
+ )
19
+
20
+
21
+ class ChestXRayDataset(Dataset):
22
+ """Dataset for Chest X-Ray images."""
23
+
24
+ def __init__(
25
+ self,
26
+ image_paths: List[Path],
27
+ labels: List[int],
28
+ transform: Optional[transforms.Compose] = None
29
+ ):
30
+ self.image_paths = image_paths
31
+ self.labels = labels
32
+ self.transform = transform
33
+
34
+ def __len__(self) -> int:
35
+ return len(self.image_paths)
36
+
37
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
38
+ img_path = self.image_paths[idx]
39
+ label = self.labels[idx]
40
+
41
+ # Load image and convert to RGB
42
+ image = Image.open(img_path).convert('RGB')
43
+
44
+ if self.transform:
45
+ image = self.transform(image)
46
+
47
+ return image, label
48
+
49
+
50
+ def get_transforms(is_training: bool = True) -> transforms.Compose:
51
+ """Get image transforms for training or validation/test."""
52
+ if is_training:
53
+ return transforms.Compose([
54
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
55
+ transforms.RandomHorizontalFlip(p=0.5),
56
+ transforms.RandomRotation(10),
57
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
60
+ ])
61
+ else:
62
+ return transforms.Compose([
63
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
66
+ ])
67
+
68
+
69
+ def load_image_paths_and_labels(
70
+ data_dir: Path,
71
+ split: str
72
+ ) -> Tuple[List[Path], List[int]]:
73
+ """Load image paths and labels from a data split directory."""
74
+ image_paths = []
75
+ labels = []
76
+
77
+ for class_idx, class_name in enumerate(CLASS_NAMES):
78
+ class_dir = data_dir / split / class_name
79
+ if class_dir.exists():
80
+ for img_path in class_dir.glob('*.jpeg'):
81
+ image_paths.append(img_path)
82
+ labels.append(class_idx)
83
+
84
+ return image_paths, labels
85
+
86
+
87
+ def create_train_val_split(
88
+ data_dir: Path = DATA_DIR,
89
+ val_ratio: float = 0.15,
90
+ seed: int = SEED
91
+ ) -> Tuple[List[Path], List[int], List[Path], List[int]]:
92
+ """Create stratified train/val split from training data."""
93
+ # Load all training images
94
+ train_paths, train_labels = load_image_paths_and_labels(data_dir, 'train')
95
+
96
+ # Stratified split
97
+ train_paths, val_paths, train_labels, val_labels = train_test_split(
98
+ train_paths, train_labels,
99
+ test_size=val_ratio,
100
+ stratify=train_labels,
101
+ random_state=seed
102
+ )
103
+
104
+ return train_paths, train_labels, val_paths, val_labels
105
+
106
+
107
+ def get_class_weights(labels: List[int]) -> torch.Tensor:
108
+ """Calculate class weights for imbalanced dataset."""
109
+ class_counts = torch.bincount(torch.tensor(labels))
110
+ total = len(labels)
111
+ weights = total / (len(class_counts) * class_counts.float())
112
+ return weights
113
+
114
+
115
+ def get_sampler(labels: List[int]) -> WeightedRandomSampler:
116
+ """Create weighted sampler for balanced batches."""
117
+ class_weights = get_class_weights(labels)
118
+ sample_weights = [class_weights[label] for label in labels]
119
+ sampler = WeightedRandomSampler(
120
+ weights=sample_weights,
121
+ num_samples=len(labels),
122
+ replacement=True
123
+ )
124
+ return sampler
125
+
126
+
127
+ def get_dataloaders(
128
+ data_dir: Path = DATA_DIR,
129
+ batch_size: int = BATCH_SIZE,
130
+ num_workers: int = NUM_WORKERS,
131
+ val_ratio: float = 0.15,
132
+ use_weighted_sampling: bool = True
133
+ ) -> Tuple[DataLoader, DataLoader, DataLoader]:
134
+ """Create train, validation, and test DataLoaders."""
135
+
136
+ # Create train/val split
137
+ train_paths, train_labels, val_paths, val_labels = create_train_val_split(
138
+ data_dir, val_ratio
139
+ )
140
+
141
+ # Load test data
142
+ test_paths, test_labels = load_image_paths_and_labels(data_dir, 'test')
143
+
144
+ # Create datasets
145
+ train_dataset = ChestXRayDataset(
146
+ train_paths, train_labels, transform=get_transforms(is_training=True)
147
+ )
148
+ val_dataset = ChestXRayDataset(
149
+ val_paths, val_labels, transform=get_transforms(is_training=False)
150
+ )
151
+ test_dataset = ChestXRayDataset(
152
+ test_paths, test_labels, transform=get_transforms(is_training=False)
153
+ )
154
+
155
+ # Create sampler for training if using weighted sampling
156
+ train_sampler = get_sampler(train_labels) if use_weighted_sampling else None
157
+
158
+ # Only use pin_memory for CUDA (not supported on MPS)
159
+ pin_memory = torch.cuda.is_available()
160
+
161
+ # Create dataloaders
162
+ train_loader = DataLoader(
163
+ train_dataset,
164
+ batch_size=batch_size,
165
+ sampler=train_sampler,
166
+ shuffle=(train_sampler is None),
167
+ num_workers=num_workers,
168
+ pin_memory=pin_memory
169
+ )
170
+
171
+ val_loader = DataLoader(
172
+ val_dataset,
173
+ batch_size=batch_size,
174
+ shuffle=False,
175
+ num_workers=num_workers,
176
+ pin_memory=pin_memory
177
+ )
178
+
179
+ test_loader = DataLoader(
180
+ test_dataset,
181
+ batch_size=batch_size,
182
+ shuffle=False,
183
+ num_workers=num_workers,
184
+ pin_memory=pin_memory
185
+ )
186
+
187
+ # Print dataset info
188
+ print(f"Train: {len(train_dataset)} images")
189
+ print(f"Val: {len(val_dataset)} images")
190
+ print(f"Test: {len(test_dataset)} images")
191
+
192
+ return train_loader, val_loader, test_loader
193
+
194
+
195
+ def get_pos_weight(labels: List[int]) -> torch.Tensor:
196
+ """Calculate pos_weight for BCEWithLogitsLoss to handle class imbalance."""
197
+ labels_tensor = torch.tensor(labels)
198
+ neg_count = (labels_tensor == 0).sum().float() # NORMAL
199
+ pos_count = (labels_tensor == 1).sum().float() # PNEUMONIA
200
+ pos_weight = neg_count / pos_count
201
+ return pos_weight
src/evaluate.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation functions for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ import numpy as np
9
+ from typing import Dict, Tuple
10
+ from sklearn.metrics import (
11
+ accuracy_score, precision_score, recall_score, f1_score,
12
+ roc_auc_score, confusion_matrix, classification_report
13
+ )
14
+
15
+ from .config import CLASS_NAMES
16
+
17
+
18
+ def predict_proba(
19
+ model: nn.Module,
20
+ loader: DataLoader,
21
+ device: torch.device
22
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
23
+ """Get predictions, probabilities, and true labels."""
24
+ model.eval()
25
+ all_probs, all_preds, all_labels = [], [], []
26
+
27
+ with torch.no_grad():
28
+ for images, labels in loader:
29
+ images = images.to(device)
30
+ outputs = model(images)
31
+ probs = torch.sigmoid(outputs).cpu().numpy()
32
+ preds = (probs > 0.5).astype(int)
33
+
34
+ all_probs.extend(probs.flatten())
35
+ all_preds.extend(preds.flatten())
36
+ all_labels.extend(labels.numpy())
37
+
38
+ return np.array(all_probs), np.array(all_preds), np.array(all_labels)
39
+
40
+
41
+ def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict:
42
+ """Compute all evaluation metrics."""
43
+ return {
44
+ 'accuracy': accuracy_score(y_true, y_pred),
45
+ 'precision': precision_score(y_true, y_pred),
46
+ 'recall': recall_score(y_true, y_pred),
47
+ 'f1': f1_score(y_true, y_pred),
48
+ 'roc_auc': roc_auc_score(y_true, y_proba),
49
+ 'confusion_matrix': confusion_matrix(y_true, y_pred)
50
+ }
51
+
52
+
53
+ def evaluate_model(
54
+ model: nn.Module,
55
+ loader: DataLoader,
56
+ device: torch.device
57
+ ) -> Dict:
58
+ """Full evaluation on a dataset."""
59
+ probs, preds, labels = predict_proba(model, loader, device)
60
+ metrics = compute_metrics(labels, preds, probs)
61
+
62
+ print("=" * 50)
63
+ print("EVALUATION RESULTS")
64
+ print("=" * 50)
65
+ print(f"Accuracy: {metrics['accuracy']:.4f}")
66
+ print(f"Precision: {metrics['precision']:.4f}")
67
+ print(f"Recall: {metrics['recall']:.4f}")
68
+ print(f"F1 Score: {metrics['f1']:.4f}")
69
+ print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
70
+ print("\nConfusion Matrix:")
71
+ print(f" {CLASS_NAMES[0]:>10} {CLASS_NAMES[1]:>10}")
72
+ for i, row in enumerate(metrics['confusion_matrix']):
73
+ print(f" {CLASS_NAMES[i]:>10} {row[0]:>10} {row[1]:>10}")
74
+
75
+ print("\nClassification Report:")
76
+ print(classification_report(labels, preds, target_names=CLASS_NAMES))
77
+
78
+ return metrics
79
+
80
+
81
+ def get_predictions_with_paths(
82
+ model: nn.Module,
83
+ dataset,
84
+ device: torch.device
85
+ ) -> list:
86
+ """Get predictions with image paths for error analysis."""
87
+ model.eval()
88
+ results = []
89
+
90
+ with torch.no_grad():
91
+ for idx in range(len(dataset)):
92
+ image, label = dataset[idx]
93
+ image = image.unsqueeze(0).to(device)
94
+
95
+ output = model(image)
96
+ prob = torch.sigmoid(output).item()
97
+ pred = 1 if prob > 0.5 else 0
98
+
99
+ results.append({
100
+ 'path': dataset.image_paths[idx],
101
+ 'true_label': label,
102
+ 'pred_label': pred,
103
+ 'probability': prob,
104
+ 'correct': pred == label
105
+ })
106
+
107
+ return results
src/export.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX export utilities for model deployment.
3
+
4
+ ONNX (Open Neural Network Exchange) is a universal format that allows
5
+ models to run on different frameworks and platforms:
6
+ - TensorFlow, PyTorch, etc.
7
+ - Mobile devices (iOS, Android)
8
+ - Web browsers (ONNX.js)
9
+ - C++, Java, and other languages
10
+ - Optimized inference servers
11
+ """
12
+
13
+ import torch
14
+ import numpy as np
15
+ from pathlib import Path
16
+ from typing import Tuple, Optional
17
+
18
+ from .config import CHECKPOINT_PATH, MODEL_DIR, IMAGE_SIZE
19
+ from .model import create_model, get_device
20
+
21
+
22
+ def export_to_onnx(
23
+ checkpoint_path: Path = CHECKPOINT_PATH,
24
+ output_path: Optional[Path] = None,
25
+ opset_version: int = 18
26
+ ) -> Path:
27
+ """
28
+ Export PyTorch model to ONNX format.
29
+
30
+ Args:
31
+ checkpoint_path: Path to the PyTorch checkpoint
32
+ output_path: Path for the ONNX model (default: models/best_model.onnx)
33
+ opset_version: ONNX opset version (14 is widely compatible)
34
+
35
+ Returns:
36
+ Path to the exported ONNX model
37
+ """
38
+ if output_path is None:
39
+ output_path = MODEL_DIR / "best_model.onnx"
40
+
41
+ # Load model
42
+ device = torch.device("cpu") # Export on CPU for compatibility
43
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
44
+
45
+ checkpoint = torch.load(checkpoint_path, map_location=device)
46
+ model.load_state_dict(checkpoint['model_state_dict'])
47
+ model.eval()
48
+
49
+ # Create dummy input (batch_size=1, channels=3, height=224, width=224)
50
+ dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
51
+
52
+ # Export to ONNX
53
+ torch.onnx.export(
54
+ model,
55
+ dummy_input,
56
+ output_path,
57
+ export_params=True,
58
+ opset_version=opset_version,
59
+ do_constant_folding=True, # Optimize constants
60
+ input_names=['image'],
61
+ output_names=['logits'],
62
+ dynamic_axes={
63
+ 'image': {0: 'batch_size'}, # Variable batch size
64
+ 'logits': {0: 'batch_size'}
65
+ }
66
+ )
67
+
68
+ print(f"Model exported to: {output_path}")
69
+ print(f"File size: {output_path.stat().st_size / 1024 / 1024:.2f} MB")
70
+
71
+ return output_path
72
+
73
+
74
+ def validate_onnx_model(
75
+ onnx_path: Path,
76
+ checkpoint_path: Path = CHECKPOINT_PATH,
77
+ rtol: float = 1e-3,
78
+ atol: float = 1e-5
79
+ ) -> bool:
80
+ """
81
+ Validate that ONNX model produces same outputs as PyTorch model.
82
+
83
+ Args:
84
+ onnx_path: Path to ONNX model
85
+ checkpoint_path: Path to PyTorch checkpoint
86
+ rtol: Relative tolerance for comparison
87
+ atol: Absolute tolerance for comparison
88
+
89
+ Returns:
90
+ True if outputs match, False otherwise
91
+ """
92
+ import onnx
93
+ import onnxruntime as ort
94
+
95
+ # Check ONNX model is valid
96
+ onnx_model = onnx.load(onnx_path)
97
+ onnx.checker.check_model(onnx_model)
98
+ print("ONNX model structure is valid")
99
+
100
+ # Load PyTorch model
101
+ device = torch.device("cpu")
102
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
103
+ checkpoint = torch.load(checkpoint_path, map_location=device)
104
+ model.load_state_dict(checkpoint['model_state_dict'])
105
+ model.eval()
106
+
107
+ # Create test input
108
+ test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
109
+
110
+ # Get PyTorch output
111
+ with torch.no_grad():
112
+ pytorch_output = model(test_input).numpy()
113
+
114
+ # Get ONNX output
115
+ ort_session = ort.InferenceSession(str(onnx_path))
116
+ onnx_output = ort_session.run(
117
+ None,
118
+ {'image': test_input.numpy()}
119
+ )[0]
120
+
121
+ # Compare outputs
122
+ is_close = np.allclose(pytorch_output, onnx_output, rtol=rtol, atol=atol)
123
+
124
+ if is_close:
125
+ print("Validation PASSED: ONNX outputs match PyTorch outputs")
126
+ print(f" PyTorch output: {pytorch_output.flatten()[:5]}...")
127
+ print(f" ONNX output: {onnx_output.flatten()[:5]}...")
128
+ else:
129
+ print("Validation FAILED: Outputs do not match!")
130
+ print(f" Max difference: {np.max(np.abs(pytorch_output - onnx_output))}")
131
+
132
+ return is_close
133
+
134
+
135
+ def predict_with_onnx(
136
+ onnx_path: Path,
137
+ image_tensor: np.ndarray
138
+ ) -> Tuple[str, float]:
139
+ """
140
+ Run inference using ONNX Runtime.
141
+
142
+ Args:
143
+ onnx_path: Path to ONNX model
144
+ image_tensor: Preprocessed image as numpy array (1, 3, 224, 224)
145
+
146
+ Returns:
147
+ Tuple of (predicted_class, confidence)
148
+ """
149
+ import onnxruntime as ort
150
+ from .config import CLASS_NAMES
151
+
152
+ # Create session
153
+ ort_session = ort.InferenceSession(str(onnx_path))
154
+
155
+ # Run inference
156
+ logits = ort_session.run(
157
+ None,
158
+ {'image': image_tensor.astype(np.float32)}
159
+ )[0]
160
+
161
+ # Apply sigmoid and get prediction
162
+ prob = 1 / (1 + np.exp(-logits[0, 0])) # Sigmoid
163
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
164
+ confidence = float(prob if prob > 0.5 else 1 - prob)
165
+
166
+ return pred_class, confidence
167
+
168
+
169
+ if __name__ == "__main__":
170
+ # Export model
171
+ print("=" * 50)
172
+ print("EXPORTING MODEL TO ONNX")
173
+ print("=" * 50)
174
+
175
+ onnx_path = export_to_onnx()
176
+
177
+ print("\n" + "=" * 50)
178
+ print("VALIDATING ONNX MODEL")
179
+ print("=" * 50)
180
+
181
+ validate_onnx_model(onnx_path)
182
+
183
+ print("\n" + "=" * 50)
184
+ print("TESTING ONNX INFERENCE")
185
+ print("=" * 50)
186
+
187
+ # Test with random input
188
+ test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
189
+ pred_class, confidence = predict_with_onnx(onnx_path, test_input)
190
+ print(f"Test prediction: {pred_class} ({confidence:.1%})")
src/gradcam.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grad-CAM visualization for model interpretability.
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from typing import Union
10
+ import matplotlib.pyplot as plt
11
+
12
+ from pytorch_grad_cam import GradCAM
13
+ from pytorch_grad_cam.utils.image import show_cam_on_image
14
+
15
+ from .dataset import get_transforms
16
+ from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
17
+
18
+
19
+ def get_gradcam(model, target_layer=None):
20
+ """Create GradCAM object for the model."""
21
+ if target_layer is None:
22
+ # Use the last conv layer of EfficientNet
23
+ target_layer = model.backbone.features[-1]
24
+ return GradCAM(model=model, target_layers=[target_layer])
25
+
26
+
27
+ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
28
+ """Denormalize tensor to numpy image [0,1]."""
29
+ mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
30
+ std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
31
+ img = tensor.cpu() * std + mean
32
+ img = img.permute(1, 2, 0).numpy()
33
+ return np.clip(img, 0, 1)
34
+
35
+
36
+ def generate_gradcam(
37
+ model,
38
+ image: Union[str, Path, Image.Image],
39
+ device: torch.device
40
+ ) -> tuple:
41
+ """Generate Grad-CAM heatmap for an image."""
42
+ model.eval()
43
+
44
+ # Load and transform image
45
+ if isinstance(image, (str, Path)):
46
+ image = Image.open(image).convert('RGB')
47
+
48
+ transform = get_transforms(is_training=False)
49
+ img_tensor = transform(image).unsqueeze(0).to(device)
50
+
51
+ # Get prediction
52
+ with torch.no_grad():
53
+ output = model(img_tensor)
54
+ prob = torch.sigmoid(output).item()
55
+
56
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
57
+ confidence = prob if prob > 0.5 else 1 - prob
58
+
59
+ # Generate Grad-CAM
60
+ cam = get_gradcam(model)
61
+ grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0]
62
+
63
+ # Create visualization
64
+ rgb_img = denormalize_image(img_tensor[0])
65
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
66
+
67
+ return cam_image, pred_class, confidence, rgb_img
68
+
69
+
70
+ def plot_gradcam(
71
+ model,
72
+ image_path: Union[str, Path],
73
+ true_label: str,
74
+ device: torch.device,
75
+ save_path: str = None
76
+ ):
77
+ """Plot original image with Grad-CAM overlay."""
78
+ cam_image, pred_class, confidence, original = generate_gradcam(model, image_path, device)
79
+
80
+ fig, axes = plt.subplots(1, 2, figsize=(10, 4))
81
+
82
+ # Original
83
+ axes[0].imshow(original)
84
+ axes[0].set_title(f"Original\nTrue: {true_label}")
85
+ axes[0].axis('off')
86
+
87
+ # Grad-CAM
88
+ color = 'green' if pred_class == true_label else 'red'
89
+ axes[1].imshow(cam_image)
90
+ axes[1].set_title(f"Grad-CAM\nPred: {pred_class} ({confidence:.1%})", color=color)
91
+ axes[1].axis('off')
92
+
93
+ plt.tight_layout()
94
+
95
+ if save_path:
96
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
97
+
98
+ plt.show()
99
+ return pred_class, confidence
100
+
101
+
102
+ def plot_gradcam_grid(
103
+ model,
104
+ image_paths: list,
105
+ true_labels: list,
106
+ device: torch.device,
107
+ save_path: str = None,
108
+ title: str = "Grad-CAM Visualizations"
109
+ ):
110
+ """Plot grid of Grad-CAM visualizations."""
111
+ n = len(image_paths)
112
+ fig, axes = plt.subplots(n, 2, figsize=(8, 3 * n))
113
+
114
+ if n == 1:
115
+ axes = axes.reshape(1, -1)
116
+
117
+ for i, (path, true_label) in enumerate(zip(image_paths, true_labels)):
118
+ cam_image, pred_class, confidence, original = generate_gradcam(model, path, device)
119
+
120
+ # Original
121
+ axes[i, 0].imshow(original)
122
+ axes[i, 0].set_title(f"True: {true_label}")
123
+ axes[i, 0].axis('off')
124
+
125
+ # Grad-CAM
126
+ color = 'green' if pred_class == true_label else 'red'
127
+ axes[i, 1].imshow(cam_image)
128
+ axes[i, 1].set_title(f"Pred: {pred_class} ({confidence:.1%})", color=color)
129
+ axes[i, 1].axis('off')
130
+
131
+ plt.suptitle(title, fontsize=14, fontweight='bold')
132
+ plt.tight_layout()
133
+
134
+ if save_path:
135
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
136
+
137
+ plt.show()
src/model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EfficientNet-B0 model for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+ from typing import Tuple
9
+
10
+ from .config import DROPOUT_RATE, NUM_CLASSES
11
+
12
+
13
+ class PneumoniaClassifier(nn.Module):
14
+ """EfficientNet-B0 based classifier for chest X-ray pneumonia detection."""
15
+
16
+ def __init__(
17
+ self,
18
+ pretrained: bool = True,
19
+ dropout_rate: float = DROPOUT_RATE,
20
+ freeze_backbone: bool = True
21
+ ):
22
+ super().__init__()
23
+
24
+ # Load pretrained EfficientNet-B0
25
+ weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
26
+ self.backbone = models.efficientnet_b0(weights=weights)
27
+
28
+ # Get the number of features from the classifier
29
+ in_features = self.backbone.classifier[1].in_features # 1280
30
+
31
+ # Replace classifier head
32
+ self.backbone.classifier = nn.Sequential(
33
+ nn.Dropout(p=dropout_rate, inplace=True),
34
+ nn.Linear(in_features, NUM_CLASSES)
35
+ )
36
+
37
+ # Freeze backbone if specified
38
+ if freeze_backbone:
39
+ self.freeze_backbone()
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return self.backbone(x)
43
+
44
+ def freeze_backbone(self):
45
+ """Freeze all layers except the classifier."""
46
+ for param in self.backbone.features.parameters():
47
+ param.requires_grad = False
48
+
49
+ def unfreeze_backbone(self):
50
+ """Unfreeze all layers for fine-tuning."""
51
+ for param in self.backbone.features.parameters():
52
+ param.requires_grad = True
53
+
54
+ def get_param_counts(self) -> Tuple[int, int]:
55
+ """Return (trainable_params, total_params)."""
56
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
57
+ total = sum(p.numel() for p in self.parameters())
58
+ return trainable, total
59
+
60
+
61
+ def create_model(
62
+ pretrained: bool = True,
63
+ dropout_rate: float = DROPOUT_RATE,
64
+ freeze_backbone: bool = True,
65
+ device: str = None
66
+ ) -> PneumoniaClassifier:
67
+ """Factory function to create the model."""
68
+ if device is None:
69
+ device = "mps" if torch.backends.mps.is_available() else \
70
+ "cuda" if torch.cuda.is_available() else "cpu"
71
+
72
+ model = PneumoniaClassifier(
73
+ pretrained=pretrained,
74
+ dropout_rate=dropout_rate,
75
+ freeze_backbone=freeze_backbone
76
+ )
77
+
78
+ return model.to(device)
79
+
80
+
81
+ def get_device() -> torch.device:
82
+ """Get the best available device."""
83
+ if torch.backends.mps.is_available():
84
+ return torch.device("mps")
85
+ elif torch.cuda.is_available():
86
+ return torch.device("cuda")
87
+ return torch.device("cpu")
src/predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference functions for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from typing import Union, Tuple
10
+
11
+ from .dataset import get_transforms
12
+ from .config import CLASS_NAMES, CHECKPOINT_PATH
13
+
14
+
15
+ def load_model(model: nn.Module, checkpoint_path: Path = CHECKPOINT_PATH, device: str = "cpu") -> nn.Module:
16
+ """Load model from checkpoint."""
17
+ checkpoint = torch.load(checkpoint_path, map_location=device)
18
+ model.load_state_dict(checkpoint['model_state_dict'])
19
+ model.eval()
20
+ return model
21
+
22
+
23
+ def predict_image(
24
+ model: nn.Module,
25
+ image: Union[str, Path, Image.Image],
26
+ device: torch.device
27
+ ) -> Tuple[str, float]:
28
+ """Predict class for a single image."""
29
+ model.eval()
30
+
31
+ # Load image if path
32
+ if isinstance(image, (str, Path)):
33
+ image = Image.open(image).convert('RGB')
34
+
35
+ # Transform
36
+ transform = get_transforms(is_training=False)
37
+ img_tensor = transform(image).unsqueeze(0).to(device)
38
+
39
+ # Predict
40
+ with torch.no_grad():
41
+ output = model(img_tensor)
42
+ prob = torch.sigmoid(output).item()
43
+
44
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
45
+ confidence = prob if prob > 0.5 else 1 - prob
46
+
47
+ return pred_class, confidence
src/train.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training pipeline for Pneumonia classification.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import AdamW
8
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
9
+ from torch.utils.data import DataLoader
10
+ from pathlib import Path
11
+ from typing import Dict, Optional, Tuple
12
+ import time
13
+
14
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
+
16
+ from .config import (
17
+ STAGE1_EPOCHS, STAGE1_LR, STAGE2_EPOCHS, STAGE2_LR,
18
+ WEIGHT_DECAY, SCHEDULER_PATIENCE, SCHEDULER_FACTOR,
19
+ EARLY_STOP_PATIENCE, CHECKPOINT_PATH, MODEL_DIR
20
+ )
21
+ from .model import PneumoniaClassifier, get_device
22
+
23
+
24
+ class EarlyStopping:
25
+ """Early stopping to prevent overfitting."""
26
+
27
+ def __init__(self, patience: int = 7, min_delta: float = 0.001):
28
+ self.patience = patience
29
+ self.min_delta = min_delta
30
+ self.counter = 0
31
+ self.best_loss = float('inf')
32
+ self.should_stop = False
33
+
34
+ def __call__(self, val_loss: float) -> bool:
35
+ if val_loss < self.best_loss - self.min_delta:
36
+ self.best_loss = val_loss
37
+ self.counter = 0
38
+ else:
39
+ self.counter += 1
40
+ if self.counter >= self.patience:
41
+ self.should_stop = True
42
+ return self.should_stop
43
+
44
+
45
+ def train_epoch(
46
+ model: nn.Module,
47
+ loader: DataLoader,
48
+ criterion: nn.Module,
49
+ optimizer: torch.optim.Optimizer,
50
+ device: torch.device
51
+ ) -> Tuple[float, float]:
52
+ """Train for one epoch."""
53
+ model.train()
54
+ total_loss = 0
55
+ all_preds, all_labels = [], []
56
+
57
+ for images, labels in loader:
58
+ images = images.to(device)
59
+ labels = labels.float().unsqueeze(1).to(device)
60
+
61
+ optimizer.zero_grad()
62
+ outputs = model(images)
63
+ loss = criterion(outputs, labels)
64
+ loss.backward()
65
+ optimizer.step()
66
+
67
+ total_loss += loss.item() * images.size(0)
68
+ preds = (torch.sigmoid(outputs) > 0.5).int()
69
+ all_preds.extend(preds.cpu().numpy())
70
+ all_labels.extend(labels.cpu().numpy())
71
+
72
+ avg_loss = total_loss / len(loader.dataset)
73
+ accuracy = accuracy_score(all_labels, all_preds)
74
+ return avg_loss, accuracy
75
+
76
+
77
+ def validate(
78
+ model: nn.Module,
79
+ loader: DataLoader,
80
+ criterion: nn.Module,
81
+ device: torch.device
82
+ ) -> Dict[str, float]:
83
+ """Validate the model."""
84
+ model.eval()
85
+ total_loss = 0
86
+ all_preds, all_labels = [], []
87
+
88
+ with torch.no_grad():
89
+ for images, labels in loader:
90
+ images = images.to(device)
91
+ labels = labels.float().unsqueeze(1).to(device)
92
+
93
+ outputs = model(images)
94
+ loss = criterion(outputs, labels)
95
+
96
+ total_loss += loss.item() * images.size(0)
97
+ preds = (torch.sigmoid(outputs) > 0.5).int()
98
+ all_preds.extend(preds.cpu().numpy())
99
+ all_labels.extend(labels.cpu().numpy())
100
+
101
+ avg_loss = total_loss / len(loader.dataset)
102
+
103
+ return {
104
+ 'loss': avg_loss,
105
+ 'accuracy': accuracy_score(all_labels, all_preds),
106
+ 'precision': precision_score(all_labels, all_preds, zero_division=0),
107
+ 'recall': recall_score(all_labels, all_preds, zero_division=0),
108
+ 'f1': f1_score(all_labels, all_preds, zero_division=0)
109
+ }
110
+
111
+
112
+ def train(
113
+ model: PneumoniaClassifier,
114
+ train_loader: DataLoader,
115
+ val_loader: DataLoader,
116
+ pos_weight: torch.Tensor,
117
+ epochs: int,
118
+ lr: float,
119
+ device: torch.device,
120
+ stage: str = "stage1",
121
+ use_wandb: bool = True,
122
+ wandb_run = None
123
+ ) -> Dict[str, list]:
124
+ """Training loop with validation."""
125
+
126
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
127
+ optimizer = AdamW(
128
+ filter(lambda p: p.requires_grad, model.parameters()),
129
+ lr=lr,
130
+ weight_decay=WEIGHT_DECAY
131
+ )
132
+ scheduler = ReduceLROnPlateau(
133
+ optimizer, mode='min',
134
+ patience=SCHEDULER_PATIENCE,
135
+ factor=SCHEDULER_FACTOR
136
+ )
137
+ early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE)
138
+
139
+ history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'val_f1': [], 'lr': []}
140
+ best_val_loss = float('inf')
141
+
142
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
143
+
144
+ for epoch in range(epochs):
145
+ start = time.time()
146
+
147
+ # Train
148
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
149
+
150
+ # Validate
151
+ val_metrics = validate(model, val_loader, criterion, device)
152
+
153
+ # Get current LR
154
+ current_lr = optimizer.param_groups[0]['lr']
155
+
156
+ # Update scheduler
157
+ scheduler.step(val_metrics['loss'])
158
+
159
+ # Log
160
+ elapsed = time.time() - start
161
+ print(f"[{stage}] Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) | "
162
+ f"Train Loss: {train_loss:.4f} | "
163
+ f"Val Loss: {val_metrics['loss']:.4f} | "
164
+ f"Val Acc: {val_metrics['accuracy']:.3f} | "
165
+ f"Val F1: {val_metrics['f1']:.3f} | "
166
+ f"LR: {current_lr:.2e}")
167
+
168
+ # W&B logging
169
+ if use_wandb and wandb_run:
170
+ wandb_run.log({
171
+ f"{stage}/train_loss": train_loss,
172
+ f"{stage}/train_acc": train_acc,
173
+ f"{stage}/val_loss": val_metrics['loss'],
174
+ f"{stage}/val_acc": val_metrics['accuracy'],
175
+ f"{stage}/val_precision": val_metrics['precision'],
176
+ f"{stage}/val_recall": val_metrics['recall'],
177
+ f"{stage}/val_f1": val_metrics['f1'],
178
+ f"{stage}/lr": current_lr,
179
+ "epoch": epoch + 1
180
+ })
181
+
182
+ # Save history
183
+ history['train_loss'].append(train_loss)
184
+ history['val_loss'].append(val_metrics['loss'])
185
+ history['val_acc'].append(val_metrics['accuracy'])
186
+ history['val_f1'].append(val_metrics['f1'])
187
+ history['lr'].append(current_lr)
188
+
189
+ # Save best model
190
+ if val_metrics['loss'] < best_val_loss:
191
+ best_val_loss = val_metrics['loss']
192
+ torch.save({
193
+ 'epoch': epoch + 1,
194
+ 'model_state_dict': model.state_dict(),
195
+ 'optimizer_state_dict': optimizer.state_dict(),
196
+ 'val_loss': best_val_loss,
197
+ 'val_metrics': val_metrics
198
+ }, CHECKPOINT_PATH)
199
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
200
+
201
+ # Early stopping
202
+ if early_stopping(val_metrics['loss']):
203
+ print(f"Early stopping triggered at epoch {epoch+1}")
204
+ break
205
+
206
+ return history
207
+
208
+
209
+ def train_two_stage(
210
+ model: PneumoniaClassifier,
211
+ train_loader: DataLoader,
212
+ val_loader: DataLoader,
213
+ pos_weight: torch.Tensor,
214
+ device: torch.device,
215
+ use_wandb: bool = True,
216
+ wandb_run = None
217
+ ) -> Dict[str, list]:
218
+ """Two-stage training: frozen backbone then fine-tuning."""
219
+
220
+ # Stage 1: Train classifier only
221
+ print("\n" + "=" * 60)
222
+ print("STAGE 1: Training classifier (backbone frozen)")
223
+ print("=" * 60)
224
+ model.freeze_backbone()
225
+ trainable, total = model.get_param_counts()
226
+ print(f"Trainable params: {trainable:,} / {total:,}")
227
+
228
+ history1 = train(
229
+ model, train_loader, val_loader, pos_weight,
230
+ epochs=STAGE1_EPOCHS, lr=STAGE1_LR, device=device,
231
+ stage="stage1", use_wandb=use_wandb, wandb_run=wandb_run
232
+ )
233
+
234
+ # Stage 2: Fine-tune entire network
235
+ print("\n" + "=" * 60)
236
+ print("STAGE 2: Fine-tuning entire network")
237
+ print("=" * 60)
238
+ model.unfreeze_backbone()
239
+ trainable, total = model.get_param_counts()
240
+ print(f"Trainable params: {trainable:,} / {total:,}")
241
+
242
+ history2 = train(
243
+ model, train_loader, val_loader, pos_weight,
244
+ epochs=STAGE2_EPOCHS, lr=STAGE2_LR, device=device,
245
+ stage="stage2", use_wandb=use_wandb, wandb_run=wandb_run
246
+ )
247
+
248
+ # Combine histories
249
+ history = {k: history1[k] + history2[k] for k in history1}
250
+ return history
src/utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for visualization and helpers.
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from typing import Optional
9
+
10
+ from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES
11
+
12
+
13
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
14
+ """Denormalize image tensor from ImageNet normalization."""
15
+ mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
16
+ std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
17
+ return tensor * std + mean
18
+
19
+
20
+ def show_batch(
21
+ images: torch.Tensor,
22
+ labels: torch.Tensor,
23
+ predictions: Optional[torch.Tensor] = None,
24
+ n_images: int = 8,
25
+ save_path: Optional[str] = None
26
+ ):
27
+ """Display a batch of images with labels."""
28
+ n_images = min(n_images, len(images))
29
+ cols = 4
30
+ rows = (n_images + cols - 1) // cols
31
+
32
+ fig, axes = plt.subplots(rows, cols, figsize=(12, 3 * rows))
33
+ axes = axes.flatten() if rows > 1 else [axes] if cols == 1 else axes
34
+
35
+ for idx in range(n_images):
36
+ img = denormalize(images[idx]).permute(1, 2, 0).numpy()
37
+ img = np.clip(img, 0, 1)
38
+
39
+ axes[idx].imshow(img)
40
+ axes[idx].axis('off')
41
+
42
+ label = CLASS_NAMES[labels[idx]]
43
+ title = f"True: {label}"
44
+
45
+ if predictions is not None:
46
+ pred = CLASS_NAMES[predictions[idx]]
47
+ color = 'green' if pred == label else 'red'
48
+ title += f"\nPred: {pred}"
49
+ axes[idx].set_title(title, color=color, fontsize=10)
50
+ else:
51
+ axes[idx].set_title(title, fontsize=10)
52
+
53
+ # Hide empty subplots
54
+ for idx in range(n_images, len(axes)):
55
+ axes[idx].axis('off')
56
+
57
+ plt.tight_layout()
58
+
59
+ if save_path:
60
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
61
+
62
+ plt.show()
63
+
64
+
65
+ def set_seed(seed: int = 42):
66
+ """Set random seed for reproducibility."""
67
+ import random
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ if torch.cuda.is_available():
72
+ torch.cuda.manual_seed_all(seed)
73
+ if torch.backends.mps.is_available():
74
+ torch.mps.manual_seed(seed)