jovian commited on
Commit
ee2d685
·
1 Parent(s): e8923cc

Add application file

Browse files
Files changed (3) hide show
  1. app.py +512 -0
  2. model/best.pt +3 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from sahi.predict import get_sliced_prediction
5
+ from sahi import AutoDetectionModel
6
+ from PIL import Image
7
+ import plotly.graph_objects as go
8
+ import spaces
9
+ import torch
10
+
11
+
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+
14
+
15
+ class Detection:
16
+ def __init__(self):
17
+ # Set the model path and confidence threshold
18
+ yolov8_model_path = "./model/best.pt" # Update to your model path
19
+
20
+ # Initialize the AutoDetectionModel
21
+ self.model = AutoDetectionModel.from_pretrained(
22
+ model_type='yolov8',
23
+ model_path=yolov8_model_path,
24
+ confidence_threshold=0.3,
25
+ device=device # Change to 'cuda:0' if you are using a GPU
26
+ )
27
+
28
+ @spaces.GPU
29
+ def detect_from_image(self, image):
30
+ # Perform sliced prediction with SAHI
31
+ results = get_sliced_prediction(
32
+ image=image,
33
+ detection_model=self.model,
34
+ slice_height=256,
35
+ slice_width=256,
36
+ overlap_height_ratio=0.2,
37
+ overlap_width_ratio=0.2,
38
+ postprocess_type='NMS',
39
+ postprocess_match_metric='IOU',
40
+ postprocess_match_threshold=0.1,
41
+ postprocess_class_agnostic=True
42
+ )
43
+
44
+ # Retrieve COCO annotations
45
+ coco_annotations = results.to_coco_annotations()
46
+ return coco_annotations
47
+
48
+ def draw_annotations(self, image, annotations):
49
+ """Draw bounding boxes on the image based on COCO annotations using OpenCV."""
50
+ # Define colors for each category in BGR (OpenCV uses BGR format)
51
+ category_styles = {
52
+ 'Nicks': {'color': (255, 60, 60), 'thickness': 2}, # Nicks (Red)
53
+ 'Dents': {'color': (255, 148, 156), 'thickness': 2}, # Dents (Light Red)
54
+ 'Scratches': {'color': (255, 116, 28), 'thickness': 2}, # Scratches (Orange)
55
+ 'Pittings': {'color': (255, 180, 28), 'thickness': 2} # Pittings (Yellow)
56
+ }
57
+
58
+ for annotation in annotations:
59
+ bbox = annotation['bbox'] # Extract the bounding box
60
+ category_name = annotation['category_name']
61
+ score = annotation.get('score', 0) # Extract confidence score, default to 0 if not present
62
+
63
+ # Get color and thickness for the current category
64
+ style = category_styles.get(category_name, {'color': (255, 0, 0), 'thickness': 2}) # Default to red if not found
65
+
66
+ # Draw rectangle
67
+ cv2.rectangle(image,
68
+ (int(bbox[0]), int(bbox[1])),
69
+ (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])),
70
+ style['color'],
71
+ style['thickness'])
72
+
73
+ # Prepare text with category and confidence score
74
+ text = f"{category_name}: {score:.2f}" # Format the score to two decimal places
75
+
76
+ # Put category text with score
77
+ cv2.putText(image,
78
+ text,
79
+ (int(bbox[0]), int(bbox[1] - 10)), # Position above the rectangle
80
+ cv2.FONT_HERSHEY_SIMPLEX,
81
+ 0.5,
82
+ style['color'],
83
+ 2)
84
+
85
+ return image
86
+
87
+
88
+ def generate_individual_graphs(self, annotations):
89
+ """Generate individual area distribution histograms for each defect category."""
90
+ # Dictionary to hold areas for each category
91
+ category_areas = {
92
+ 'Nicks': [],
93
+ 'Dents': [],
94
+ 'Scratches': [],
95
+ 'Pittings': []
96
+ }
97
+
98
+ # Populate the category_areas dictionary
99
+ for annotation in annotations:
100
+ category_name = annotation['category_name']
101
+ area = annotation['bbox'][2] * annotation['bbox'][3] # Width * Height
102
+ if category_name in category_areas:
103
+ category_areas[category_name].append(area)
104
+
105
+ # Create individual area distribution histograms for each category
106
+ individual_graphs = {}
107
+ for category in ['Nicks', 'Dents', 'Scratches', 'Pittings']:
108
+ areas = category_areas[category]
109
+ fig = go.Figure()
110
+ if areas: # Check if there are areas to plot
111
+ # Create a histogram and store the frequencies
112
+ histogram_data = go.Histogram(
113
+ x=areas,
114
+ name=category,
115
+ marker_color=self.get_color(category), # Use associated color
116
+ opacity=1,
117
+ nbinsx=10 # Number of bins
118
+ )
119
+ fig.add_trace(histogram_data)
120
+
121
+ # Get the frequencies and edges for swapping axes
122
+ frequencies = histogram_data.y
123
+ edges = histogram_data.x
124
+
125
+ # Create a bar chart to swap the axes
126
+ fig = go.Figure(data=[
127
+ go.Bar(
128
+ x=frequencies, # Frequencies on x-axis
129
+ y=edges, # Edges on y-axis
130
+ name=category,
131
+ marker_color=self.get_color(category), # Use associated color
132
+ opacity=1
133
+ )
134
+ ])
135
+ else: # Generate an empty graph if no areas
136
+ fig.add_trace(go.Bar(x=[], y=[], name=category)) # Empty graph
137
+
138
+ # Update layout with swapped axes
139
+ fig.update_layout(
140
+ title=f'Area Distribution of {category}',
141
+ xaxis_title='Frequency', # Frequency on x-axis
142
+ yaxis_title='Area', # Area on y-axis
143
+ showlegend=True
144
+ )
145
+ individual_graphs[category] = fig
146
+
147
+ return individual_graphs['Nicks'], individual_graphs['Dents'], individual_graphs['Scratches'], individual_graphs['Pittings']
148
+
149
+
150
+
151
+ def generate_frequency_graph(self, annotations):
152
+ """Generate a frequency bar chart for defect categories."""
153
+ category_counts = {
154
+ 'Nicks': 0,
155
+ 'Dents': 0,
156
+ 'Scratches': 0,
157
+ 'Pittings': 0
158
+ }
159
+
160
+ # Count occurrences of each defect category
161
+ for annotation in annotations:
162
+ category_name = annotation['category_name']
163
+ if category_name in category_counts:
164
+ category_counts[category_name] += 1
165
+
166
+ # Create a bar chart for frequency
167
+ freq_chart = go.Figure()
168
+ category_colors = {
169
+ 'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
170
+ 'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
171
+ 'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
172
+ 'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
173
+ }
174
+
175
+ for category, count in category_counts.items():
176
+ freq_chart.add_trace(go.Bar(
177
+ x=[category],
178
+ y=[count],
179
+ name=category,
180
+ marker_color=category_colors.get(category, 'blue') # Default to blue if not found
181
+ ))
182
+
183
+ freq_chart.update_layout(
184
+ title='Frequency of Defects',
185
+ xaxis_title='Defect Category',
186
+ yaxis_title='Count',
187
+ barmode='group'
188
+ )
189
+
190
+ return freq_chart
191
+
192
+
193
+ def get_color(self, category_name):
194
+ """Get the color associated with a category name."""
195
+ category_styles = {
196
+ 'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
197
+ 'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
198
+ 'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
199
+ 'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
200
+ }
201
+ return category_styles.get(category_name, (255, 0, 0)) # Default to red if not found
202
+
203
+
204
+
205
+ detection = Detection()
206
+
207
+ def upload_image(image):
208
+ """Process the uploaded image (if needed) and display it."""
209
+ return image
210
+
211
+ def apply_detection(image):
212
+ """Run object detection on the uploaded image and return the annotated image."""
213
+ # Convert image from PIL to NumPy array
214
+ img = np.array(image)
215
+
216
+ # Perform detection and get COCO annotations
217
+ annotations = detection.detect_from_image(img)
218
+
219
+ # Draw the annotations on the image using OpenCV
220
+ annotated_image = detection.draw_annotations(img, annotations)
221
+
222
+ # Convert back to PIL format for Gradio output
223
+ return Image.fromarray(annotated_image), annotations
224
+
225
+ def generate_graphs_btn(annotations):
226
+ """Generate interactive graphs from the annotations."""
227
+ # Generate individual graphs for each defect category
228
+ individual_graphs = detection.generate_individual_graphs(annotations)
229
+ frequency_graph = detection.generate_frequency_graph(annotations)
230
+ return individual_graphs
231
+
232
+ css = """
233
+
234
+ @import url('https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&family=Montserrat:wght@700&family=Open+Sans&family=Poppins:wght@300;400;500;600;700;800&display=swap');
235
+
236
+ *{
237
+ margin: 0;
238
+ padding: 0;
239
+ box-sizing: border-box;
240
+ font-family: 'Ubuntu',sans-serif;
241
+ }
242
+
243
+ a{
244
+ text-decoration: none;
245
+ color: #000;
246
+ }
247
+
248
+
249
+ body{
250
+ background-color: #fff;
251
+ }
252
+
253
+ nav{
254
+ padding: 0 80px;
255
+ display: flex;
256
+ align-items: center;
257
+ justify-content: space-between;
258
+ }
259
+
260
+
261
+ .nav-logo{
262
+ margin-top: 20px;
263
+ }
264
+
265
+ .astarlogo{
266
+ width: 230px;
267
+ display: flex;
268
+ border-style: none;
269
+ display: none;
270
+ }
271
+
272
+
273
+ .nav-links{
274
+ list-style: none;
275
+ display: flex;
276
+ align-items: center;
277
+ gap: 3rem ;
278
+ }
279
+
280
+ .link a{
281
+ position: relative;
282
+ padding-bottom: 0.75rem;
283
+ color:#083484;
284
+ font-size: 1rem;
285
+ font-weight: 600;
286
+ font-family: 'Poppins',sans-serif;
287
+ }
288
+
289
+ .link a::after {
290
+ content: "";
291
+ position: absolute;
292
+ height: 2px;
293
+ width: 0;
294
+ bottom: 0;
295
+ left: 0;
296
+ background-color: #083484;
297
+ transition: all 0.3s ease;
298
+
299
+ }
300
+
301
+ .link a:hover::after{
302
+ width: 70%;
303
+ }
304
+
305
+ nav .login button{
306
+ padding: 8px 14px;
307
+ border: none;
308
+ cursor: pointer;
309
+ background-color: transparent;
310
+ }
311
+
312
+ nav .login button#signup{
313
+ background-color: #083484;
314
+ color: #fff;
315
+ border-radius: 4px;
316
+ margin-right: 14px;
317
+ padding: 15px 20px;
318
+ margin-top: 25px;
319
+ display: none;
320
+
321
+ }
322
+
323
+ header{
324
+ padding: 0 80px;
325
+ height: calc(100vh-80px);
326
+ display: flex;
327
+ align-items: center;
328
+ justify-content: space-between;
329
+ }
330
+
331
+ header .left h1 {
332
+ font-size: 80px;
333
+ display: flex;
334
+ justify-content: center;
335
+ margin-top: 17rem;
336
+
337
+ }
338
+
339
+ header .left span{
340
+ font-size: 80px;
341
+ color: #083484;
342
+ display: flex;
343
+ justify-content: center;
344
+
345
+ }
346
+ header .left .second-line{
347
+ font-size: 80px;
348
+ color: #083484;
349
+ display: flex;
350
+ justify-content: center;
351
+ font-weight: 400;
352
+
353
+ }
354
+
355
+ header .left p{
356
+ margin-top: 35px;
357
+ font-stretch: ultra-condensed;
358
+ color: #777;
359
+ display: flex;
360
+ justify-content: center;
361
+ text-align: center;
362
+ margin-bottom: 10px;
363
+ }
364
+
365
+ header .left a{
366
+ display: flex;
367
+ align-items: center;
368
+ background: #083484;
369
+ width: 150px;
370
+ padding: 8px;
371
+ border-radius: 60px;
372
+ }
373
+
374
+ header .left a i{
375
+ background-color: #fff;
376
+ font-size: 24px;
377
+ border-radius: 50%;
378
+ padding: 8px;
379
+ }
380
+
381
+ header .left a span{
382
+ color: #fff;
383
+ margin-left: 22px;
384
+ }
385
+
386
+ .container {
387
+ padding:30px;
388
+ text-align: center;
389
+ overflow: auto;
390
+ margin-top: 500px;
391
+ }
392
+
393
+ .sub-header {
394
+ font-size: 4em;
395
+ text-align: center;
396
+ color: #083484;
397
+ font-family: 'Montserrat',sans-serif;
398
+ }
399
+
400
+
401
+
402
+
403
+ """
404
+
405
+
406
+
407
+ js_func = """
408
+ function refresh() {
409
+ const url = new URL(window.location);
410
+
411
+ if (url.searchParams.get('__theme') !== 'light') {
412
+ url.searchParams.set('__theme', 'light');
413
+ window.location.href = url.href;
414
+ }
415
+ }
416
+
417
+ """
418
+
419
+
420
+
421
+ # Gradio interface components
422
+ with gr.Blocks(css = css,js=js_func) as demo:
423
+
424
+ gr.HTML("""
425
+ <nav>
426
+ <div class = "nav-logo" >
427
+ <a href="#">
428
+ <img class="astarlogo" src="" >
429
+ </a>
430
+ </div>
431
+ <ul class="nav-links">
432
+ <li class = "link"><a href="#">HOME</a></li>
433
+ <li id = "link1" class = "link"><a href="#">OFFLINE DETECTION</a></li>
434
+ <li id = "link2" class = "link"><a href="#">CONTACT US</a></li>
435
+ </ul>
436
+ <div class="login">
437
+ <button id="signup">Get Started</button>
438
+ </div>
439
+ </nav>
440
+
441
+
442
+ <header>
443
+ <div class="left">
444
+ <h1><span>OIS</span><br></h1>
445
+ <span class="second-line">AI Detection Model</span>
446
+ <p>
447
+ The OIS AI Detection Model enhances manufacturing by using the powerful YOLOv11 algorithm on
448
+ a Raspberry Pi for real-time, on-device defect detection. It automates quality control,
449
+ reduces human error, and minimizes downtime. With a user-friendly web interface,
450
+ the model enables offline swift defect identification, seamless integration into
451
+ production, and improving both efficiency and product quality.
452
+ </p>
453
+ </div>
454
+
455
+ </header>
456
+
457
+ <section class="container">
458
+
459
+ <p class="sub-header">OFFLINE DETECTION</p>
460
+
461
+ </section>
462
+
463
+ """)
464
+
465
+
466
+ with gr.Row():
467
+ # Image Upload and Display in two columns
468
+ with gr.Column():
469
+ gr.Markdown("### Input")
470
+ upload_image_component = gr.Image(type="pil", label="Select Image")
471
+
472
+ with gr.Column():
473
+ gr.Markdown("### Output")
474
+ output_image_component = gr.Image(type="pil", label="Annotated Image")
475
+
476
+ # Button for Object Detection below the columns
477
+ with gr.Row(): # Create a new row for the button
478
+ apply_detection_btn = gr.Button("Apply Detection")
479
+ output_annotations = gr.State() # Store annotations
480
+ apply_detection_btn.click(apply_detection, inputs=upload_image_component, outputs=[output_image_component, output_annotations])
481
+
482
+ # Row for the graphs
483
+ with gr.Row():
484
+ # Individual graphs for each defect category
485
+ nicks_graph_component = gr.Plot(label="Nicks Area Distribution")
486
+ dents_graph_component = gr.Plot(label="Dents Area Distribution")
487
+ scratches_graph_component = gr.Plot(label="Scratches Area Distribution")
488
+ pittings_graph_component = gr.Plot(label="Pittings Area Distribution")
489
+
490
+ # Button to generate graphs
491
+ with gr.Row():
492
+ graph_btn = gr.Button("Generate Graphs")
493
+ graph_btn.click(generate_graphs_btn, inputs=output_annotations, outputs=[
494
+ nicks_graph_component, dents_graph_component,
495
+ scratches_graph_component, pittings_graph_component
496
+ ])
497
+
498
+ # Row for frequency graph
499
+ with gr.Row():
500
+ frequency_graph_component = gr.Plot(label="Defect Frequency Distribution") # Frequency Graph
501
+
502
+ # Additional row for frequency graph button (if needed)
503
+ with gr.Row():
504
+ freq_graph_btn = gr.Button("Refresh Frequency Graph")
505
+ freq_graph_btn.click(detection.generate_frequency_graph,
506
+ inputs=output_annotations,
507
+ outputs=frequency_graph_component)
508
+
509
+ # Launch the Gradio interface
510
+ demo.launch(share=True)
511
+
512
+
model/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67424dbaf2d9c3f07f356a59c37187ef1a7b9f59ebabf77c5cb7f9cb9507f107
3
+ size 38138560
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch
3
+ torchvision
4
+ opencv-python
5
+ gradio==5.4.0
6
+ sahi==0.11.18
7
+ pillow
8
+ plotly==5.24.1
9
+ ultralytics==8.3.24