Balajim57 commited on
Commit
e8ac7f1
·
verified ·
1 Parent(s): 92c7b0a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +431 -0
app.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ from __future__ import print_function, division, absolute_import
4
+ import streamlit as st
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import transforms
8
+ from PIL import Image, ImageDraw
9
+ from ultralytics import YOLO
10
+ from streamlit_drawable_canvas import st_canvas
11
+ import os
12
+
13
+ # --- Define Basic Components for InceptionResNetV2 ---
14
+ class BasicConv2d(nn.Module):
15
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
16
+ super(BasicConv2d, self).__init__()
17
+ self.conv = nn.Conv2d(in_planes, out_planes,
18
+ kernel_size=kernel_size, stride=stride,
19
+ padding=padding, bias=False)
20
+ self.bn = nn.BatchNorm2d(out_planes)
21
+ self.relu = nn.ReLU(inplace=False)
22
+
23
+ def forward(self, x):
24
+ x = self.conv(x)
25
+ x = self.bn(x)
26
+ x = self.relu(x)
27
+ return x
28
+
29
+ # --- Define InceptionResNetV2 Architecture ---
30
+ class Mixed_5b(nn.Module):
31
+ def __init__(self):
32
+ super(Mixed_5b, self).__init__()
33
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
34
+
35
+ self.branch1 = nn.Sequential(
36
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
37
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
38
+ )
39
+
40
+ self.branch2 = nn.Sequential(
41
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
42
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
43
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
44
+ )
45
+
46
+ self.branch3 = nn.Sequential(
47
+ nn.AvgPool2d(3, stride=1, padding=1),
48
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
49
+ )
50
+
51
+ def forward(self, x):
52
+ x0 = self.branch0(x)
53
+ x1 = self.branch1(x)
54
+ x2 = self.branch2(x)
55
+ x3 = self.branch3(x)
56
+ out = torch.cat((x0, x1, x2, x3), 1)
57
+ return out
58
+
59
+ class Block35(nn.Module):
60
+ def __init__(self, scale=1.0):
61
+ super(Block35, self).__init__()
62
+ self.scale = scale
63
+
64
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
65
+
66
+ self.branch1 = nn.Sequential(
67
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
68
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
69
+ )
70
+
71
+ self.branch2 = nn.Sequential(
72
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
73
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
74
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
75
+ )
76
+
77
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
78
+ self.relu = nn.ReLU(inplace=False)
79
+
80
+ def forward(self, x):
81
+ x0 = self.branch0(x)
82
+ x1 = self.branch1(x)
83
+ x2 = self.branch2(x)
84
+ out = torch.cat((x0, x1, x2), 1)
85
+ out = self.conv2d(out)
86
+ out = out * self.scale + x
87
+ out = self.relu(out)
88
+ return out
89
+
90
+ class Mixed_6a(nn.Module):
91
+ def __init__(self):
92
+ super(Mixed_6a, self).__init__()
93
+
94
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
95
+
96
+ self.branch1 = nn.Sequential(
97
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
98
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
99
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
100
+ )
101
+
102
+ self.branch2 = nn.MaxPool2d(3, stride=2)
103
+
104
+ def forward(self, x):
105
+ x0 = self.branch0(x)
106
+ x1 = self.branch1(x)
107
+ x2 = self.branch2(x)
108
+ out = torch.cat((x0, x1, x2), 1)
109
+ return out
110
+
111
+ class Block17(nn.Module):
112
+ def __init__(self, scale=1.0):
113
+ super(Block17, self).__init__()
114
+ self.scale = scale
115
+
116
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
117
+
118
+ self.branch1 = nn.Sequential(
119
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
120
+ BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
121
+ BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
122
+ )
123
+
124
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
125
+ self.relu = nn.ReLU(inplace=False)
126
+
127
+ def forward(self, x):
128
+ x0 = self.branch0(x)
129
+ x1 = self.branch1(x)
130
+ out = torch.cat((x0, x1), 1)
131
+ out = self.conv2d(out)
132
+ out = out * self.scale + x
133
+ out = self.relu(out)
134
+ return out
135
+
136
+ class Mixed_7a(nn.Module):
137
+ def __init__(self):
138
+ super(Mixed_7a, self).__init__()
139
+
140
+ self.branch0 = nn.Sequential(
141
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
142
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
143
+ )
144
+
145
+ self.branch1 = nn.Sequential(
146
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
147
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
148
+ )
149
+
150
+ self.branch2 = nn.Sequential(
151
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
152
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
153
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
154
+ )
155
+
156
+ self.branch3 = nn.MaxPool2d(3, stride=2)
157
+
158
+ def forward(self, x):
159
+ x0 = self.branch0(x)
160
+ x1 = self.branch1(x)
161
+ x2 = self.branch2(x)
162
+ x3 = self.branch3(x)
163
+ out = torch.cat((x0, x1, x2, x3), 1)
164
+ return out
165
+
166
+ class Block8(nn.Module):
167
+ def __init__(self, scale=1.0, noReLU=False):
168
+ super(Block8, self).__init__()
169
+ self.scale = scale
170
+ self.noReLU = noReLU
171
+
172
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
173
+
174
+ self.branch1 = nn.Sequential(
175
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
176
+ BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
177
+ BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
178
+ )
179
+
180
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
181
+ if not self.noReLU:
182
+ self.relu = nn.ReLU(inplace=False)
183
+
184
+ def forward(self, x):
185
+ x0 = self.branch0(x)
186
+ x1 = self.branch1(x)
187
+ out = torch.cat((x0, x1), 1)
188
+ out = self.conv2d(out)
189
+ out = out * self.scale + x
190
+ if not self.noReLU:
191
+ out = self.relu(out)
192
+ return out
193
+
194
+ class InceptionResNetV2(nn.Module):
195
+ def __init__(self, num_classes=1001):
196
+ super(InceptionResNetV2, self).__init__()
197
+ # Define all your layers here
198
+ self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
199
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
200
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
201
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
202
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
203
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
204
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
205
+ self.mixed_5b = Mixed_5b()
206
+ self.repeat = nn.Sequential(
207
+ Block35(scale=0.17),
208
+ Block35(scale=0.17),
209
+ Block35(scale=0.17),
210
+ Block35(scale=0.17),
211
+ Block35(scale=0.17),
212
+ Block35(scale=0.17),
213
+ Block35(scale=0.17),
214
+ Block35(scale=0.17),
215
+ Block35(scale=0.17),
216
+ Block35(scale=0.17)
217
+ )
218
+ self.mixed_6a = Mixed_6a()
219
+ self.repeat_1 = nn.Sequential(
220
+ Block17(scale=0.10),
221
+ Block17(scale=0.10),
222
+ Block17(scale=0.10),
223
+ Block17(scale=0.10),
224
+ Block17(scale=0.10),
225
+ Block17(scale=0.10),
226
+ Block17(scale=0.10),
227
+ Block17(scale=0.10),
228
+ Block17(scale=0.10),
229
+ Block17(scale=0.10),
230
+ Block17(scale=0.10),
231
+ Block17(scale=0.10),
232
+ Block17(scale=0.10),
233
+ Block17(scale=0.10),
234
+ Block17(scale=0.10),
235
+ Block17(scale=0.10),
236
+ Block17(scale=0.10),
237
+ Block17(scale=0.10),
238
+ Block17(scale=0.10),
239
+ Block17(scale=0.10)
240
+ )
241
+ self.mixed_7a = Mixed_7a()
242
+ self.repeat_2 = nn.Sequential(
243
+ Block8(scale=0.20),
244
+ Block8(scale=0.20),
245
+ Block8(scale=0.20),
246
+ Block8(scale=0.20),
247
+ Block8(scale=0.20),
248
+ Block8(scale=0.20),
249
+ Block8(scale=0.20),
250
+ Block8(scale=0.20),
251
+ Block8(scale=0.20)
252
+ )
253
+ self.block8 = Block8(noReLU=True)
254
+ self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
255
+ self.avgpool_1a = nn.AvgPool2d(8, stride=1, padding=0)
256
+ self.last_linear = nn.Linear(1536, num_classes)
257
+
258
+ def features(self, input):
259
+ x = self.conv2d_1a(input)
260
+ x = self.conv2d_2a(x)
261
+ x = self.conv2d_2b(x)
262
+ x = self.maxpool_3a(x)
263
+ x = self.conv2d_3b(x)
264
+ x = self.conv2d_4a(x)
265
+ x = self.maxpool_5a(x)
266
+ x = self.mixed_5b(x)
267
+ x = self.repeat(x)
268
+ x = self.mixed_6a(x)
269
+ x = self.repeat_1(x)
270
+ x = self.mixed_7a(x)
271
+ x = self.repeat_2(x)
272
+ x = self.block8(x)
273
+ x = self.conv2d_7b(x)
274
+ return x
275
+
276
+ def logits(self, features):
277
+ x = self.avgpool_1a(features)
278
+ x = x.view(x.size(0), -1)
279
+ x = self.last_linear(x)
280
+ return x
281
+
282
+ def forward(self, input):
283
+ x = self.features(input)
284
+ x = self.logits(x)
285
+ return x
286
+
287
+ def inceptionresnetv2(num_classes=1000):
288
+ return InceptionResNetV2(num_classes=num_classes)
289
+
290
+ # --- Load Models ---
291
+ @st.cache_resource
292
+ def load_inception_model(model_path):
293
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
+ model = inceptionresnetv2(num_classes=2).to(device) # Adjust num_classes as needed
295
+ model.load_state_dict(torch.load(model_path, map_location=device))
296
+ model.eval()
297
+ return model, device
298
+
299
+ @st.cache_resource
300
+ def load_yolo_model(yolo_model_path="yolov8n.pt"):
301
+ model = YOLO(yolo_model_path) # You can specify a custom YOLOv8 model path if needed
302
+ return model
303
+
304
+ # --- Image Preprocessing ---
305
+ data_transforms = transforms.Compose([
306
+ transforms.Resize(342),
307
+ transforms.CenterCrop(299),
308
+ transforms.ToTensor(),
309
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
310
+ ])
311
+
312
+ # --- Streamlit App ---
313
+ def main():
314
+ st.title("Image Anomaly Detection and Object Detection")
315
+ st.write("Upload an image to analyze for anomalies.")
316
+
317
+ # Load models
318
+ inception_model, device = load_inception_model(r'X:\mowito\Inception-ResNetV2-Weights\anamoly30.pth') # Ensure 'anamoly30.pth' is in the same directory
319
+ yolo_model = load_yolo_model(r'X:\mowito\mowito.pt') # Ensure 'yolov8n.pt' is in the same directory or specify the path
320
+
321
+ # Upload the image
322
+ uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
323
+
324
+ # User input for confidence threshold
325
+ threshold = st.slider("Set Confidence Threshold", 0.0, 1.0, 0.5, 0.01)
326
+
327
+ if uploaded_file is not None:
328
+ # Display the uploaded image
329
+ image = Image.open(uploaded_file).convert("RGB")
330
+ st.image(image, caption="Uploaded Image", width=400)
331
+
332
+ # Preprocess the image
333
+ transformed_image = data_transforms(image).unsqueeze(0).to(device)
334
+
335
+ # InceptionResNetV2 Prediction
336
+ with torch.no_grad():
337
+ outputs = inception_model(transformed_image)
338
+ _, predicted = torch.max(outputs, 1)
339
+ predicted_class = ['bad', 'good'][predicted.item()]
340
+ confidence = torch.nn.functional.softmax(outputs, dim=1)[0][predicted.item()].item()
341
+
342
+ st.write(f"**Prediction:** {predicted_class}")
343
+ st.write(f"**Confidence:** {confidence:.4f}")
344
+
345
+ # Check if confidence is above the threshold
346
+ if confidence >= threshold:
347
+ if predicted_class == "bad":
348
+ st.warning("Anomalies detected in the image. Processing further analysis...")
349
+
350
+ # Automatically run YOLOv8 on the uploaded image
351
+ st.write("Analyzing anomalies using YOLOv8...")
352
+ yolo_results = yolo_model.predict(source=image, conf=0.25, show=False)
353
+
354
+ # Display YOLOv8 predictions
355
+ st.write("### YOLOv8 Predictions:")
356
+ for result in yolo_results:
357
+ # Plot the results on the image
358
+ annotated_yolo_image = result.plot()
359
+ st.image(annotated_yolo_image, caption="YOLOv8 Detection", width=400)
360
+
361
+ # Optionally, display detailed results
362
+ st.write("### Detection Details:")
363
+ for result in yolo_results:
364
+ for box in result.boxes:
365
+ cls = int(box.cls)
366
+ conf = box.conf
367
+ label = yolo_model.names[cls] if cls < len(yolo_model.names) else "Unknown"
368
+ st.write(f"- **Label**: {label}, **Confidence**: {conf.item():.2f}")
369
+
370
+ # Provide interactive feedback option
371
+ st.info("You can annotate the image to refine analysis.")
372
+
373
+ # Initialize canvas for manual annotation
374
+ canvas_result = st_canvas(
375
+ fill_color="rgba(255, 165, 0, 0.3)", # Semi-transparent orange
376
+ stroke_width=2,
377
+ stroke_color="#FF0000", # Red
378
+ background_color="#FFFFFF",
379
+ background_image=image,
380
+ update_streamlit=True,
381
+ height=image.height,
382
+ width=image.width,
383
+ drawing_mode="rect", # Allow drawing rectangles
384
+ key="canvas",
385
+ )
386
+
387
+ if canvas_result.json_data is not None:
388
+ objects = canvas_result.json_data["objects"]
389
+ if len(objects) > 0:
390
+ st.success("Bounding boxes drawn. Click the button below to analyze with YOLOv8.")
391
+ if st.button("Analyze Manual Annotations"):
392
+ # Draw the bounding boxes on the image
393
+ annotated_image = image.copy()
394
+ draw = ImageDraw.Draw(annotated_image)
395
+ for obj in objects:
396
+ if obj["type"] == "rect":
397
+ left = obj["left"]
398
+ top = obj["top"]
399
+ width = obj["width"]
400
+ height = obj["height"]
401
+ draw.rectangle([left, top, left + width, top + height], outline="red", width=3)
402
+
403
+ st.image(annotated_image, caption="Annotated Image", width=400)
404
+
405
+ # Pass the manually annotated image to YOLOv8
406
+ yolo_results_manual = yolo_model.predict(source=annotated_image, conf=0.25, show=False)
407
+
408
+ # Display YOLOv8 predictions for annotated image
409
+ st.write("### YOLOv8 Predictions (Manual Annotations):")
410
+ for result in yolo_results_manual:
411
+ # Plot the results on the image
412
+ annotated_yolo_image_manual = result.plot()
413
+ st.image(annotated_yolo_image_manual, caption="YOLOv8 Detection (Manual)", width=400)
414
+
415
+ # Display detection details
416
+ st.write("### Detection Details (Manual Annotations):")
417
+ for result in yolo_results_manual:
418
+ for box in result.boxes:
419
+ cls = int(box.cls)
420
+ conf = box.conf
421
+ label = yolo_model.names[cls] if cls < len(yolo_model.names) else "Unknown"
422
+ st.write(f"- **Label**: {label}, **Confidence**: {conf.item():.2f}")
423
+ else:
424
+ st.info("Draw bounding boxes around the anomalies and press the button to analyze.")
425
+ else:
426
+ st.warning(f"The confidence level ({confidence:.4f}) is below the threshold of {threshold}. No further analysis will be performed.")
427
+ else:
428
+ st.info("Please upload an image to get started.")
429
+
430
+ if __name__ == "__main__":
431
+ main()