Files changed (2) hide show
  1. hf_requirements.txt +17 -0
  2. streamlit_app.py +278 -0
hf_requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces Requirements
2
+ # Minimal dependencies for deployment
3
+
4
+ # PyTorch (CPU version for HF Spaces)
5
+ --extra-index-url https://download.pytorch.org/whl/cpu
6
+ torch==2.1.0
7
+ torchvision==0.16.0
8
+
9
+ # Core
10
+ pillow>=10.0.0
11
+ numpy>=1.24.0
12
+
13
+ # Model Interpretability
14
+ grad-cam>=1.4.0
15
+
16
+ # Web UI
17
+ streamlit>=1.28.0
streamlit_app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit Web UI for Pneumonia Detection.
3
+
4
+ Run with: streamlit run app/streamlit_app.py
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add project root to path
11
+ sys.path.insert(0, str(Path(__file__).parent.parent))
12
+
13
+ import streamlit as st
14
+ import torch
15
+ from PIL import Image
16
+ import time
17
+
18
+ from src.config import CHECKPOINT_PATH, CLASS_NAMES
19
+ from src.model import create_model, get_device
20
+ from src.predict import load_model, predict_image
21
+ from src.gradcam import generate_gradcam
22
+
23
+ # =============================================================================
24
+ # Page Configuration
25
+ # =============================================================================
26
+
27
+ st.set_page_config(
28
+ page_title="Pneumonia Detection",
29
+ page_icon="🫁",
30
+ layout="wide",
31
+ initial_sidebar_state="expanded"
32
+ )
33
+
34
+ # =============================================================================
35
+ # Custom CSS
36
+ # =============================================================================
37
+
38
+ st.markdown("""
39
+ <style>
40
+ .main-header {
41
+ font-size: 2.5rem;
42
+ font-weight: bold;
43
+ color: #1E88E5;
44
+ text-align: center;
45
+ margin-bottom: 0.5rem;
46
+ }
47
+ .sub-header {
48
+ font-size: 1.1rem;
49
+ color: #666;
50
+ text-align: center;
51
+ margin-bottom: 2rem;
52
+ }
53
+ .prediction-box {
54
+ padding: 1.5rem;
55
+ border-radius: 10px;
56
+ text-align: center;
57
+ margin: 1rem 0;
58
+ }
59
+ .prediction-normal {
60
+ background-color: #E8F5E9;
61
+ border: 2px solid #4CAF50;
62
+ }
63
+ .prediction-pneumonia {
64
+ background-color: #FFEBEE;
65
+ border: 2px solid #F44336;
66
+ }
67
+ .confidence-text {
68
+ font-size: 1.2rem;
69
+ font-weight: bold;
70
+ }
71
+ .metric-card {
72
+ background-color: #f8f9fa;
73
+ padding: 1rem;
74
+ border-radius: 8px;
75
+ text-align: center;
76
+ }
77
+ </style>
78
+ """, unsafe_allow_html=True)
79
+
80
+ # =============================================================================
81
+ # Model Loading (Cached)
82
+ # =============================================================================
83
+
84
+ @st.cache_resource
85
+ def load_model_cached():
86
+ """Load model once and cache it."""
87
+ device = get_device()
88
+ model = create_model(pretrained=False, freeze_backbone=False, device=device)
89
+ model = load_model(model, CHECKPOINT_PATH, device)
90
+ return model, device
91
+
92
+
93
+ # =============================================================================
94
+ # Sidebar
95
+ # =============================================================================
96
+
97
+ with st.sidebar:
98
+ st.image("https://img.icons8.com/fluency/96/lungs.png", width=80)
99
+ st.title("About")
100
+
101
+ st.markdown("""
102
+ This application uses deep learning to detect **pneumonia** from chest X-ray images.
103
+
104
+ **Model:** EfficientNet-B0
105
+ **Accuracy:** 90.5%
106
+ **Recall:** 98.2%
107
+ """)
108
+
109
+ st.divider()
110
+
111
+ st.subheader("How to Use")
112
+ st.markdown("""
113
+ 1. Upload a chest X-ray image
114
+ 2. Click **Analyze Image**
115
+ 3. View prediction and Grad-CAM
116
+ """)
117
+
118
+ st.divider()
119
+
120
+ st.subheader("Model Metrics")
121
+ col1, col2 = st.columns(2)
122
+ with col1:
123
+ st.metric("Accuracy", "90.5%")
124
+ st.metric("Precision", "88.0%")
125
+ with col2:
126
+ st.metric("Recall", "98.2%")
127
+ st.metric("F1 Score", "92.8%")
128
+
129
+ st.divider()
130
+
131
+ st.markdown("""
132
+ **Links:**
133
+ [GitHub Repository](#) | [Live Demo](#)
134
+
135
+ ---
136
+ *Built with PyTorch & Streamlit*
137
+ """)
138
+
139
+ # =============================================================================
140
+ # Main Content
141
+ # =============================================================================
142
+
143
+ # Header
144
+ st.markdown('<p class="main-header">🫁 Pneumonia Detection from Chest X-Rays</p>', unsafe_allow_html=True)
145
+ st.markdown('<p class="sub-header">Upload a chest X-ray image to detect pneumonia using AI</p>', unsafe_allow_html=True)
146
+
147
+ # Load model
148
+ try:
149
+ model, device = load_model_cached()
150
+ model_loaded = True
151
+ except Exception as e:
152
+ st.error(f"Failed to load model: {e}")
153
+ model_loaded = False
154
+
155
+ if model_loaded:
156
+ # Create columns for layout
157
+ col1, col2 = st.columns([1, 1])
158
+
159
+ with col1:
160
+ st.subheader("πŸ“€ Upload Image")
161
+
162
+ uploaded_file = st.file_uploader(
163
+ "Choose a chest X-ray image",
164
+ type=["jpg", "jpeg", "png"],
165
+ help="Supported formats: JPG, JPEG, PNG"
166
+ )
167
+
168
+ # Sample images section
169
+ st.markdown("---")
170
+ st.markdown("**Or try a sample image:**")
171
+
172
+ sample_col1, sample_col2 = st.columns(2)
173
+
174
+ use_sample = None
175
+ with sample_col1:
176
+ if st.button("🟒 Normal Sample", width="stretch"):
177
+ use_sample = "normal"
178
+ with sample_col2:
179
+ if st.button("πŸ”΄ Pneumonia Sample", width="stretch"):
180
+ use_sample = "pneumonia"
181
+
182
+ # Load sample image if selected
183
+ if use_sample == "normal":
184
+ sample_path = Path("data/raw/test/NORMAL/IM-0001-0001.jpeg")
185
+ if sample_path.exists():
186
+ uploaded_file = sample_path
187
+ elif use_sample == "pneumonia":
188
+ sample_path = Path("data/raw/test/PNEUMONIA/person1_virus_6.jpeg")
189
+ if sample_path.exists():
190
+ uploaded_file = sample_path
191
+
192
+ with col2:
193
+ st.subheader("πŸ” Analysis Results")
194
+ results_placeholder = st.empty()
195
+
196
+ # Process image if uploaded
197
+ if uploaded_file is not None:
198
+ # Load image
199
+ if isinstance(uploaded_file, Path):
200
+ image = Image.open(uploaded_file).convert("RGB")
201
+ st.session_state['image_source'] = str(uploaded_file)
202
+ else:
203
+ image = Image.open(uploaded_file).convert("RGB")
204
+ st.session_state['image_source'] = uploaded_file.name
205
+
206
+ # Display uploaded image
207
+ with col1:
208
+ st.image(image, caption="Uploaded X-Ray", width="stretch")
209
+
210
+ # Analyze button
211
+ with col1:
212
+ analyze_button = st.button("πŸ”¬ Analyze Image", type="primary", width="stretch")
213
+
214
+ if analyze_button:
215
+ with col2:
216
+ with st.spinner("Analyzing image..."):
217
+ # Run prediction
218
+ start_time = time.time()
219
+ pred_class, confidence = predict_image(model, image, device)
220
+ inference_time = (time.time() - start_time) * 1000
221
+
222
+ # Generate Grad-CAM
223
+ cam_image, _, _, original = generate_gradcam(model, image, device)
224
+
225
+ # Display results
226
+ if pred_class == "PNEUMONIA":
227
+ st.markdown(f"""
228
+ <div class="prediction-box prediction-pneumonia">
229
+ <h2 style="color: #F44336; margin: 0;">⚠️ PNEUMONIA DETECTED</h2>
230
+ <p class="confidence-text">Confidence: {confidence:.1%}</p>
231
+ </div>
232
+ """, unsafe_allow_html=True)
233
+ else:
234
+ st.markdown(f"""
235
+ <div class="prediction-box prediction-normal">
236
+ <h2 style="color: #4CAF50; margin: 0;">βœ… NORMAL</h2>
237
+ <p class="confidence-text">Confidence: {confidence:.1%}</p>
238
+ </div>
239
+ """, unsafe_allow_html=True)
240
+
241
+ # Metrics row
242
+ m1, m2, m3 = st.columns(3)
243
+ with m1:
244
+ st.metric("Prediction", pred_class)
245
+ with m2:
246
+ st.metric("Confidence", f"{confidence:.1%}")
247
+ with m3:
248
+ st.metric("Time", f"{inference_time:.0f}ms")
249
+
250
+ # Grad-CAM visualization
251
+ st.markdown("---")
252
+ st.subheader("πŸ”₯ Grad-CAM Visualization")
253
+ st.caption("Highlighted regions show areas that influenced the prediction")
254
+
255
+ gcol1, gcol2 = st.columns(2)
256
+ with gcol1:
257
+ st.image(original, caption="Original", width="stretch")
258
+ with gcol2:
259
+ st.image(cam_image, caption="Grad-CAM Heatmap", width="stretch")
260
+
261
+ # Disclaimer
262
+ st.warning("""
263
+ **Disclaimer:** This tool is for educational purposes only and should not be used
264
+ for medical diagnosis. Always consult a qualified healthcare professional.
265
+ """)
266
+
267
+ else:
268
+ st.error("Model could not be loaded. Please check the model file exists.")
269
+
270
+ # =============================================================================
271
+ # Footer
272
+ # =============================================================================
273
+
274
+ st.markdown("---")
275
+ st.markdown(
276
+ "<p style='text-align: center; color: #888;'>Built with ❀️ using PyTorch, EfficientNet-B0, and Streamlit</p>",
277
+ unsafe_allow_html=True
278
+ )