Spaces:
Build error
Build error
| import os | |
| import tempfile | |
| from ultralytics import YOLO | |
| from clusterv2 import RiceClustering | |
| from whitenessv2 import RiceWhitenessAnalyzer | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import gradio as gr | |
| from io import BytesIO | |
| from PIL import Image | |
| import cv2 | |
| from matplotlib.patches import Polygon | |
| # Function to visualize Oriented Bounding Boxes (OBBs) instead of YOLO bboxes | |
| def visualize_obbs(image, obbs, labels): | |
| img = cv2.imread(image) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| plt.figure(figsize=(12, 8)) | |
| plt.imshow(img) | |
| # Define colors and text for each label | |
| colors = ['r', 'g', 'b', 'y', 'c'] # Adapt colors based on possible clusters | |
| labels_text = ['', '', '', '', ''] | |
| for obb, label in zip(obbs, labels): | |
| x_center, y_center, width, height, angle = obb | |
| color = colors[label] | |
| theta = 270 - angle | |
| half_w, half_h = width / 2, height / 2 | |
| corners = np.array([ | |
| [-half_w, -half_h], | |
| [half_w, -half_h], | |
| [half_w, half_h], | |
| [-half_w, half_h] | |
| ]) | |
| rotation_matrix = np.array([ | |
| [np.cos(theta), -np.sin(theta)], | |
| [np.sin(theta), np.cos(theta)] | |
| ]) | |
| rotated_corners = np.dot(corners, rotation_matrix) + [x_center, y_center] | |
| polygon = Polygon(rotated_corners, closed=True, fill=False, edgecolor=color, linewidth=2) | |
| plt.gca().add_patch(polygon) | |
| plt.text(rotated_corners[0, 0], rotated_corners[0, 1], labels_text[label], | |
| color=color, fontweight='bold') | |
| plt.axis('off') | |
| plt.tight_layout() | |
| # Save the plot to a BytesIO object | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| # Function to get number of rice grains and bounding boxes | |
| def get_number_of_rice_grains(model, image_path): | |
| results = model(image_path) | |
| boxes = results[0].obb.xywhr.cpu().numpy() | |
| return boxes, len(boxes) | |
| # Function to get broken grain percentage and cluster labels | |
| def get_broken_grain_percentage(clusterer, boxes): | |
| labels = clusterer.cluster_rice(boxes) | |
| unique_labels, counts = np.unique(labels, return_counts=True) | |
| total_grains = len(labels) | |
| grain_percentages = {label: (count / total_grains) * 100 for label, count in zip(unique_labels, counts)} | |
| broken_grain_percentage = grain_percentages.get(1, 0) | |
| return broken_grain_percentage, labels | |
| # Function to create scatter plot for clusters | |
| def plot_height_distribution_scatter_with_clusters(heights, cluster_centers): | |
| heights = np.array(heights) | |
| cluster_centers = np.array(cluster_centers) | |
| distances_to_full_rice = np.abs(heights - cluster_centers[0]) | |
| distances_to_broken_rice = np.abs(heights - cluster_centers[1]) | |
| cluster_labels = np.where(distances_to_full_rice < distances_to_broken_rice, 0, 1) | |
| plt.figure(figsize=(8, 6)) | |
| for cluster, color, label in zip([0, 1], ['green', 'red'], ['Full Rice', 'Broken Rice']): | |
| indices = np.where(cluster_labels == cluster)[0] | |
| plt.scatter(indices, heights[indices], color=color, alpha=0.7, label=label) | |
| plt.axhline(y=cluster_centers[0], color='g', linestyle='--', label='Full Rice Center') | |
| plt.axhline(y=cluster_centers[1], color='r', linestyle='--', label='Broken Rice Center') | |
| plt.xlabel('Index') | |
| plt.ylabel('Bounding Box Heights') | |
| plt.legend() | |
| plt.title('Scatter Plot of Rice Heights and Cluster Centers') | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| # Function to create pie chart for broken grain percentage | |
| def broken_grain_pie_chart(broken_grain_percentage, total_grains): | |
| intact_grain_percentage = 100 - broken_grain_percentage | |
| labels = ['Broken Grains', 'Intact Grains'] | |
| sizes = [broken_grain_percentage, intact_grain_percentage] | |
| colors = ['#ff9999', '#66b3ff'] | |
| plt.figure(figsize=(7, 7)) | |
| plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90, colors=colors) | |
| plt.axis('equal') | |
| plt.title(f'Broken vs Intact Rice Grains (Total: {total_grains})') | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| # Main processing function for Gradio interface | |
| def process_rice_image(image): | |
| # Save the uploaded image to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp: | |
| image_path = temp.name | |
| image.save(image_path) | |
| model_path = 'best.pt' | |
| # Load YOLO model and clustering model | |
| model = YOLO(model_path) | |
| clusterer = RiceClustering() | |
| # Get rice grain information and cluster labels | |
| boxes, num_rice_grains = get_number_of_rice_grains(model, image_path) | |
| broken_grain_percentage, labels = get_broken_grain_percentage(clusterer, boxes) | |
| # Generate OBB visualization | |
| obb_img = visualize_obbs(image_path, boxes, labels) | |
| # Generate pie chart | |
| pie_chart_img = broken_grain_pie_chart(broken_grain_percentage, num_rice_grains) | |
| # Generate cluster scatter plot | |
| heights = clusterer.heights | |
| cluster_centers = clusterer.get_cluster_centers() | |
| scatter_plot_img = plot_height_distribution_scatter_with_clusters(heights, cluster_centers) | |
| analyzer = RiceWhitenessAnalyzer(model_path) | |
| obbs, whiteness_values, mean_whiteness = analyzer.analyze_image(image_path) | |
| mean_whiteness_percentage = mean_whiteness * 100 / 255 | |
| return obb_img, scatter_plot_img, pie_chart_img, mean_whiteness_percentage | |
| examples = [os.path.join('assets', image) for image in os.listdir('assets')] | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=process_rice_image, | |
| inputs=gr.Image(type='pil'), | |
| outputs=[ | |
| gr.Image(label="Oriented Bounding Boxes (OBBs)"), | |
| gr.Image(label="Cluster Scatter Plot"), | |
| gr.Image(label="Broken vs Intact Rice Pie Chart"), | |
| gr.Text(label="Mean Whiteness Percentage") | |
| ], | |
| title="Rice Grain Detection and Clustering with OBBs", | |
| description="Upload an image to detect rice grains, cluster them, visualize OBBs, and broken vs intact rice grains.", | |
| examples=examples | |
| ) | |
| interface.launch() | |