OIS-AI-Defect / app.py
jovian
update
de18998
import gradio as gr
import numpy as np
import cv2
from sahi.predict import get_sliced_prediction
from sahi import AutoDetectionModel
from PIL import Image
import plotly.graph_objects as go
import torch
#import spaces
import os
import shutil
import subprocess
import os
import shutil
import subprocess
device = "cuda:0" if torch.cuda.is_available() else "cpu"
from torchvision.ops import box_iou
#testing
class Detection:
# def __init__(self):
# # Set the model path and confidence threshold
# yolov8_model_path = "./model/train_model.pt" # Update to your model path
# #yolov8_model_path = "./model/best_100epochs_latest.pt" # Update to your model path
# # Initialize the AutoDetectionModel
# self.model = AutoDetectionModel.from_pretrained(
# model_type='yolov8',
# model_path=yolov8_model_path,
# confidence_threshold=0.3,
# device='cpu' # Change to 'cuda:0' if you are using a GPU
# )
# def detect_from_image(self, image):
# # Perform sliced prediction with SAHI
# results = get_sliced_prediction(
# image=image,
# detection_model=self.model,
# slice_height=256,
# slice_width=256,
# overlap_height_ratio=0.5,
# overlap_width_ratio=0.5,
# postprocess_type='NMS',
# postprocess_match_metric='IOU',
# postprocess_match_threshold=0.1,
# postprocess_class_agnostic=True,
# )
# # Retrieve COCO annotations
# coco_annotations = results.to_coco_annotations()
# return coco_annotations
def __init__(self):
# Set the paths for the two YOLOv8 models
yolov8_model_path1 = "./model/train_model.pt" # Update to your model path
yolov8_model_path2 = "./model/best_100epochs_latest.pt" # Update to the second model path
self.model1 = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path1,
confidence_threshold=0.3,
device='cuda:0'
)
self.model2 = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path2,
confidence_threshold=0.3,
device='cuda:0'
)
def detect_from_image(self, image,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input):
results1 = get_sliced_prediction(
image=image,
detection_model=self.model1,
slice_height=slice_height_input,
slice_width=slice_width_input,
overlap_height_ratio=overlap_height_input,
overlap_width_ratio=overlap_width_input,
postprocess_type='NMS',
postprocess_match_metric='IOU',
postprocess_match_threshold=0.1,
postprocess_class_agnostic=True,
)
results2 = get_sliced_prediction(
image=image,
detection_model=self.model2,
slice_height=slice_height_input,
slice_width=slice_width_input,
overlap_height_ratio=overlap_height_input,
overlap_width_ratio=overlap_width_input,
postprocess_type='NMS',
postprocess_match_metric='IOU',
postprocess_match_threshold=0.1,
postprocess_class_agnostic=True,
)
# Convert results to COCO annotations
annotations1 = results1.to_coco_annotations()
annotations2 = results2.to_coco_annotations()
# Combine results from both models
combined_annotations = self.combine_results(annotations1, annotations2)
return combined_annotations
def combine_results(self, annotations1, annotations2, iou_threshold=0.1):
"""
Combine the results of two sets of annotations, keeping the higher-confidence
prediction only when the IoU between two bounding boxes is above the threshold.
:param annotations1: COCO annotations from model 1
:param annotations2: COCO annotations from model 2
:param iou_threshold: IoU threshold to consider two boxes overlapping
:return: Combined annotations list
"""
combined = annotations1.copy()
for ann2 in annotations2:
box2 = ann2['bbox']
conf2 = ann2['score']
keep = True
for ann1 in combined:
box1 = ann1['bbox']
conf1 = ann1['score']
# Compute IoU between boxes
box1_array = np.array([[box1[0], box1[1], box1[0] + box1[2], box1[1] + box1[3]]])
box2_array = np.array([[box2[0], box2[1], box2[0] + box2[2], box2[1] + box2[3]]])
iou = box_iou(torch.tensor(box1_array), torch.tensor(box2_array)).item()
# Print IoU for debugging
print(f"IoU {iou:.4f}")
# Only check confidence if IoU is above the threshold
if iou > iou_threshold:
# Keep the annotation with higher confidence
if conf2 <= conf1:
keep = False
else:
# Remove the lower-confidence annotation from `combined`
combined.remove(ann1)
break
if keep:
combined.append(ann2)
return combined
#-----------------------------------------------------------------------------------------------------------------------
def draw_annotations(self, image, annotations):
"""Draw bounding boxes on the image based on COCO annotations using OpenCV."""
# Define colors for each category in BGR (OpenCV uses BGR format)
category_styles = {
'Nicks': {'color': (255, 60, 60), 'thickness': 2}, # Nicks (Red)
'Dents': {'color': (255, 148, 156), 'thickness': 2}, # Dents (Light Red)
'Scratches': {'color': (255, 116, 28), 'thickness': 2}, # Scratches (Orange)
'Pittings': {'color': (255, 180, 28), 'thickness': 2} # Pittings (Yellow)
}
for annotation in annotations:
bbox = annotation['bbox'] # Extract the bounding box
category_name = annotation['category_name']
score = annotation.get('score', 0) # Extract confidence score, default to 0 if not present
# Get color and thickness for the current category
style = category_styles.get(category_name, {'color': (255, 0, 0), 'thickness': 2}) # Default to red if not found
# Draw rectangle
cv2.rectangle(image,
(int(bbox[0]), int(bbox[1])),
(int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])),
style['color'],
style['thickness'])
# Prepare text with category and confidence score
text = f"{category_name}: {score:.2f}" # Format the score to two decimal places
# Put category text with score
cv2.putText(image,
text,
(int(bbox[0]), int(bbox[1] - 10)), # Position above the rectangle
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
style['color'],
2)
return image
def generate_individual_graphs(self, annotations):
"""Generate individual area distribution histograms for each defect category."""
# Dictionary to hold areas for each category
category_areas = {
'Nicks': [],
'Dents': [],
'Scratches': [],
'Pittings': []
}
# Populate the category_areas dictionary
for annotation in annotations:
category_name = annotation['category_name']
area = annotation['area']
if category_name in category_areas:
category_areas[category_name].append(area)
# Create individual area distribution histograms for each ctegory
individual_graphs = {}
for category in ['Nicks', 'Dents', 'Scratches', 'Pittings']:
areas = category_areas[category]
fig = go.Figure()
if areas: # Check if there are areas to plot
# Create a histogram and store the frequencies
histogram_data = go.Histogram(
x=areas,
name=category,
marker_color=self.get_color(category), # Use associated color
opacity=1,
nbinsx=50 # Number of bins
)
fig.add_trace(histogram_data)
# Get the frequencies and edges for swapping axes
frequencies = histogram_data.y
edges = histogram_data.x
# Create a bar chart to swap the axes
fig = go.Figure(data=[
go.Bar(
x=frequencies, # Frequencies on x-axis
y=edges, # Edges on y-axis
name=category,
marker_color=self.get_color(category), # Use associated color
opacity=1
)
])
else: # Generate an empty graph if no areas
fig.add_trace(go.Bar(x=[], y=[], name=category)) # Empty graph
# Update layout with swapped axes
fig.update_layout(
title=f'Size Distribution of {category}',
xaxis_title='Frequency', # Frequency on x-axis
yaxis_title='Size', # Area on y-axis
showlegend=True
)
individual_graphs[category] = fig
return individual_graphs['Nicks'], individual_graphs['Dents'], individual_graphs['Scratches'], individual_graphs['Pittings']
def generate_frequency_graph(self, annotations):
"""Generate a frequency bar chart for defect categories."""
category_counts = {
'Nicks': 0,
'Dents': 0,
'Scratches': 0,
'Pittings': 0
}
# Count occurrences of each defect category
for annotation in annotations:
category_name = annotation['category_name']
if category_name in category_counts:
category_counts[category_name] += 1
# Create a bar chart for frequency
freq_chart = go.Figure()
category_colors = {
'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
}
for category, count in category_counts.items():
freq_chart.add_trace(go.Bar(
x=[category],
y=[count],
name=category,
marker_color=category_colors.get(category, 'blue') # Default to blue if not found
))
freq_chart.update_layout(
title='Frequency of Defects',
xaxis_title='Defect Category',
yaxis_title='Count',
barmode='group'
)
return freq_chart
def get_color(self, category_name):
"""Get the color associated with a category name."""
category_styles = {
'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
}
return category_styles.get(category_name, (255, 0, 0)) # Default to red if not found
detection = Detection()
def upload_image(image):
"""Process the uploaded image (if needed) and display it."""
return image
#@spaces.GPU
def apply_detection(image,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input):
"""Run object detection on the uploaded image and return the annotated image."""
# Convert image from PIL to NumPy array
img = np.array(image)
# Perform detection and get COCO annotations
annotations = detection.detect_from_image(img,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input)
# Draw the annotations on the image using OpenCV
annotated_image = detection.draw_annotations(img, annotations)
# Convert back to PIL format for Gradio output
return Image.fromarray(annotated_image), annotations
def generate_graphs_btn(annotations):
"""Generate interactive graphs from the annotations."""
# Generate individual graphs for each defect category
individual_graphs = detection.generate_individual_graphs(annotations)
frequency_graph = detection.generate_frequency_graph(annotations)
return individual_graphs
# Function to handle login authentication
def login_auth(username, password):
if username != password:
raise gr.Error("Username or Password is wrong") # Raise an error on failed login
return True # Return True if authentication is successful
# Function to create individual bar charts for each defect type
def generate_confidence_bar_chart(annotations):
# Categorize confidence scores
confidence_bins = {'<25%': 0, '25%-75%': 0, '>75%': 0}
defect_bins = {
"Nicks": confidence_bins.copy(),
"Dents": confidence_bins.copy(),
"Scratches": confidence_bins.copy(),
"Pittings": confidence_bins.copy(),
}
# Populate bins based on annotations
for annotation in annotations:
defect = annotation["category_name"]
score = annotation["score"] * 100 # Convert to percentage
if score < 25:
defect_bins[defect]['<25%'] += 1
elif 25 <= score <= 75:
defect_bins[defect]['25%-75%'] += 1
else:
defect_bins[defect]['>75%'] += 1
# Define colors for each defect
category_styles = {
'Nicks': 'rgba(255, 60, 60, 0.7)', # Red
'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red
'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange
'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow
}
# Generate individual charts
charts = []
for defect, bins in defect_bins.items():
fig = go.Figure()
fig.add_trace(go.Bar(
name=defect,
x=list(bins.keys()), # Confidence ranges
y=list(bins.values()), # Counts
text=[f"{v} defects" for v in bins.values()], # Hover text
hoverinfo="text",
marker_color=category_styles.get(defect, 'rgba(255, 0, 0, 0.7)') # Default to red
))
# Customize layout
fig.update_layout(
title=f"{defect} Confidence Score Distribution",
xaxis_title="Confidence Range",
yaxis_title="Defect Count",
template="plotly_white"
)
charts.append(fig)
return charts # Return list of charts
# Directory to save images
img_dir = "./stitching/img_dir/"
output_dir = "./"
os.makedirs(img_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
# Function to handle the stitching process
def save_and_stitch(first_image, second_image, third_image, fourth_image):
# Save images to `img_dir`
images = [first_image, second_image, third_image, fourth_image]
for idx, img in enumerate(images):
if img is not None:
file_path = os.path.join(img_dir, f"Image_{idx + 1}.jpg")
img.save(file_path, format="JPEG")
# Execute the stitching command for all image files in the folder
command = f"stitch {img_dir}/Image_*.jpg"
try:
subprocess.run(command, shell=True, check=True)
# Load the result image from ./stitching/result.jpg
result_image_path = os.path.join(output_dir, "result.jpg")
if os.path.exists(result_image_path):
print("found")
return Image.open(result_image_path)
else:
print("not found")
return None # If result image doesn't exist, return None
except subprocess.CalledProcessError as e:
print(f"Error executing command: {str(e)}")
return None
# Function to clear the img_dir
def clear_img_dir():
for file_name in os.listdir(img_dir):
file_path = os.path.join(img_dir, file_name)
try:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
os.rmdir(file_path) # For directories, remove them
except Exception as e:
print(f"Error deleting file {file_name}: {str(e)}")
return "Images cleared from img_dir!"
# Gradio interface components
with gr.Blocks() as demo:
# State variable to track login status
login_successful = gr.State(value=False)
with gr.Row(visible=False) as header_row:
gr.HTML("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&family=Montserrat:wght@700&family=Open+Sans&family=Poppins:wght@300;400;500;600;700;800&display=swap');
*{
margin: 0;
padding: 0;
box-sizing: border-box;
font-family: 'Ubuntu',sans-serif;
}
a{
text-decoration: none;
color: #000;
}
body{
background-color: #fff;
}
.gradio-container-5-4-0 .prose * {
color: #083484;
}
.gradio-container-5-4-0 .prose :first-child {
margin-top: 85px
}
header{
padding: 0 80px;
height: calc(100vh-80px);
display: flex;
align-items: center;
justify-content: space-between;
}
header .left h1 {
font-size: 80px;
display: flex;
justify-content: center;
margin-top: 100px;
}
header .left span{
font-size: 80px;
color: #083484;
display: flex;
justify-content: center;
}
header .left .second-line{
font-size: 80px;
color: #083484;
display: flex;
justify-content: center;
font-weight: 400;
}
header .left p{
margin-top: 35px;
font-stretch: ultra-condensed;
color: #777;
display: flex;
justify-content: center;
text-align: center;
margin-bottom: 10px;
}
header .left a{
display: flex;
align-items: center;
background: #083484;
width: 150px;
padding: 8px;
border-radius: 60px;
}
header .left a i{
background-color: #fff;
font-size: 24px;
border-radius: 50%;
padding: 8px;
}
header .left a span{
color: #fff;
margin-left: 22px;
}
.place {
padding:30px;
text-align: center;
overflow: auto;
margin-top: 500px;
}
.sub-header {
font-size: 4em;
text-align: center;
color: #083484;
font-family: 'Montserrat',sans-serif;
}
.gradio-container-5-4-0 .prose h1, .gradio-container-5-4-0 .prose h2, .gradio-container-5-4-0 .prose h3, .gradio-container-5-4-0 .prose h4, .gradio-container-5-4-0 .prose h5 {
margin: var(--spacing-xxl) 0 var(--spacing-lg);
font-weight: var(--prose-header-text-weight);
line-height: 1.3;
color: #083484;
text-align: center;}
@media screen and (max-width: 1024px) {
header {
margin-top: 5em;
display: flex;
flex-direction: column;
align-items: center; /* Centers items horizontally */
text-align: center; /* Centers text inside elements */
}
header .left {
display: flex;
flex-direction: column;
align-items: center; /* Ensures all child elements are centered */
}
header .left h1 {
font-size: 60px;
}
header .left .second-line {
font-size: 60px;
text-align: center;
}
header .left p {
font-size: 15px;
}
}
@media screen and (max-width: 576px) {
header{
margin-top: 5em;
}
header .left h1 {
font-size: 50px;
}
header .left .second-line {
font-size: 40px;
text-align: center
}
header .left p{
font-size: 15px;
}
}
</style>
<header>
<div class="left">
<h1><span>OIS</span><br></h1>
<span class="second-line">AI Detection Model</span>
<p>
The OIS AI Detection Model enhances manufacturing by using the powerful YOLOv11 algorithm on
a Raspberry Pi for real-time, on-device defect detection. It automates quality control,
reduces human error, and minimizes downtime. With a user-friendly web interface,
the model enables offline swift defect identification, seamless integration into
production, and improving both efficiency and product quality.
</p>
</div>
</header>
<section class="place">
<p class="sub-header">OFFLINE DETECTION</p>
</section>
""")
with gr.Row(visible=False) as slicing_text:
gr.Markdown("### Choose the width and height dimension and the overlapping ratio of the slice to determine how small the model can detect")
with gr.Row(visible=False) as slicing_dim_input:
# Add inputs for width and height
slice_width_input = gr.Number(label="Slice Width (pixels)", value=256)
slice_height_input = gr.Number(label="Slice Height (pixels)", value=256)
with gr.Row(visible=False) as slicing_overlap_input:
overlap_width_input = gr.Slider(0, 1, step=0.01, label="Overlap Width Ratio", value=0.5)
overlap_height_input = gr.Slider(0, 1, step=0.01, label="Overlap Height Ratio", value=0.5)
with gr.Row(visible=False) as input_row:
# Image Upload and Display in two columns
with gr.Column():
gr.Markdown("### Input (Supported Image: bmp,jpg,png,jpeg,gif)")
upload_image_component = gr.Image(type="pil", label="Select Image")
with gr.Column():
gr.Markdown("### Output")
output_image_component = gr.Image(type="pil", label="Annotated Image")
apply_detection_btn = gr.Button("Apply Detection", variant='primary')
output_annotations = gr.State() # Store annotations
apply_detection_btn.click(apply_detection, inputs=[upload_image_component,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input], outputs=[output_image_component, output_annotations])
# Row for the graphs
with gr.Row(visible=False) as area_graph_row:
# Individual graphs for each defect category
nicks_graph_component = gr.Plot(label="Nicks Size Distribution")
dents_graph_component = gr.Plot(label="Dents Size Distribution")
scratches_graph_component = gr.Plot(label="Scratches Size Distribution")
pittings_graph_component = gr.Plot(label="Pittings Size Distribution")
# Button to generate graphs
with gr.Row(visible=False) as area_btn_row:
graph_btn = gr.Button("Generate Size Distribution Graphs",variant='primary')
graph_btn.click(generate_graphs_btn, inputs=output_annotations, outputs=[
nicks_graph_component, dents_graph_component,
scratches_graph_component, pittings_graph_component
])
# Row for frequency graph
with gr.Row(visible=False) as frequency_graph_row:
frequency_graph_component = gr.Plot(label="Defect Frequency Distribution") # Frequency Graph
# Row for frequency graph btn
with gr.Row(visible=False) as frequency_btn_row:
freq_graph_btn = gr.Button("Generate Frequency Graph",variant='primary')
freq_graph_btn.click(detection.generate_frequency_graph,
inputs=output_annotations,
outputs=frequency_graph_component)
# Gradio row for confidence bar chart
with gr.Row(visible=False) as confidence_bar_chart_row:
nicks_confidence_bar_chart = gr.Plot(label="Nicks Confidence Score Distribution")
dents_confidence_bar_chart = gr.Plot(label="Dents Confidence Score Distribution")
scratches_confidence_bar_chart = gr.Plot(label="Scratches Confidence Score Distribution")
pittings_confidence_bar_chart = gr.Plot(label="Pittings Confidence Score Distribution")
#Gradio row for confidence bar chart
with gr.Row(visible=False) as confidence_btn_row:
confidence_chart_btn = gr.Button("Generate Confidence Chart", variant="primary")
confidence_chart_btn.click(
generate_confidence_bar_chart,
inputs=output_annotations, # Pass the annotations
outputs=[nicks_confidence_bar_chart,dents_confidence_bar_chart,scratches_confidence_bar_chart,pittings_confidence_bar_chart]
)
with gr.Row(visible=False) as upload_image_stitching:
first_image_stitching = gr.Image(type="pil", label="Select Image 1")
second_image_stitching = gr.Image(type="pil", label="Select Image 2")
third_image_stitching = gr.Image(type="pil", label="Select Image 3")
fourth_image_stitching = gr.Image(type="pil", label="Select Image 4")
# Row for result output
with gr.Row(visible=False) as result_output_block:
result_output = gr.Image(type="pil",label="Stitched Output Image")
# Row for buttons
with gr.Row(visible=False) as stitching_btn:
apply_stitching_btn = gr.Button("Apply Stitching",variant="primary")
# Button click actions
apply_stitching_btn.click(
save_and_stitch,
inputs=[
first_image_stitching,
second_image_stitching,
third_image_stitching,
fourth_image_stitching,
],
outputs=result_output,
)
# Row for displaying status
with gr.Row(visible=False) as display_img_dir_status:
status_text = gr.Textbox(label="Status")
# Row for clearing images from img_dir
with gr.Row(visible=False) as clear_img_btn:
clear_img_dir_btn = gr.Button("Clear Images from img_dir",variant="primary")
clear_img_dir_btn.click(
clear_img_dir,
inputs=[],
outputs=status_text
)
# Login row, initially visible
with gr.Row(visible=True) as login_row:
with gr.Column():
gr.Markdown(value="<div style='text-align: center;'><h2>Login Page</h2></div>")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("")
with gr.Column(scale=1, variant='panel'):
username_tbox = gr.Textbox(label="User Name", interactive=True)
password_tbox = gr.Textbox(label="Password", interactive=True, type='password')
submit_btn = gr.Button(value='Submit', variant='primary', size='sm')
# On clicking the submit button
submit_btn.click(
login_auth,
inputs=[username_tbox, password_tbox],
outputs=login_successful # Set state variable on successful login
).then(
lambda login_state: (
gr.update(visible=login_state), # Show header_row
gr.update(visible=login_state), # Show slicing text
gr.update(visible=login_state), # Show slicing_dim_input
gr.update(visible=login_state), # Show slicing_overlap_input
gr.update(visible=login_state), # Show input_row
gr.update(visible=login_state), # Show area_graph_row
gr.update(visible=login_state), # Show area_btn_row
gr.update(visible=login_state), # Show frequency_graph_row
gr.update(visible=login_state), # Show frequency_btn_row
gr.update(visible=login_state), #Show Confidence chart
gr.update(visible=login_state), #Show Confidence btn
gr.update(visible=login_state), #Show upload image stitching
gr.update(visible=login_state), #Show stitched result output
gr.update(visible=login_state), #Show stitching btn
gr.update(visible=login_state), #Show display image dir status
gr.update(visible=login_state), #Show clear image btn
gr.update(visible=not login_state) # for login
),
inputs=login_successful,
outputs=[header_row,
slicing_text,
slicing_dim_input,
slicing_overlap_input,
input_row,
area_graph_row,
area_btn_row,
frequency_graph_row,
frequency_btn_row,
confidence_bar_chart_row,
confidence_btn_row,
upload_image_stitching,
result_output_block,
stitching_btn,
display_img_dir_status,
clear_img_btn,
login_row]
)
with gr.Column(scale=2):
gr.Markdown("")
# Launch the Gradio interface
demo.launch(share=True,show_api=False)