Ameya729 commited on
Commit
5ab3ef6
Β·
verified Β·
1 Parent(s): 56ec9ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -345
app.py CHANGED
@@ -1,345 +1,345 @@
1
- """
2
- Streamlit Application for Automated Tablet Defect Detection
3
- """
4
-
5
- import streamlit as st
6
- import torch
7
- import numpy as np
8
- from PIL import Image
9
- import sys
10
- from pathlib import Path
11
- import io
12
-
13
- # Add parent directory to path
14
- sys.path.append(str(Path(__file__).parent.parent))
15
-
16
- import config
17
- from src.feature_extractor import FeatureExtractor, extract_embeddings
18
- from src.padim import PaDiM
19
- from src.visualize import apply_heatmap
20
-
21
-
22
- @st.cache_resource
23
- def load_model():
24
- """Load PaDiM model and feature extractor (cached)"""
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
-
27
- # Load PaDiM model
28
- model_path = config.MODEL_DIR / "padim_model.pkl"
29
-
30
- if not model_path.exists():
31
- st.error("❌ Model file not found. Please train the model first.")
32
- st.info("To train the model, run: `python train.py` in your terminal")
33
- st.stop()
34
-
35
- padim_model = PaDiM()
36
- padim_model.load(model_path)
37
-
38
- # Load feature extractor
39
- extractor = FeatureExtractor(
40
- backbone=config.BACKBONE,
41
- layers=config.FEATURE_LAYERS
42
- ).to(device)
43
-
44
- return padim_model, extractor, device
45
-
46
-
47
- def preprocess_image(image: Image.Image) -> torch.Tensor:
48
- """Preprocess uploaded image"""
49
- from torchvision import transforms
50
-
51
- transform = transforms.Compose([
52
- transforms.Resize(config.IMAGE_SIZE),
53
- transforms.ToTensor(),
54
- transforms.Normalize(mean=config.MEAN, std=config.STD)
55
- ])
56
-
57
- return transform(image).unsqueeze(0) # Add batch dimension
58
-
59
-
60
- def predict_defect(image: Image.Image, padim_model, extractor, device):
61
- """Run inference on uploaded image"""
62
-
63
- # Preprocess
64
- img_tensor = preprocess_image(image).to(device)
65
-
66
- # Extract embeddings
67
- with torch.no_grad():
68
- embeddings = extract_embeddings(extractor, img_tensor)
69
-
70
- # Predict
71
- embeddings_np = embeddings.cpu().numpy()
72
- anomaly_score, anomaly_map = padim_model.predict(embeddings_np)
73
-
74
- return anomaly_score, anomaly_map
75
-
76
-
77
- def main():
78
- """Main Streamlit app"""
79
-
80
- # Page configuration
81
- st.set_page_config(
82
- page_title="Tablet Defect Detection",
83
- page_icon="πŸ’Š",
84
- layout="wide",
85
- initial_sidebar_state="expanded"
86
- )
87
-
88
- # Custom CSS
89
- st.markdown("""
90
- <style>
91
- .main-header {
92
- font-size: 2.5rem;
93
- font-weight: 700;
94
- color: #1f77b4;
95
- text-align: center;
96
- margin-bottom: 1rem;
97
- }
98
- .subtitle {
99
- text-align: center;
100
- color: #666;
101
- margin-bottom: 2rem;
102
- }
103
- .metric-card {
104
- background-color: #f0f2f6;
105
- padding: 1rem;
106
- border-radius: 0.5rem;
107
- margin: 0.5rem 0;
108
- }
109
- .defect-alert {
110
- background-color: #ffebee;
111
- color: #c62828;
112
- padding: 1rem;
113
- border-radius: 0.5rem;
114
- border-left: 4px solid #c62828;
115
- font-weight: 600;
116
- }
117
- .normal-alert {
118
- background-color: #e8f5e9;
119
- color: #2e7d32;
120
- padding: 1rem;
121
- border-radius: 0.5rem;
122
- border-left: 4px solid #2e7d32;
123
- font-weight: 600;
124
- }
125
- </style>
126
- """, unsafe_allow_html=True)
127
-
128
- # Header
129
- st.markdown('<div class="main-header">πŸ’Š Automated Tablet Defect Detection</div>',
130
- unsafe_allow_html=True)
131
- st.markdown('<div class="subtitle">Unsupervised Computer Vision Quality Inspection System</div>',
132
- unsafe_allow_html=True)
133
-
134
- # Sidebar
135
- with st.sidebar:
136
- st.image("https://img.icons8.com/fluency/96/pill.png", width=80)
137
- st.title("βš™οΈ Settings")
138
-
139
- threshold = st.slider(
140
- "Anomaly Threshold",
141
- min_value=0.0,
142
- max_value=2.0,
143
- value=0.5,
144
- step=0.05,
145
- help="Adjust sensitivity: lower = more sensitive to defects"
146
- )
147
-
148
- show_heatmap = st.checkbox("Show Anomaly Heatmap", value=True)
149
- heatmap_alpha = st.slider("Heatmap Opacity", 0.0, 1.0, 0.4, 0.05)
150
-
151
- st.divider()
152
- st.subheader("πŸ“Š Model Info")
153
- st.markdown(f"""
154
- - **Method:** PaDiM
155
- - **Backbone:** ResNet-18
156
- - **Layers:** {', '.join(config.FEATURE_LAYERS)}
157
- - **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
158
- """)
159
-
160
- st.divider()
161
- st.subheader("ℹ️ About")
162
- st.markdown("""
163
- This system uses **PaDiM** (Patch Distribution Modeling) for
164
- unsupervised anomaly detection in pharmaceutical tablets.
165
-
166
- **Features:**
167
- - βœ… Image-level defect classification
168
- - 🎯 Pixel-level defect localization
169
- - πŸ“ˆ Anomaly score quantification
170
- - πŸš€ CPU-friendly inference
171
- """)
172
-
173
-
174
- # Load model
175
- with st.spinner("Loading model..."):
176
- padim_model, extractor, device = load_model()
177
-
178
- # Main content
179
- st.divider()
180
-
181
- # File uploader
182
- uploaded_file = st.file_uploader(
183
- "Upload a tablet image for inspection",
184
- type=["png", "jpg", "jpeg"],
185
- help="Supported formats: PNG, JPG, JPEG"
186
- )
187
-
188
- # Demo images section
189
- col1, col2 = st.columns([3, 1])
190
- with col2:
191
- use_demo = st.button("🎲 Try Demo Image")
192
-
193
- if use_demo:
194
- # Load a random test image
195
- demo_dir = config.TEST_DIR / "good"
196
- demo_images = list(demo_dir.glob("*.png"))
197
- if demo_images:
198
- demo_path = np.random.choice(demo_images)
199
- uploaded_file = demo_path
200
-
201
- if uploaded_file is not None:
202
- # Load image
203
- if isinstance(uploaded_file, Path):
204
- image = Image.open(uploaded_file).convert("RGB")
205
- else:
206
- image = Image.open(uploaded_file).convert("RGB")
207
-
208
- # Display original image
209
- st.subheader("πŸ“Έ Uploaded Image")
210
- col1, col2, col3 = st.columns([1, 2, 1])
211
- with col2:
212
- st.image(image, use_container_width=True)
213
-
214
- # Run inference
215
- with st.spinner("πŸ” Analyzing image..."):
216
- anomaly_score, anomaly_map = predict_defect(
217
- image, padim_model, extractor, device
218
- )
219
-
220
- # Display results
221
- st.divider()
222
- st.subheader("🎯 Inspection Results")
223
-
224
- # Prediction
225
- is_defective = anomaly_score > threshold
226
-
227
- if is_defective:
228
- st.markdown(f"""
229
- <div class="defect-alert">
230
- ⚠️ DEFECTIVE TABLET DETECTED
231
- </div>
232
- """, unsafe_allow_html=True)
233
- else:
234
- st.markdown(f"""
235
- <div class="normal-alert">
236
- βœ… NORMAL TABLET (No Defects)
237
- </div>
238
- """, unsafe_allow_html=True)
239
-
240
- # Metrics
241
- col1, col2, col3 = st.columns(3)
242
-
243
- with col1:
244
- st.metric(
245
- label="Anomaly Score",
246
- value=f"{anomaly_score:.4f}",
247
- delta="Defect" if is_defective else "Normal",
248
- delta_color="inverse"
249
- )
250
-
251
- with col2:
252
- st.metric(
253
- label="Threshold",
254
- value=f"{threshold:.3f}",
255
- delta=f"{(anomaly_score/threshold - 1)*100:+.1f}%" if threshold > 0 else "N/A"
256
- )
257
-
258
- with col3:
259
- confidence = abs(anomaly_score - threshold) / threshold if threshold > 0 else 0
260
- st.metric(
261
- label="Confidence",
262
- value=f"{min(confidence * 100, 100):.1f}%"
263
- )
264
-
265
- # Heatmap visualization
266
- if show_heatmap:
267
- st.divider()
268
- st.subheader("πŸ”₯ Anomaly Heatmap")
269
- st.markdown("*Highlighted regions indicate potential defects*")
270
-
271
- # Create heatmap overlay
272
- img_np = np.array(image)
273
- heatmap_overlay = apply_heatmap(
274
- img_np,
275
- anomaly_map,
276
- alpha=heatmap_alpha,
277
- colormap=config.HEATMAP_COLORMAP
278
- )
279
-
280
- # Display side by side
281
- col1, col2 = st.columns(2)
282
-
283
- with col1:
284
- st.image(image, caption="Original", use_container_width=True)
285
-
286
- with col2:
287
- st.image(heatmap_overlay, caption="Defect Localization",
288
- use_container_width=True)
289
-
290
- # Download results
291
- st.divider()
292
-
293
- if st.button("πŸ’Ύ Download Results"):
294
- # Create annotated image
295
- img_np = np.array(image)
296
- result_img = apply_heatmap(img_np, anomaly_map, alpha=heatmap_alpha)
297
-
298
- # Add text annotation
299
- import cv2
300
- prediction_text = "DEFECTIVE" if is_defective else "NORMAL"
301
- color = (255, 0, 0) if is_defective else (0, 255, 0)
302
- cv2.putText(result_img, f"{prediction_text} ({anomaly_score:.3f})",
303
- (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
304
- 1, color, 2, cv2.LINE_AA)
305
-
306
- # Convert to bytes
307
- result_pil = Image.fromarray(result_img)
308
- buf = io.BytesIO()
309
- result_pil.save(buf, format="PNG")
310
-
311
- st.download_button(
312
- label="⬇️ Download Annotated Image",
313
- data=buf.getvalue(),
314
- file_name="defect_detection_result.png",
315
- mime="image/png"
316
- )
317
-
318
- else:
319
- # Instructions when no image uploaded
320
- st.info("πŸ‘† Please upload an image or click 'Try Demo Image' to start inspection.")
321
-
322
- # Example gallery
323
- st.divider()
324
- st.subheader("πŸ“š Example Defect Types")
325
-
326
- cols = st.columns(5)
327
- defect_examples = {
328
- "Normal": config.TEST_DIR / "good",
329
- "Crack": config.TEST_DIR / "crack",
330
- "Poke": config.TEST_DIR / "poke",
331
- "Scratch": config.TEST_DIR / "scratch",
332
- "Squeeze": config.TEST_DIR / "squeeze"
333
- }
334
-
335
- for idx, (defect_name, defect_dir) in enumerate(defect_examples.items()):
336
- if defect_dir.exists():
337
- images = list(defect_dir.glob("*.png"))
338
- if images:
339
- with cols[idx % 5]:
340
- example_img = Image.open(images[0])
341
- st.image(example_img, caption=defect_name, use_container_width=True)
342
-
343
-
344
- if __name__ == "__main__":
345
- main()
 
1
+ """
2
+ Streamlit Application for Automated Tablet Defect Detection
3
+ """
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ import sys
10
+ from pathlib import Path
11
+ import io
12
+
13
+ # Add parent directory to path
14
+ sys.path.append(str(Path(__file__).parent.parent))
15
+
16
+ import config
17
+ from src.feature_extractor import FeatureExtractor, extract_embeddings
18
+ from src.padim import PaDiM
19
+ from src.visualize import apply_heatmap
20
+
21
+
22
+ @st.cache_resource
23
+ def load_model():
24
+ """Load PaDiM model and feature extractor (cached)"""
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # Load PaDiM model
28
+ model_path = config.MODEL_DIR / "padim_model.pkl"
29
+
30
+ if not model_path.exists():
31
+ st.error("❌ Model file not found. Please train the model first.")
32
+ st.info("To train the model, run: `python train.py` in your terminal")
33
+ st.stop()
34
+
35
+ padim_model = PaDiM()
36
+ padim_model.load(model_path)
37
+
38
+ # Load feature extractor
39
+ extractor = FeatureExtractor(
40
+ backbone=config.BACKBONE,
41
+ layers=config.FEATURE_LAYERS
42
+ ).to(device)
43
+
44
+ return padim_model, extractor, device
45
+
46
+
47
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
48
+ """Preprocess uploaded image"""
49
+ from torchvision import transforms
50
+
51
+ transform = transforms.Compose([
52
+ transforms.Resize(config.IMAGE_SIZE),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=config.MEAN, std=config.STD)
55
+ ])
56
+
57
+ return transform(image).unsqueeze(0) # Add batch dimension
58
+
59
+
60
+ def predict_defect(image: Image.Image, padim_model, extractor, device):
61
+ """Run inference on uploaded image"""
62
+
63
+ # Preprocess
64
+ img_tensor = preprocess_image(image).to(device)
65
+
66
+ # Extract embeddings
67
+ with torch.no_grad():
68
+ embeddings = extract_embeddings(extractor, img_tensor)
69
+
70
+ # Predict
71
+ embeddings_np = embeddings.cpu().numpy()
72
+ anomaly_score, anomaly_map = padim_model.predict(embeddings_np)
73
+
74
+ return anomaly_score, anomaly_map
75
+
76
+
77
+ def main():
78
+ """Main Streamlit app"""
79
+
80
+ # Page configuration
81
+ st.set_page_config(
82
+ page_title="Tablet Defect Detection",
83
+ page_icon="πŸ’Š",
84
+ layout="wide",
85
+ initial_sidebar_state="expanded"
86
+ )
87
+
88
+ # Custom CSS
89
+ st.markdown("""
90
+ <style>
91
+ .main-header {
92
+ font-size: 2.5rem;
93
+ font-weight: 700;
94
+ color: #1f77b4;
95
+ text-align: center;
96
+ margin-bottom: 1rem;
97
+ }
98
+ .subtitle {
99
+ text-align: center;
100
+ color: #666;
101
+ margin-bottom: 2rem;
102
+ }
103
+ .metric-card {
104
+ background-color: #f0f2f6;
105
+ padding: 1rem;
106
+ border-radius: 0.5rem;
107
+ margin: 0.5rem 0;
108
+ }
109
+ .defect-alert {
110
+ background-color: #ffebee;
111
+ color: #c62828;
112
+ padding: 1rem;
113
+ border-radius: 0.5rem;
114
+ border-left: 4px solid #c62828;
115
+ font-weight: 600;
116
+ }
117
+ .normal-alert {
118
+ background-color: #e8f5e9;
119
+ color: #2e7d32;
120
+ padding: 1rem;
121
+ border-radius: 0.5rem;
122
+ border-left: 4px solid #2e7d32;
123
+ font-weight: 600;
124
+ }
125
+ </style>
126
+ """, unsafe_allow_html=True)
127
+
128
+ # Header
129
+ st.markdown('<div class="main-header">πŸ’Š Automated Tablet Defect Detection</div>',
130
+ unsafe_allow_html=True)
131
+ st.markdown('<div class="subtitle">Unsupervised Computer Vision Quality Inspection System</div>',
132
+ unsafe_allow_html=True)
133
+
134
+ # Sidebar
135
+ with st.sidebar:
136
+ st.image("https://img.icons8.com/fluency/96/pill.png", width=80)
137
+ st.title("βš™οΈ Settings")
138
+
139
+ threshold = st.slider(
140
+ "Anomaly Threshold",
141
+ min_value=0.0,
142
+ max_value=2.0,
143
+ value=0.5,
144
+ step=0.05,
145
+ help="Adjust sensitivity: lower = more sensitive to defects"
146
+ )
147
+
148
+ show_heatmap = st.checkbox("Show Anomaly Heatmap", value=True)
149
+ heatmap_alpha = st.slider("Heatmap Opacity", 0.0, 1.0, 0.4, 0.05)
150
+
151
+ st.divider()
152
+ st.subheader("πŸ“Š Model Info")
153
+ st.markdown(f"""
154
+ - **Method:** PaDiM
155
+ - **Backbone:** ResNet-18
156
+ - **Layers:** {', '.join(config.FEATURE_LAYERS)}
157
+ - **Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}
158
+ """)
159
+
160
+ st.divider()
161
+ st.subheader("ℹ️ About")
162
+ st.markdown("""
163
+ This system uses **PaDiM** (Patch Distribution Modeling) for
164
+ unsupervised anomaly detection in pharmaceutical tablets.
165
+
166
+ **Features:**
167
+ - βœ… Image-level defect classification
168
+ - 🎯 Pixel-level defect localization
169
+ - πŸ“ˆ Anomaly score quantification
170
+ - πŸš€ CPU-friendly inference
171
+ """)
172
+
173
+
174
+ # Load model
175
+ with st.spinner("Loading model..."):
176
+ padim_model, extractor, device = load_model()
177
+
178
+ # Main content
179
+ st.divider()
180
+
181
+ # File uploader
182
+ uploaded_file = st.file_uploader(
183
+ "Upload a tablet image for inspection",
184
+ type=["png", "jpg", "jpeg"],
185
+ help="Supported formats: PNG, JPG, JPEG"
186
+ )
187
+
188
+ # Demo images section
189
+ col1, col2 = st.columns([3, 1])
190
+ with col2:
191
+ use_demo = st.button("🎲 Try Demo Image")
192
+
193
+ if use_demo:
194
+ # Load a random test image
195
+ demo_dir = config.TEST_DIR / "good"
196
+ demo_images = list(demo_dir.glob("*.png"))
197
+ if demo_images:
198
+ demo_path = np.random.choice(demo_images)
199
+ uploaded_file = demo_path
200
+
201
+ if uploaded_file is not None:
202
+ # Load image
203
+ if isinstance(uploaded_file, Path):
204
+ image = Image.open(uploaded_file).convert("RGB")
205
+ else:
206
+ image = Image.open(uploaded_file).convert("RGB")
207
+
208
+ # Display original image
209
+ st.subheader("πŸ“Έ Uploaded Image")
210
+ col1, col2, col3 = st.columns([1, 2, 1])
211
+ with col2:
212
+ st.image(image, use_column_width=True)
213
+
214
+ # Run inference
215
+ with st.spinner("πŸ” Analyzing image..."):
216
+ anomaly_score, anomaly_map = predict_defect(
217
+ image, padim_model, extractor, device
218
+ )
219
+
220
+ # Display results
221
+ st.divider()
222
+ st.subheader("🎯 Inspection Results")
223
+
224
+ # Prediction
225
+ is_defective = anomaly_score > threshold
226
+
227
+ if is_defective:
228
+ st.markdown(f"""
229
+ <div class="defect-alert">
230
+ ⚠️ DEFECTIVE TABLET DETECTED
231
+ </div>
232
+ """, unsafe_allow_html=True)
233
+ else:
234
+ st.markdown(f"""
235
+ <div class="normal-alert">
236
+ βœ… NORMAL TABLET (No Defects)
237
+ </div>
238
+ """, unsafe_allow_html=True)
239
+
240
+ # Metrics
241
+ col1, col2, col3 = st.columns(3)
242
+
243
+ with col1:
244
+ st.metric(
245
+ label="Anomaly Score",
246
+ value=f"{anomaly_score:.4f}",
247
+ delta="Defect" if is_defective else "Normal",
248
+ delta_color="inverse"
249
+ )
250
+
251
+ with col2:
252
+ st.metric(
253
+ label="Threshold",
254
+ value=f"{threshold:.3f}",
255
+ delta=f"{(anomaly_score/threshold - 1)*100:+.1f}%" if threshold > 0 else "N/A"
256
+ )
257
+
258
+ with col3:
259
+ confidence = abs(anomaly_score - threshold) / threshold if threshold > 0 else 0
260
+ st.metric(
261
+ label="Confidence",
262
+ value=f"{min(confidence * 100, 100):.1f}%"
263
+ )
264
+
265
+ # Heatmap visualization
266
+ if show_heatmap:
267
+ st.divider()
268
+ st.subheader("πŸ”₯ Anomaly Heatmap")
269
+ st.markdown("*Highlighted regions indicate potential defects*")
270
+
271
+ # Create heatmap overlay
272
+ img_np = np.array(image)
273
+ heatmap_overlay = apply_heatmap(
274
+ img_np,
275
+ anomaly_map,
276
+ alpha=heatmap_alpha,
277
+ colormap=config.HEATMAP_COLORMAP
278
+ )
279
+
280
+ # Display side by side
281
+ col1, col2 = st.columns(2)
282
+
283
+ with col1:
284
+ st.image(image, caption="Original", use_column_width=True)
285
+
286
+ with col2:
287
+ st.image(heatmap_overlay, caption="Defect Localization",
288
+ use_column_width=True)
289
+
290
+ # Download results
291
+ st.divider()
292
+
293
+ if st.button("πŸ’Ύ Download Results"):
294
+ # Create annotated image
295
+ img_np = np.array(image)
296
+ result_img = apply_heatmap(img_np, anomaly_map, alpha=heatmap_alpha)
297
+
298
+ # Add text annotation
299
+ import cv2
300
+ prediction_text = "DEFECTIVE" if is_defective else "NORMAL"
301
+ color = (255, 0, 0) if is_defective else (0, 255, 0)
302
+ cv2.putText(result_img, f"{prediction_text} ({anomaly_score:.3f})",
303
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
304
+ 1, color, 2, cv2.LINE_AA)
305
+
306
+ # Convert to bytes
307
+ result_pil = Image.fromarray(result_img)
308
+ buf = io.BytesIO()
309
+ result_pil.save(buf, format="PNG")
310
+
311
+ st.download_button(
312
+ label="⬇️ Download Annotated Image",
313
+ data=buf.getvalue(),
314
+ file_name="defect_detection_result.png",
315
+ mime="image/png"
316
+ )
317
+
318
+ else:
319
+ # Instructions when no image uploaded
320
+ st.info("πŸ‘† Please upload an image or click 'Try Demo Image' to start inspection.")
321
+
322
+ # Example gallery
323
+ st.divider()
324
+ st.subheader("πŸ“š Example Defect Types")
325
+
326
+ cols = st.columns(5)
327
+ defect_examples = {
328
+ "Normal": config.TEST_DIR / "good",
329
+ "Crack": config.TEST_DIR / "crack",
330
+ "Poke": config.TEST_DIR / "poke",
331
+ "Scratch": config.TEST_DIR / "scratch",
332
+ "Squeeze": config.TEST_DIR / "squeeze"
333
+ }
334
+
335
+ for idx, (defect_name, defect_dir) in enumerate(defect_examples.items()):
336
+ if defect_dir.exists():
337
+ images = list(defect_dir.glob("*.png"))
338
+ if images:
339
+ with cols[idx % 5]:
340
+ example_img = Image.open(images[0])
341
+ st.image(example_img, caption=defect_name, use_column_width=True)
342
+
343
+
344
+ if __name__ == "__main__":
345
+ main()