Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from sahi.predict import get_sliced_prediction | |
| from sahi import AutoDetectionModel | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| import torch | |
| import spaces | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| class Detection: | |
| def __init__(self): | |
| # Set the model path and confidence threshold | |
| yolov8_model_path = "./model/best.pt" # Update to your model path | |
| # Initialize the AutoDetectionModel | |
| self.model = AutoDetectionModel.from_pretrained( | |
| model_type='yolov8', | |
| model_path=yolov8_model_path, | |
| confidence_threshold=0.3, | |
| device='cpu' # Change to 'cuda:0' if you are using a GPU | |
| ) | |
| def detect_from_image(self, image): | |
| # Perform sliced prediction with SAHI | |
| results = get_sliced_prediction( | |
| image=image, | |
| detection_model=self.model, | |
| slice_height=256, | |
| slice_width=256, | |
| overlap_height_ratio=0.2, | |
| overlap_width_ratio=0.2, | |
| postprocess_type='NMS', | |
| postprocess_match_metric='IOU', | |
| postprocess_match_threshold=0.1, | |
| postprocess_class_agnostic=True, | |
| ) | |
| # Retrieve COCO annotations | |
| coco_annotations = results.to_coco_annotations() | |
| return coco_annotations | |
| def draw_annotations(self, image, annotations): | |
| """Draw bounding boxes on the image based on COCO annotations using OpenCV.""" | |
| # Define colors for each category in BGR (OpenCV uses BGR format) | |
| category_styles = { | |
| 'Nicks': {'color': (255, 60, 60), 'thickness': 2}, # Nicks (Red) | |
| 'Dents': {'color': (255, 148, 156), 'thickness': 2}, # Dents (Light Red) | |
| 'Scratches': {'color': (255, 116, 28), 'thickness': 2}, # Scratches (Orange) | |
| 'Pittings': {'color': (255, 180, 28), 'thickness': 2} # Pittings (Yellow) | |
| } | |
| for annotation in annotations: | |
| bbox = annotation['bbox'] # Extract the bounding box | |
| category_name = annotation['category_name'] | |
| score = annotation.get('score', 0) # Extract confidence score, default to 0 if not present | |
| # Get color and thickness for the current category | |
| style = category_styles.get(category_name, {'color': (255, 0, 0), 'thickness': 2}) # Default to red if not found | |
| # Draw rectangle | |
| cv2.rectangle(image, | |
| (int(bbox[0]), int(bbox[1])), | |
| (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), | |
| style['color'], | |
| style['thickness']) | |
| # Prepare text with category and confidence score | |
| text = f"{category_name}: {score:.2f}" # Format the score to two decimal places | |
| # Put category text with score | |
| cv2.putText(image, | |
| text, | |
| (int(bbox[0]), int(bbox[1] - 10)), # Position above the rectangle | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| style['color'], | |
| 2) | |
| return image | |
| def generate_individual_graphs(self, annotations): | |
| """Generate individual area distribution histograms for each defect category.""" | |
| # Dictionary to hold areas for each category | |
| category_areas = { | |
| 'Nicks': [], | |
| 'Dents': [], | |
| 'Scratches': [], | |
| 'Pittings': [] | |
| } | |
| # Populate the category_areas dictionary | |
| for annotation in annotations: | |
| category_name = annotation['category_name'] | |
| area = annotation['bbox'][2] * annotation['bbox'][3] # Width * Height | |
| if category_name in category_areas: | |
| category_areas[category_name].append(area) | |
| # Create individual area distribution histograms for each ctegory | |
| individual_graphs = {} | |
| for category in ['Nicks', 'Dents', 'Scratches', 'Pittings']: | |
| areas = category_areas[category] | |
| fig = go.Figure() | |
| if areas: # Check if there are areas to plot | |
| # Create a histogram and store the frequencies | |
| histogram_data = go.Histogram( | |
| x=areas, | |
| name=category, | |
| marker_color=self.get_color(category), # Use associated color | |
| opacity=1, | |
| nbinsx=10 # Number of bins | |
| ) | |
| fig.add_trace(histogram_data) | |
| # Get the frequencies and edges for swapping axes | |
| frequencies = histogram_data.y | |
| edges = histogram_data.x | |
| # Create a bar chart to swap the axes | |
| fig = go.Figure(data=[ | |
| go.Bar( | |
| x=frequencies, # Frequencies on x-axis | |
| y=edges, # Edges on y-axis | |
| name=category, | |
| marker_color=self.get_color(category), # Use associated color | |
| opacity=1 | |
| ) | |
| ]) | |
| else: # Generate an empty graph if no areas | |
| fig.add_trace(go.Bar(x=[], y=[], name=category)) # Empty graph | |
| # Update layout with swapped axes | |
| fig.update_layout( | |
| title=f'Area Distribution of {category}', | |
| xaxis_title='Frequency', # Frequency on x-axis | |
| yaxis_title='Area', # Area on y-axis | |
| showlegend=True | |
| ) | |
| individual_graphs[category] = fig | |
| return individual_graphs['Nicks'], individual_graphs['Dents'], individual_graphs['Scratches'], individual_graphs['Pittings'] | |
| def generate_frequency_graph(self, annotations): | |
| """Generate a frequency bar chart for defect categories.""" | |
| category_counts = { | |
| 'Nicks': 0, | |
| 'Dents': 0, | |
| 'Scratches': 0, | |
| 'Pittings': 0 | |
| } | |
| # Count occurrences of each defect category | |
| for annotation in annotations: | |
| category_name = annotation['category_name'] | |
| if category_name in category_counts: | |
| category_counts[category_name] += 1 | |
| # Create a bar chart for frequency | |
| freq_chart = go.Figure() | |
| category_colors = { | |
| 'Nicks': 'rgba(255, 60, 60, 0.7)', # Red | |
| 'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red | |
| 'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange | |
| 'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow | |
| } | |
| for category, count in category_counts.items(): | |
| freq_chart.add_trace(go.Bar( | |
| x=[category], | |
| y=[count], | |
| name=category, | |
| marker_color=category_colors.get(category, 'blue') # Default to blue if not found | |
| )) | |
| freq_chart.update_layout( | |
| title='Frequency of Defects', | |
| xaxis_title='Defect Category', | |
| yaxis_title='Count', | |
| barmode='group' | |
| ) | |
| return freq_chart | |
| def get_color(self, category_name): | |
| """Get the color associated with a category name.""" | |
| category_styles = { | |
| 'Nicks': 'rgba(255, 60, 60, 0.7)', # Red | |
| 'Dents': 'rgba(255, 148, 156, 0.7)', # Light Red | |
| 'Scratches': 'rgba(255, 116, 28, 0.7)', # Orange | |
| 'Pittings': 'rgba(255, 180, 28, 0.7)' # Yellow | |
| } | |
| return category_styles.get(category_name, (255, 0, 0)) # Default to red if not found | |
| detection = Detection() | |
| def upload_image(image): | |
| """Process the uploaded image (if needed) and display it.""" | |
| return image | |
| def apply_detection(image): | |
| """Run object detection on the uploaded image and return the annotated image.""" | |
| # Convert image from PIL to NumPy array | |
| img = np.array(image) | |
| # Perform detection and get COCO annotations | |
| annotations = detection.detect_from_image(img) | |
| # Draw the annotations on the image using OpenCV | |
| annotated_image = detection.draw_annotations(img, annotations) | |
| # Convert back to PIL format for Gradio output | |
| return Image.fromarray(annotated_image), annotations | |
| def generate_graphs_btn(annotations): | |
| """Generate interactive graphs from the annotations.""" | |
| # Generate individual graphs for each defect category | |
| individual_graphs = detection.generate_individual_graphs(annotations) | |
| frequency_graph = detection.generate_frequency_graph(annotations) | |
| return individual_graphs | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&family=Montserrat:wght@700&family=Open+Sans&family=Poppins:wght@300;400;500;600;700;800&display=swap'); | |
| *{ | |
| margin: 0; | |
| padding: 0; | |
| box-sizing: border-box; | |
| font-family: 'Ubuntu',sans-serif; | |
| } | |
| a{ | |
| text-decoration: none; | |
| color: #000; | |
| } | |
| body{ | |
| background-color: #fff; | |
| } | |
| header{ | |
| padding: 0 80px; | |
| height: calc(100vh-80px); | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| } | |
| header .left h1 { | |
| font-size: 80px; | |
| display: flex; | |
| justify-content: center; | |
| margin-top: 17rem; | |
| } | |
| header .left span{ | |
| font-size: 80px; | |
| color: #083484; | |
| display: flex; | |
| justify-content: center; | |
| } | |
| header .left .second-line{ | |
| font-size: 80px; | |
| color: #083484; | |
| display: flex; | |
| justify-content: center; | |
| font-weight: 400; | |
| } | |
| header .left p{ | |
| margin-top: 35px; | |
| font-stretch: ultra-condensed; | |
| color: #777; | |
| display: flex; | |
| justify-content: center; | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| header .left a{ | |
| display: flex; | |
| align-items: center; | |
| background: #083484; | |
| width: 150px; | |
| padding: 8px; | |
| border-radius: 60px; | |
| } | |
| header .left a i{ | |
| background-color: #fff; | |
| font-size: 24px; | |
| border-radius: 50%; | |
| padding: 8px; | |
| } | |
| header .left a span{ | |
| color: #fff; | |
| margin-left: 22px; | |
| } | |
| .container { | |
| padding:30px; | |
| text-align: center; | |
| overflow: auto; | |
| margin-top: 500px; | |
| } | |
| .sub-header { | |
| font-size: 4em; | |
| text-align: center; | |
| color: #083484; | |
| font-family: 'Montserrat',sans-serif; | |
| } | |
| """ | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # Gradio interface components | |
| with gr.Blocks(css = css,js=js_func) as demo: | |
| gr.HTML(""" | |
| <header> | |
| <div class="left"> | |
| <h1><span>OIS</span><br></h1> | |
| <span class="second-line">AI Detection Model</span> | |
| <p> | |
| The OIS AI Detection Model enhances manufacturing by using the powerful YOLOv11 algorithm on | |
| a Raspberry Pi for real-time, on-device defect detection. It automates quality control, | |
| reduces human error, and minimizes downtime. With a user-friendly web interface, | |
| the model enables offline swift defect identification, seamless integration into | |
| production, and improving both efficiency and product quality. | |
| </p> | |
| </div> | |
| </header> | |
| <section class="container"> | |
| <p class="sub-header">OFFLINE DETECTION</p> | |
| </section> | |
| """) | |
| with gr.Row(): | |
| # Image Upload and Display in two columns | |
| with gr.Column(): | |
| gr.Markdown("### Input") | |
| upload_image_component = gr.Image(type="pil", label="Select Image") | |
| with gr.Column(): | |
| gr.Markdown("### Output") | |
| output_image_component = gr.Image(type="pil", label="Annotated Image") | |
| apply_detection_btn = gr.Button("Apply Detection") | |
| output_annotations = gr.State() # Store annotations | |
| apply_detection_btn.click(apply_detection, inputs=upload_image_component, outputs=[output_image_component, output_annotations]) | |
| # Row for the graphs | |
| with gr.Row(): | |
| # Individual graphs for each defect category | |
| nicks_graph_component = gr.Plot(label="Nicks Area Distribution") | |
| dents_graph_component = gr.Plot(label="Dents Area Distribution") | |
| scratches_graph_component = gr.Plot(label="Scratches Area Distribution") | |
| pittings_graph_component = gr.Plot(label="Pittings Area Distribution") | |
| # Button to generate graphs | |
| with gr.Row(): | |
| graph_btn = gr.Button("Generate Area Distribution Graphs") | |
| graph_btn.click(generate_graphs_btn, inputs=output_annotations, outputs=[ | |
| nicks_graph_component, dents_graph_component, | |
| scratches_graph_component, pittings_graph_component | |
| ]) | |
| # Row for frequency graph | |
| with gr.Row(): | |
| frequency_graph_component = gr.Plot(label="Defect Frequency Distribution") # Frequency Graph | |
| # Row for frequency graph btn | |
| with gr.Row(): | |
| freq_graph_btn = gr.Button("Generate Frequency Graph") | |
| freq_graph_btn.click(detection.generate_frequency_graph, | |
| inputs=output_annotations, | |
| outputs=frequency_graph_component) | |
| # Launch the Gradio interface | |
| demo.launch(share=True) | |