Balajim57 commited on
Commit
bf57be6
·
verified ·
1 Parent(s): e8ac7f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +430 -430
app.py CHANGED
@@ -1,431 +1,431 @@
1
- # app.py
2
-
3
- from __future__ import print_function, division, absolute_import
4
- import streamlit as st
5
- import torch
6
- import torch.nn as nn
7
- from torchvision import transforms
8
- from PIL import Image, ImageDraw
9
- from ultralytics import YOLO
10
- from streamlit_drawable_canvas import st_canvas
11
- import os
12
-
13
- # --- Define Basic Components for InceptionResNetV2 ---
14
- class BasicConv2d(nn.Module):
15
- def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
16
- super(BasicConv2d, self).__init__()
17
- self.conv = nn.Conv2d(in_planes, out_planes,
18
- kernel_size=kernel_size, stride=stride,
19
- padding=padding, bias=False)
20
- self.bn = nn.BatchNorm2d(out_planes)
21
- self.relu = nn.ReLU(inplace=False)
22
-
23
- def forward(self, x):
24
- x = self.conv(x)
25
- x = self.bn(x)
26
- x = self.relu(x)
27
- return x
28
-
29
- # --- Define InceptionResNetV2 Architecture ---
30
- class Mixed_5b(nn.Module):
31
- def __init__(self):
32
- super(Mixed_5b, self).__init__()
33
- self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
34
-
35
- self.branch1 = nn.Sequential(
36
- BasicConv2d(192, 48, kernel_size=1, stride=1),
37
- BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
38
- )
39
-
40
- self.branch2 = nn.Sequential(
41
- BasicConv2d(192, 64, kernel_size=1, stride=1),
42
- BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
43
- BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
44
- )
45
-
46
- self.branch3 = nn.Sequential(
47
- nn.AvgPool2d(3, stride=1, padding=1),
48
- BasicConv2d(192, 64, kernel_size=1, stride=1)
49
- )
50
-
51
- def forward(self, x):
52
- x0 = self.branch0(x)
53
- x1 = self.branch1(x)
54
- x2 = self.branch2(x)
55
- x3 = self.branch3(x)
56
- out = torch.cat((x0, x1, x2, x3), 1)
57
- return out
58
-
59
- class Block35(nn.Module):
60
- def __init__(self, scale=1.0):
61
- super(Block35, self).__init__()
62
- self.scale = scale
63
-
64
- self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
65
-
66
- self.branch1 = nn.Sequential(
67
- BasicConv2d(320, 32, kernel_size=1, stride=1),
68
- BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
69
- )
70
-
71
- self.branch2 = nn.Sequential(
72
- BasicConv2d(320, 32, kernel_size=1, stride=1),
73
- BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
74
- BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
75
- )
76
-
77
- self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
78
- self.relu = nn.ReLU(inplace=False)
79
-
80
- def forward(self, x):
81
- x0 = self.branch0(x)
82
- x1 = self.branch1(x)
83
- x2 = self.branch2(x)
84
- out = torch.cat((x0, x1, x2), 1)
85
- out = self.conv2d(out)
86
- out = out * self.scale + x
87
- out = self.relu(out)
88
- return out
89
-
90
- class Mixed_6a(nn.Module):
91
- def __init__(self):
92
- super(Mixed_6a, self).__init__()
93
-
94
- self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
95
-
96
- self.branch1 = nn.Sequential(
97
- BasicConv2d(320, 256, kernel_size=1, stride=1),
98
- BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
99
- BasicConv2d(256, 384, kernel_size=3, stride=2)
100
- )
101
-
102
- self.branch2 = nn.MaxPool2d(3, stride=2)
103
-
104
- def forward(self, x):
105
- x0 = self.branch0(x)
106
- x1 = self.branch1(x)
107
- x2 = self.branch2(x)
108
- out = torch.cat((x0, x1, x2), 1)
109
- return out
110
-
111
- class Block17(nn.Module):
112
- def __init__(self, scale=1.0):
113
- super(Block17, self).__init__()
114
- self.scale = scale
115
-
116
- self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
117
-
118
- self.branch1 = nn.Sequential(
119
- BasicConv2d(1088, 128, kernel_size=1, stride=1),
120
- BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
121
- BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
122
- )
123
-
124
- self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
125
- self.relu = nn.ReLU(inplace=False)
126
-
127
- def forward(self, x):
128
- x0 = self.branch0(x)
129
- x1 = self.branch1(x)
130
- out = torch.cat((x0, x1), 1)
131
- out = self.conv2d(out)
132
- out = out * self.scale + x
133
- out = self.relu(out)
134
- return out
135
-
136
- class Mixed_7a(nn.Module):
137
- def __init__(self):
138
- super(Mixed_7a, self).__init__()
139
-
140
- self.branch0 = nn.Sequential(
141
- BasicConv2d(1088, 256, kernel_size=1, stride=1),
142
- BasicConv2d(256, 384, kernel_size=3, stride=2)
143
- )
144
-
145
- self.branch1 = nn.Sequential(
146
- BasicConv2d(1088, 256, kernel_size=1, stride=1),
147
- BasicConv2d(256, 288, kernel_size=3, stride=2)
148
- )
149
-
150
- self.branch2 = nn.Sequential(
151
- BasicConv2d(1088, 256, kernel_size=1, stride=1),
152
- BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
153
- BasicConv2d(288, 320, kernel_size=3, stride=2)
154
- )
155
-
156
- self.branch3 = nn.MaxPool2d(3, stride=2)
157
-
158
- def forward(self, x):
159
- x0 = self.branch0(x)
160
- x1 = self.branch1(x)
161
- x2 = self.branch2(x)
162
- x3 = self.branch3(x)
163
- out = torch.cat((x0, x1, x2, x3), 1)
164
- return out
165
-
166
- class Block8(nn.Module):
167
- def __init__(self, scale=1.0, noReLU=False):
168
- super(Block8, self).__init__()
169
- self.scale = scale
170
- self.noReLU = noReLU
171
-
172
- self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
173
-
174
- self.branch1 = nn.Sequential(
175
- BasicConv2d(2080, 192, kernel_size=1, stride=1),
176
- BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
177
- BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
178
- )
179
-
180
- self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
181
- if not self.noReLU:
182
- self.relu = nn.ReLU(inplace=False)
183
-
184
- def forward(self, x):
185
- x0 = self.branch0(x)
186
- x1 = self.branch1(x)
187
- out = torch.cat((x0, x1), 1)
188
- out = self.conv2d(out)
189
- out = out * self.scale + x
190
- if not self.noReLU:
191
- out = self.relu(out)
192
- return out
193
-
194
- class InceptionResNetV2(nn.Module):
195
- def __init__(self, num_classes=1001):
196
- super(InceptionResNetV2, self).__init__()
197
- # Define all your layers here
198
- self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
199
- self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
200
- self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
201
- self.maxpool_3a = nn.MaxPool2d(3, stride=2)
202
- self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
203
- self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
204
- self.maxpool_5a = nn.MaxPool2d(3, stride=2)
205
- self.mixed_5b = Mixed_5b()
206
- self.repeat = nn.Sequential(
207
- Block35(scale=0.17),
208
- Block35(scale=0.17),
209
- Block35(scale=0.17),
210
- Block35(scale=0.17),
211
- Block35(scale=0.17),
212
- Block35(scale=0.17),
213
- Block35(scale=0.17),
214
- Block35(scale=0.17),
215
- Block35(scale=0.17),
216
- Block35(scale=0.17)
217
- )
218
- self.mixed_6a = Mixed_6a()
219
- self.repeat_1 = nn.Sequential(
220
- Block17(scale=0.10),
221
- Block17(scale=0.10),
222
- Block17(scale=0.10),
223
- Block17(scale=0.10),
224
- Block17(scale=0.10),
225
- Block17(scale=0.10),
226
- Block17(scale=0.10),
227
- Block17(scale=0.10),
228
- Block17(scale=0.10),
229
- Block17(scale=0.10),
230
- Block17(scale=0.10),
231
- Block17(scale=0.10),
232
- Block17(scale=0.10),
233
- Block17(scale=0.10),
234
- Block17(scale=0.10),
235
- Block17(scale=0.10),
236
- Block17(scale=0.10),
237
- Block17(scale=0.10),
238
- Block17(scale=0.10),
239
- Block17(scale=0.10)
240
- )
241
- self.mixed_7a = Mixed_7a()
242
- self.repeat_2 = nn.Sequential(
243
- Block8(scale=0.20),
244
- Block8(scale=0.20),
245
- Block8(scale=0.20),
246
- Block8(scale=0.20),
247
- Block8(scale=0.20),
248
- Block8(scale=0.20),
249
- Block8(scale=0.20),
250
- Block8(scale=0.20),
251
- Block8(scale=0.20)
252
- )
253
- self.block8 = Block8(noReLU=True)
254
- self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
255
- self.avgpool_1a = nn.AvgPool2d(8, stride=1, padding=0)
256
- self.last_linear = nn.Linear(1536, num_classes)
257
-
258
- def features(self, input):
259
- x = self.conv2d_1a(input)
260
- x = self.conv2d_2a(x)
261
- x = self.conv2d_2b(x)
262
- x = self.maxpool_3a(x)
263
- x = self.conv2d_3b(x)
264
- x = self.conv2d_4a(x)
265
- x = self.maxpool_5a(x)
266
- x = self.mixed_5b(x)
267
- x = self.repeat(x)
268
- x = self.mixed_6a(x)
269
- x = self.repeat_1(x)
270
- x = self.mixed_7a(x)
271
- x = self.repeat_2(x)
272
- x = self.block8(x)
273
- x = self.conv2d_7b(x)
274
- return x
275
-
276
- def logits(self, features):
277
- x = self.avgpool_1a(features)
278
- x = x.view(x.size(0), -1)
279
- x = self.last_linear(x)
280
- return x
281
-
282
- def forward(self, input):
283
- x = self.features(input)
284
- x = self.logits(x)
285
- return x
286
-
287
- def inceptionresnetv2(num_classes=1000):
288
- return InceptionResNetV2(num_classes=num_classes)
289
-
290
- # --- Load Models ---
291
- @st.cache_resource
292
- def load_inception_model(model_path):
293
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
- model = inceptionresnetv2(num_classes=2).to(device) # Adjust num_classes as needed
295
- model.load_state_dict(torch.load(model_path, map_location=device))
296
- model.eval()
297
- return model, device
298
-
299
- @st.cache_resource
300
- def load_yolo_model(yolo_model_path="yolov8n.pt"):
301
- model = YOLO(yolo_model_path) # You can specify a custom YOLOv8 model path if needed
302
- return model
303
-
304
- # --- Image Preprocessing ---
305
- data_transforms = transforms.Compose([
306
- transforms.Resize(342),
307
- transforms.CenterCrop(299),
308
- transforms.ToTensor(),
309
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
310
- ])
311
-
312
- # --- Streamlit App ---
313
- def main():
314
- st.title("Image Anomaly Detection and Object Detection")
315
- st.write("Upload an image to analyze for anomalies.")
316
-
317
- # Load models
318
- inception_model, device = load_inception_model(r'X:\mowito\Inception-ResNetV2-Weights\anamoly30.pth') # Ensure 'anamoly30.pth' is in the same directory
319
- yolo_model = load_yolo_model(r'X:\mowito\mowito.pt') # Ensure 'yolov8n.pt' is in the same directory or specify the path
320
-
321
- # Upload the image
322
- uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
323
-
324
- # User input for confidence threshold
325
- threshold = st.slider("Set Confidence Threshold", 0.0, 1.0, 0.5, 0.01)
326
-
327
- if uploaded_file is not None:
328
- # Display the uploaded image
329
- image = Image.open(uploaded_file).convert("RGB")
330
- st.image(image, caption="Uploaded Image", width=400)
331
-
332
- # Preprocess the image
333
- transformed_image = data_transforms(image).unsqueeze(0).to(device)
334
-
335
- # InceptionResNetV2 Prediction
336
- with torch.no_grad():
337
- outputs = inception_model(transformed_image)
338
- _, predicted = torch.max(outputs, 1)
339
- predicted_class = ['bad', 'good'][predicted.item()]
340
- confidence = torch.nn.functional.softmax(outputs, dim=1)[0][predicted.item()].item()
341
-
342
- st.write(f"**Prediction:** {predicted_class}")
343
- st.write(f"**Confidence:** {confidence:.4f}")
344
-
345
- # Check if confidence is above the threshold
346
- if confidence >= threshold:
347
- if predicted_class == "bad":
348
- st.warning("Anomalies detected in the image. Processing further analysis...")
349
-
350
- # Automatically run YOLOv8 on the uploaded image
351
- st.write("Analyzing anomalies using YOLOv8...")
352
- yolo_results = yolo_model.predict(source=image, conf=0.25, show=False)
353
-
354
- # Display YOLOv8 predictions
355
- st.write("### YOLOv8 Predictions:")
356
- for result in yolo_results:
357
- # Plot the results on the image
358
- annotated_yolo_image = result.plot()
359
- st.image(annotated_yolo_image, caption="YOLOv8 Detection", width=400)
360
-
361
- # Optionally, display detailed results
362
- st.write("### Detection Details:")
363
- for result in yolo_results:
364
- for box in result.boxes:
365
- cls = int(box.cls)
366
- conf = box.conf
367
- label = yolo_model.names[cls] if cls < len(yolo_model.names) else "Unknown"
368
- st.write(f"- **Label**: {label}, **Confidence**: {conf.item():.2f}")
369
-
370
- # Provide interactive feedback option
371
- st.info("You can annotate the image to refine analysis.")
372
-
373
- # Initialize canvas for manual annotation
374
- canvas_result = st_canvas(
375
- fill_color="rgba(255, 165, 0, 0.3)", # Semi-transparent orange
376
- stroke_width=2,
377
- stroke_color="#FF0000", # Red
378
- background_color="#FFFFFF",
379
- background_image=image,
380
- update_streamlit=True,
381
- height=image.height,
382
- width=image.width,
383
- drawing_mode="rect", # Allow drawing rectangles
384
- key="canvas",
385
- )
386
-
387
- if canvas_result.json_data is not None:
388
- objects = canvas_result.json_data["objects"]
389
- if len(objects) > 0:
390
- st.success("Bounding boxes drawn. Click the button below to analyze with YOLOv8.")
391
- if st.button("Analyze Manual Annotations"):
392
- # Draw the bounding boxes on the image
393
- annotated_image = image.copy()
394
- draw = ImageDraw.Draw(annotated_image)
395
- for obj in objects:
396
- if obj["type"] == "rect":
397
- left = obj["left"]
398
- top = obj["top"]
399
- width = obj["width"]
400
- height = obj["height"]
401
- draw.rectangle([left, top, left + width, top + height], outline="red", width=3)
402
-
403
- st.image(annotated_image, caption="Annotated Image", width=400)
404
-
405
- # Pass the manually annotated image to YOLOv8
406
- yolo_results_manual = yolo_model.predict(source=annotated_image, conf=0.25, show=False)
407
-
408
- # Display YOLOv8 predictions for annotated image
409
- st.write("### YOLOv8 Predictions (Manual Annotations):")
410
- for result in yolo_results_manual:
411
- # Plot the results on the image
412
- annotated_yolo_image_manual = result.plot()
413
- st.image(annotated_yolo_image_manual, caption="YOLOv8 Detection (Manual)", width=400)
414
-
415
- # Display detection details
416
- st.write("### Detection Details (Manual Annotations):")
417
- for result in yolo_results_manual:
418
- for box in result.boxes:
419
- cls = int(box.cls)
420
- conf = box.conf
421
- label = yolo_model.names[cls] if cls < len(yolo_model.names) else "Unknown"
422
- st.write(f"- **Label**: {label}, **Confidence**: {conf.item():.2f}")
423
- else:
424
- st.info("Draw bounding boxes around the anomalies and press the button to analyze.")
425
- else:
426
- st.warning(f"The confidence level ({confidence:.4f}) is below the threshold of {threshold}. No further analysis will be performed.")
427
- else:
428
- st.info("Please upload an image to get started.")
429
-
430
- if __name__ == "__main__":
431
  main()
 
1
+ # app.py
2
+
3
+ from __future__ import print_function, division, absolute_import
4
+ import streamlit as st
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import transforms
8
+ from PIL import Image, ImageDraw
9
+ from ultralytics import YOLO
10
+ from streamlit_drawable_canvas import st_canvas
11
+ import os
12
+
13
+ # --- Define Basic Components for InceptionResNetV2 ---
14
+ class BasicConv2d(nn.Module):
15
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
16
+ super(BasicConv2d, self).__init__()
17
+ self.conv = nn.Conv2d(in_planes, out_planes,
18
+ kernel_size=kernel_size, stride=stride,
19
+ padding=padding, bias=False)
20
+ self.bn = nn.BatchNorm2d(out_planes)
21
+ self.relu = nn.ReLU(inplace=False)
22
+
23
+ def forward(self, x):
24
+ x = self.conv(x)
25
+ x = self.bn(x)
26
+ x = self.relu(x)
27
+ return x
28
+
29
+ # --- Define InceptionResNetV2 Architecture ---
30
+ class Mixed_5b(nn.Module):
31
+ def __init__(self):
32
+ super(Mixed_5b, self).__init__()
33
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
34
+
35
+ self.branch1 = nn.Sequential(
36
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
37
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
38
+ )
39
+
40
+ self.branch2 = nn.Sequential(
41
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
42
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
43
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
44
+ )
45
+
46
+ self.branch3 = nn.Sequential(
47
+ nn.AvgPool2d(3, stride=1, padding=1),
48
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
49
+ )
50
+
51
+ def forward(self, x):
52
+ x0 = self.branch0(x)
53
+ x1 = self.branch1(x)
54
+ x2 = self.branch2(x)
55
+ x3 = self.branch3(x)
56
+ out = torch.cat((x0, x1, x2, x3), 1)
57
+ return out
58
+
59
+ class Block35(nn.Module):
60
+ def __init__(self, scale=1.0):
61
+ super(Block35, self).__init__()
62
+ self.scale = scale
63
+
64
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
65
+
66
+ self.branch1 = nn.Sequential(
67
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
68
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
69
+ )
70
+
71
+ self.branch2 = nn.Sequential(
72
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
73
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
74
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
75
+ )
76
+
77
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
78
+ self.relu = nn.ReLU(inplace=False)
79
+
80
+ def forward(self, x):
81
+ x0 = self.branch0(x)
82
+ x1 = self.branch1(x)
83
+ x2 = self.branch2(x)
84
+ out = torch.cat((x0, x1, x2), 1)
85
+ out = self.conv2d(out)
86
+ out = out * self.scale + x
87
+ out = self.relu(out)
88
+ return out
89
+
90
+ class Mixed_6a(nn.Module):
91
+ def __init__(self):
92
+ super(Mixed_6a, self).__init__()
93
+
94
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
95
+
96
+ self.branch1 = nn.Sequential(
97
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
98
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
99
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
100
+ )
101
+
102
+ self.branch2 = nn.MaxPool2d(3, stride=2)
103
+
104
+ def forward(self, x):
105
+ x0 = self.branch0(x)
106
+ x1 = self.branch1(x)
107
+ x2 = self.branch2(x)
108
+ out = torch.cat((x0, x1, x2), 1)
109
+ return out
110
+
111
+ class Block17(nn.Module):
112
+ def __init__(self, scale=1.0):
113
+ super(Block17, self).__init__()
114
+ self.scale = scale
115
+
116
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
117
+
118
+ self.branch1 = nn.Sequential(
119
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
120
+ BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
121
+ BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
122
+ )
123
+
124
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
125
+ self.relu = nn.ReLU(inplace=False)
126
+
127
+ def forward(self, x):
128
+ x0 = self.branch0(x)
129
+ x1 = self.branch1(x)
130
+ out = torch.cat((x0, x1), 1)
131
+ out = self.conv2d(out)
132
+ out = out * self.scale + x
133
+ out = self.relu(out)
134
+ return out
135
+
136
+ class Mixed_7a(nn.Module):
137
+ def __init__(self):
138
+ super(Mixed_7a, self).__init__()
139
+
140
+ self.branch0 = nn.Sequential(
141
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
142
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
143
+ )
144
+
145
+ self.branch1 = nn.Sequential(
146
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
147
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
148
+ )
149
+
150
+ self.branch2 = nn.Sequential(
151
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
152
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
153
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
154
+ )
155
+
156
+ self.branch3 = nn.MaxPool2d(3, stride=2)
157
+
158
+ def forward(self, x):
159
+ x0 = self.branch0(x)
160
+ x1 = self.branch1(x)
161
+ x2 = self.branch2(x)
162
+ x3 = self.branch3(x)
163
+ out = torch.cat((x0, x1, x2, x3), 1)
164
+ return out
165
+
166
+ class Block8(nn.Module):
167
+ def __init__(self, scale=1.0, noReLU=False):
168
+ super(Block8, self).__init__()
169
+ self.scale = scale
170
+ self.noReLU = noReLU
171
+
172
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
173
+
174
+ self.branch1 = nn.Sequential(
175
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
176
+ BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
177
+ BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
178
+ )
179
+
180
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
181
+ if not self.noReLU:
182
+ self.relu = nn.ReLU(inplace=False)
183
+
184
+ def forward(self, x):
185
+ x0 = self.branch0(x)
186
+ x1 = self.branch1(x)
187
+ out = torch.cat((x0, x1), 1)
188
+ out = self.conv2d(out)
189
+ out = out * self.scale + x
190
+ if not self.noReLU:
191
+ out = self.relu(out)
192
+ return out
193
+
194
+ class InceptionResNetV2(nn.Module):
195
+ def __init__(self, num_classes=1001):
196
+ super(InceptionResNetV2, self).__init__()
197
+ # Define all your layers here
198
+ self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
199
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
200
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
201
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
202
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
203
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
204
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
205
+ self.mixed_5b = Mixed_5b()
206
+ self.repeat = nn.Sequential(
207
+ Block35(scale=0.17),
208
+ Block35(scale=0.17),
209
+ Block35(scale=0.17),
210
+ Block35(scale=0.17),
211
+ Block35(scale=0.17),
212
+ Block35(scale=0.17),
213
+ Block35(scale=0.17),
214
+ Block35(scale=0.17),
215
+ Block35(scale=0.17),
216
+ Block35(scale=0.17)
217
+ )
218
+ self.mixed_6a = Mixed_6a()
219
+ self.repeat_1 = nn.Sequential(
220
+ Block17(scale=0.10),
221
+ Block17(scale=0.10),
222
+ Block17(scale=0.10),
223
+ Block17(scale=0.10),
224
+ Block17(scale=0.10),
225
+ Block17(scale=0.10),
226
+ Block17(scale=0.10),
227
+ Block17(scale=0.10),
228
+ Block17(scale=0.10),
229
+ Block17(scale=0.10),
230
+ Block17(scale=0.10),
231
+ Block17(scale=0.10),
232
+ Block17(scale=0.10),
233
+ Block17(scale=0.10),
234
+ Block17(scale=0.10),
235
+ Block17(scale=0.10),
236
+ Block17(scale=0.10),
237
+ Block17(scale=0.10),
238
+ Block17(scale=0.10),
239
+ Block17(scale=0.10)
240
+ )
241
+ self.mixed_7a = Mixed_7a()
242
+ self.repeat_2 = nn.Sequential(
243
+ Block8(scale=0.20),
244
+ Block8(scale=0.20),
245
+ Block8(scale=0.20),
246
+ Block8(scale=0.20),
247
+ Block8(scale=0.20),
248
+ Block8(scale=0.20),
249
+ Block8(scale=0.20),
250
+ Block8(scale=0.20),
251
+ Block8(scale=0.20)
252
+ )
253
+ self.block8 = Block8(noReLU=True)
254
+ self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
255
+ self.avgpool_1a = nn.AvgPool2d(8, stride=1, padding=0)
256
+ self.last_linear = nn.Linear(1536, num_classes)
257
+
258
+ def features(self, input):
259
+ x = self.conv2d_1a(input)
260
+ x = self.conv2d_2a(x)
261
+ x = self.conv2d_2b(x)
262
+ x = self.maxpool_3a(x)
263
+ x = self.conv2d_3b(x)
264
+ x = self.conv2d_4a(x)
265
+ x = self.maxpool_5a(x)
266
+ x = self.mixed_5b(x)
267
+ x = self.repeat(x)
268
+ x = self.mixed_6a(x)
269
+ x = self.repeat_1(x)
270
+ x = self.mixed_7a(x)
271
+ x = self.repeat_2(x)
272
+ x = self.block8(x)
273
+ x = self.conv2d_7b(x)
274
+ return x
275
+
276
+ def logits(self, features):
277
+ x = self.avgpool_1a(features)
278
+ x = x.view(x.size(0), -1)
279
+ x = self.last_linear(x)
280
+ return x
281
+
282
+ def forward(self, input):
283
+ x = self.features(input)
284
+ x = self.logits(x)
285
+ return x
286
+
287
+ def inceptionresnetv2(num_classes=1000):
288
+ return InceptionResNetV2(num_classes=num_classes)
289
+
290
+ # --- Load Models ---
291
+ @st.cache_resource
292
+ def load_inception_model(model_path):
293
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
+ model = inceptionresnetv2(num_classes=2).to(device) # Adjust num_classes as needed
295
+ model.load_state_dict(torch.load(model_path, map_location=device))
296
+ model.eval()
297
+ return model, device
298
+
299
+ @st.cache_resource
300
+ def load_yolo_model(yolo_model_path="yolov8n.pt"):
301
+ model = YOLO(yolo_model_path) # You can specify a custom YOLOv8 model path if needed
302
+ return model
303
+
304
+ # --- Image Preprocessing ---
305
+ data_transforms = transforms.Compose([
306
+ transforms.Resize(342),
307
+ transforms.CenterCrop(299),
308
+ transforms.ToTensor(),
309
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
310
+ ])
311
+
312
+ # --- Streamlit App ---
313
+ def main():
314
+ st.title("Image Anomaly Detection and Object Detection")
315
+ st.write("Upload an image to analyze for anomalies.")
316
+
317
+ # Load models
318
+ inception_model, device = load_inception_model(r'anamoly30.pth') # Ensure 'anamoly30.pth' is in the same directory
319
+ yolo_model = load_yolo_model(r'X:\mowito\mowito.pt') # Ensure 'yolov8n.pt' is in the same directory or specify the path
320
+
321
+ # Upload the image
322
+ uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
323
+
324
+ # User input for confidence threshold
325
+ threshold = st.slider("Set Confidence Threshold", 0.0, 1.0, 0.5, 0.01)
326
+
327
+ if uploaded_file is not None:
328
+ # Display the uploaded image
329
+ image = Image.open(uploaded_file).convert("RGB")
330
+ st.image(image, caption="Uploaded Image", width=400)
331
+
332
+ # Preprocess the image
333
+ transformed_image = data_transforms(image).unsqueeze(0).to(device)
334
+
335
+ # InceptionResNetV2 Prediction
336
+ with torch.no_grad():
337
+ outputs = inception_model(transformed_image)
338
+ _, predicted = torch.max(outputs, 1)
339
+ predicted_class = ['bad', 'good'][predicted.item()]
340
+ confidence = torch.nn.functional.softmax(outputs, dim=1)[0][predicted.item()].item()
341
+
342
+ st.write(f"**Prediction:** {predicted_class}")
343
+ st.write(f"**Confidence:** {confidence:.4f}")
344
+
345
+ # Check if confidence is above the threshold
346
+ if confidence >= threshold:
347
+ if predicted_class == "bad":
348
+ st.warning("Anomalies detected in the image. Processing further analysis...")
349
+
350
+ # Automatically run YOLOv8 on the uploaded image
351
+ st.write("Analyzing anomalies using YOLOv8...")
352
+ yolo_results = yolo_model.predict(source=image, conf=0.25, show=False)
353
+
354
+ # Display YOLOv8 predictions
355
+ st.write("### YOLOv8 Predictions:")
356
+ for result in yolo_results:
357
+ # Plot the results on the image
358
+ annotated_yolo_image = result.plot()
359
+ st.image(annotated_yolo_image, caption="YOLOv8 Detection", width=400)
360
+
361
+ # Optionally, display detailed results
362
+ st.write("### Detection Details:")
363
+ for result in yolo_results:
364
+ for box in result.boxes:
365
+ cls = int(box.cls)
366
+ conf = box.conf
367
+ label = yolo_model.names[cls] if cls < len(yolo_model.names) else "Unknown"
368
+ st.write(f"- **Label**: {label}, **Confidence**: {conf.item():.2f}")
369
+
370
+ # Provide interactive feedback option
371
+ st.info("You can annotate the image to refine analysis.")
372
+
373
+ # Initialize canvas for manual annotation
374
+ canvas_result = st_canvas(
375
+ fill_color="rgba(255, 165, 0, 0.3)", # Semi-transparent orange
376
+ stroke_width=2,
377
+ stroke_color="#FF0000", # Red
378
+ background_color="#FFFFFF",
379
+ background_image=image,
380
+ update_streamlit=True,
381
+ height=image.height,
382
+ width=image.width,
383
+ drawing_mode="rect", # Allow drawing rectangles
384
+ key="canvas",
385
+ )
386
+
387
+ if canvas_result.json_data is not None:
388
+ objects = canvas_result.json_data["objects"]
389
+ if len(objects) > 0:
390
+ st.success("Bounding boxes drawn. Click the button below to analyze with YOLOv8.")
391
+ if st.button("Analyze Manual Annotations"):
392
+ # Draw the bounding boxes on the image
393
+ annotated_image = image.copy()
394
+ draw = ImageDraw.Draw(annotated_image)
395
+ for obj in objects:
396
+ if obj["type"] == "rect":
397
+ left = obj["left"]
398
+ top = obj["top"]
399
+ width = obj["width"]
400
+ height = obj["height"]
401
+ draw.rectangle([left, top, left + width, top + height], outline="red", width=3)
402
+
403
+ st.image(annotated_image, caption="Annotated Image", width=400)
404
+
405
+ # Pass the manually annotated image to YOLOv8
406
+ yolo_results_manual = yolo_model.predict(source=annotated_image, conf=0.25, show=False)
407
+
408
+ # Display YOLOv8 predictions for annotated image
409
+ st.write("### YOLOv8 Predictions (Manual Annotations):")
410
+ for result in yolo_results_manual:
411
+ # Plot the results on the image
412
+ annotated_yolo_image_manual = result.plot()
413
+ st.image(annotated_yolo_image_manual, caption="YOLOv8 Detection (Manual)", width=400)
414
+
415
+ # Display detection details
416
+ st.write("### Detection Details (Manual Annotations):")
417
+ for result in yolo_results_manual:
418
+ for box in result.boxes:
419
+ cls = int(box.cls)
420
+ conf = box.conf
421
+ label = yolo_model.names[cls] if cls < len(yolo_model.names) else "Unknown"
422
+ st.write(f"- **Label**: {label}, **Confidence**: {conf.item():.2f}")
423
+ else:
424
+ st.info("Draw bounding boxes around the anomalies and press the button to analyze.")
425
+ else:
426
+ st.warning(f"The confidence level ({confidence:.4f}) is below the threshold of {threshold}. No further analysis will be performed.")
427
+ else:
428
+ st.info("Please upload an image to get started.")
429
+
430
+ if __name__ == "__main__":
431
  main()