jovian
first
ed2b47e
import gradio as gr
import numpy as np
import cv2
from sahi.predict import get_sliced_prediction
from sahi import AutoDetectionModel
from PIL import Image
import plotly.graph_objects as go
import torch
import spaces
device = "cuda:0" if torch.cuda.is_available() else "cpu"
class Detection:
def __init__(self):
# Set the model path and confidence threshold
yolov8_model_path = "./model/best.pt" # Update to your model path
# Initialize the AutoDetectionModel
self.model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.3,
device='cpu' # Change to 'cuda:0' if you are using a GPU
)
def detect_from_image(self, image):
# Perform sliced prediction with SAHI
results = get_sliced_prediction(
image=image,
detection_model=self.model,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
postprocess_type='NMS',
postprocess_match_metric='IOU',
postprocess_match_threshold=0.1,
postprocess_class_agnostic=True,
)
# Retrieve COCO annotations
coco_annotations = results.to_coco_annotations()
return coco_annotations
def draw_annotations(self, image, annotations):
"""Draw bounding boxes on the image based on COCO annotations using OpenCV."""
# Define colors for each category in BGR (OpenCV uses BGR format)
category_styles = {
'Nicks': {'color': (255, 60, 60), 'thickness': 2}, # Nicks (Red)
'Dents': {'color': (255, 148, 156), 'thickness': 2}, # Dents (Light Red)
'Scratches': {'color': (255, 116, 28), 'thickness': 2}, # Scratches (Orange)
'Pittings': {'color': (255, 180, 28), 'thickness': 2} # Pittings (Yellow)
}
for annotation in annotations:
bbox = annotation['bbox'] # Extract the bounding box
category_name = annotation['category_name']
score = annotation.get('score', 0) # Extract confidence score, default to 0 if not present
# Get color and thickness for the current category
style = category_styles.get(category_name, {'color': (255, 0, 0), 'thickness': 2}) # Default to red if not found
# Draw rectangle
cv2.rectangle(image,
(int(bbox[0]), int(bbox[1])),
(int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])),
style['color'],
style['thickness'])
# Prepare text with category and confidence score
text = f"{category_name}: {score:.2f}" # Format the score to two decimal places
# Put category text with score
cv2.putText(image,
text,
(int(bbox[0]), int(bbox[1] - 10)), # Position above the rectangle
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
style['color'],
2)
return image
def generate_individual_graphs(self, annotations):
"""Generate individual area distribution histograms for each defect category."""
# Dictionary to hold areas for each category
category_areas = {
'Nicks': [],
'Dents': [],
'Scratches': [],
'Pittings': []
}
# Populate the category_areas dictionary
for annotation in annotations:
category_name = annotation['category_name']
area = annotation['bbox'][2] * annotation['bbox'][3] # Width * Height
if category_name in category_areas:
category_areas[category_name].append(area)
# Create individual area distribution histograms for each ctegory
individual_graphs = {}
for category in ['Nicks', 'Dents', 'Scratches', 'Pittings']:
areas = category_areas[category]
fig = go.Figure()
if areas: # Check if there are areas to plot
# Create a histogram and store the frequencies
histogram_data = go.Histogram(
x=areas,
name=category,
marker_color=self.get_color(category), # Use associated color
opacity=1,
nbinsx=10 # Number of bins
)
fig.add_trace(histogram_data)
# Get the frequencies and edges for swapping axes
frequencies = histogram_data.y
edges = histogram_data.x
# Create a bar chart to swap the axes
fig = go.Figure(data=[
go.Bar(
x=frequencies, # Frequencies on x-axis
y=edges, # Edges on y-axis
name=category,
marker_color=self.get_color(category), # Use associated color
opacity=1
)
])
else: # Generate an empty graph if no areas
fig.add_trace(go.Bar(x=[], y=[], name=category)) # Empty graph
# Update layout with swapped axes
fig.update_layout(
title=f'Area Distribution of {category}',
xaxis_title='Frequency', # Frequency on x-axis
yaxis_title='Area', # Area on y-axis
showlegend=True
)
individual_graphs[category] = fig
return individual_graphs['Nicks'], individual_graphs['Dents'], individual_graphs['Scratches'], individual_graphs['Pittings']
def generate_frequency_graph(self, annotations):
"""Generate a frequency bar chart for defect categories."""
category_counts = {
'Nicks': 0,
'Dents': 0,
'Scratches': 0,
'Pittings': 0
}
# Count occurrences of each defect category
for annotation in annotations:
category_name = annotation['category_name']
if category_name in category_counts:
category_counts[category_name] += 1
# Create a bar chart for frequency
freq_chart = go.Figure()
category_colors = {
'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
}
for category, count in category_counts.items():
freq_chart.add_trace(go.Bar(
x=[category],
y=[count],
name=category,
marker_color=category_colors.get(category, 'blue') # Default to blue if not found
))
freq_chart.update_layout(
title='Frequency of Defects',
xaxis_title='Defect Category',
yaxis_title='Count',
barmode='group'
)
return freq_chart
def get_color(self, category_name):
"""Get the color associated with a category name."""
category_styles = {
'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
}
return category_styles.get(category_name, (255, 0, 0)) # Default to red if not found
detection = Detection()
def upload_image(image):
"""Process the uploaded image (if needed) and display it."""
return image
@spaces.GPU
def apply_detection(image):
"""Run object detection on the uploaded image and return the annotated image."""
# Convert image from PIL to NumPy array
img = np.array(image)
# Perform detection and get COCO annotations
annotations = detection.detect_from_image(img)
# Draw the annotations on the image using OpenCV
annotated_image = detection.draw_annotations(img, annotations)
# Convert back to PIL format for Gradio output
return Image.fromarray(annotated_image), annotations
def generate_graphs_btn(annotations):
"""Generate interactive graphs from the annotations."""
# Generate individual graphs for each defect category
individual_graphs = detection.generate_individual_graphs(annotations)
frequency_graph = detection.generate_frequency_graph(annotations)
return individual_graphs
css = """
@import url('https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&family=Montserrat:wght@700&family=Open+Sans&family=Poppins:wght@300;400;500;600;700;800&display=swap');
*{
margin: 0;
padding: 0;
box-sizing: border-box;
font-family: 'Ubuntu',sans-serif;
}
a{
text-decoration: none;
color: #000;
}
body{
background-color: #fff;
}
header{
padding: 0 80px;
height: calc(100vh-80px);
display: flex;
align-items: center;
justify-content: space-between;
}
header .left h1 {
font-size: 80px;
display: flex;
justify-content: center;
margin-top: 17rem;
}
header .left span{
font-size: 80px;
color: #083484;
display: flex;
justify-content: center;
}
header .left .second-line{
font-size: 80px;
color: #083484;
display: flex;
justify-content: center;
font-weight: 400;
}
header .left p{
margin-top: 35px;
font-stretch: ultra-condensed;
color: #777;
display: flex;
justify-content: center;
text-align: center;
margin-bottom: 10px;
}
header .left a{
display: flex;
align-items: center;
background: #083484;
width: 150px;
padding: 8px;
border-radius: 60px;
}
header .left a i{
background-color: #fff;
font-size: 24px;
border-radius: 50%;
padding: 8px;
}
header .left a span{
color: #fff;
margin-left: 22px;
}
.container {
padding:30px;
text-align: center;
overflow: auto;
margin-top: 500px;
}
.sub-header {
font-size: 4em;
text-align: center;
color: #083484;
font-family: 'Montserrat',sans-serif;
}
"""
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
# Gradio interface components
with gr.Blocks(css = css,js=js_func) as demo:
gr.HTML("""
<header>
<div class="left">
<h1><span>OIS</span><br></h1>
<span class="second-line">AI Detection Model</span>
<p>
The OIS AI Detection Model enhances manufacturing by using the powerful YOLOv11 algorithm on
a Raspberry Pi for real-time, on-device defect detection. It automates quality control,
reduces human error, and minimizes downtime. With a user-friendly web interface,
the model enables offline swift defect identification, seamless integration into
production, and improving both efficiency and product quality.
</p>
</div>
</header>
<section class="container">
<p class="sub-header">OFFLINE DETECTION</p>
</section>
""")
with gr.Row():
# Image Upload and Display in two columns
with gr.Column():
gr.Markdown("### Input")
upload_image_component = gr.Image(type="pil", label="Select Image")
with gr.Column():
gr.Markdown("### Output")
output_image_component = gr.Image(type="pil", label="Annotated Image")
apply_detection_btn = gr.Button("Apply Detection")
output_annotations = gr.State() # Store annotations
apply_detection_btn.click(apply_detection, inputs=upload_image_component, outputs=[output_image_component, output_annotations])
# Row for the graphs
with gr.Row():
# Individual graphs for each defect category
nicks_graph_component = gr.Plot(label="Nicks Area Distribution")
dents_graph_component = gr.Plot(label="Dents Area Distribution")
scratches_graph_component = gr.Plot(label="Scratches Area Distribution")
pittings_graph_component = gr.Plot(label="Pittings Area Distribution")
# Button to generate graphs
with gr.Row():
graph_btn = gr.Button("Generate Area Distribution Graphs")
graph_btn.click(generate_graphs_btn, inputs=output_annotations, outputs=[
nicks_graph_component, dents_graph_component,
scratches_graph_component, pittings_graph_component
])
# Row for frequency graph
with gr.Row():
frequency_graph_component = gr.Plot(label="Defect Frequency Distribution") # Frequency Graph
# Row for frequency graph btn
with gr.Row():
freq_graph_btn = gr.Button("Generate Frequency Graph")
freq_graph_btn.click(detection.generate_frequency_graph,
inputs=output_annotations,
outputs=frequency_graph_component)
# Launch the Gradio interface
demo.launch(share=True)