RicePOC / app.py
aniruddh1907's picture
Upload 18 files
e89bdf1 verified
import os
import tempfile
from ultralytics import YOLO
from clusterv2 import RiceClustering
from whitenessv2 import RiceWhitenessAnalyzer
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
from io import BytesIO
from PIL import Image
import cv2
from matplotlib.patches import Polygon
# Function to visualize Oriented Bounding Boxes (OBBs) instead of YOLO bboxes
def visualize_obbs(image, obbs, labels):
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(12, 8))
plt.imshow(img)
# Define colors and text for each label
colors = ['r', 'g', 'b', 'y', 'c'] # Adapt colors based on possible clusters
labels_text = ['', '', '', '', '']
for obb, label in zip(obbs, labels):
x_center, y_center, width, height, angle = obb
color = colors[label]
theta = 270 - angle
half_w, half_h = width / 2, height / 2
corners = np.array([
[-half_w, -half_h],
[half_w, -half_h],
[half_w, half_h],
[-half_w, half_h]
])
rotation_matrix = np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
rotated_corners = np.dot(corners, rotation_matrix) + [x_center, y_center]
polygon = Polygon(rotated_corners, closed=True, fill=False, edgecolor=color, linewidth=2)
plt.gca().add_patch(polygon)
plt.text(rotated_corners[0, 0], rotated_corners[0, 1], labels_text[label],
color=color, fontweight='bold')
plt.axis('off')
plt.tight_layout()
# Save the plot to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return Image.open(buf)
# Function to get number of rice grains and bounding boxes
def get_number_of_rice_grains(model, image_path):
results = model(image_path)
boxes = results[0].obb.xywhr.cpu().numpy()
return boxes, len(boxes)
# Function to get broken grain percentage and cluster labels
def get_broken_grain_percentage(clusterer, boxes):
labels = clusterer.cluster_rice(boxes)
unique_labels, counts = np.unique(labels, return_counts=True)
total_grains = len(labels)
grain_percentages = {label: (count / total_grains) * 100 for label, count in zip(unique_labels, counts)}
broken_grain_percentage = grain_percentages.get(1, 0)
return broken_grain_percentage, labels
# Function to create scatter plot for clusters
def plot_height_distribution_scatter_with_clusters(heights, cluster_centers):
heights = np.array(heights)
cluster_centers = np.array(cluster_centers)
distances_to_full_rice = np.abs(heights - cluster_centers[0])
distances_to_broken_rice = np.abs(heights - cluster_centers[1])
cluster_labels = np.where(distances_to_full_rice < distances_to_broken_rice, 0, 1)
plt.figure(figsize=(8, 6))
for cluster, color, label in zip([0, 1], ['green', 'red'], ['Full Rice', 'Broken Rice']):
indices = np.where(cluster_labels == cluster)[0]
plt.scatter(indices, heights[indices], color=color, alpha=0.7, label=label)
plt.axhline(y=cluster_centers[0], color='g', linestyle='--', label='Full Rice Center')
plt.axhline(y=cluster_centers[1], color='r', linestyle='--', label='Broken Rice Center')
plt.xlabel('Index')
plt.ylabel('Bounding Box Heights')
plt.legend()
plt.title('Scatter Plot of Rice Heights and Cluster Centers')
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return Image.open(buf)
# Function to create pie chart for broken grain percentage
def broken_grain_pie_chart(broken_grain_percentage, total_grains):
intact_grain_percentage = 100 - broken_grain_percentage
labels = ['Broken Grains', 'Intact Grains']
sizes = [broken_grain_percentage, intact_grain_percentage]
colors = ['#ff9999', '#66b3ff']
plt.figure(figsize=(7, 7))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90, colors=colors)
plt.axis('equal')
plt.title(f'Broken vs Intact Rice Grains (Total: {total_grains})')
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return Image.open(buf)
# Main processing function for Gradio interface
def process_rice_image(image):
# Save the uploaded image to a temporary file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp:
image_path = temp.name
image.save(image_path)
model_path = 'best.pt'
# Load YOLO model and clustering model
model = YOLO(model_path)
clusterer = RiceClustering()
# Get rice grain information and cluster labels
boxes, num_rice_grains = get_number_of_rice_grains(model, image_path)
broken_grain_percentage, labels = get_broken_grain_percentage(clusterer, boxes)
# Generate OBB visualization
obb_img = visualize_obbs(image_path, boxes, labels)
# Generate pie chart
pie_chart_img = broken_grain_pie_chart(broken_grain_percentage, num_rice_grains)
# Generate cluster scatter plot
heights = clusterer.heights
cluster_centers = clusterer.get_cluster_centers()
scatter_plot_img = plot_height_distribution_scatter_with_clusters(heights, cluster_centers)
analyzer = RiceWhitenessAnalyzer(model_path)
obbs, whiteness_values, mean_whiteness = analyzer.analyze_image(image_path)
mean_whiteness_percentage = mean_whiteness * 100 / 255
return obb_img, scatter_plot_img, pie_chart_img, mean_whiteness_percentage
examples = [os.path.join('assets', image) for image in os.listdir('assets')]
# Gradio interface
interface = gr.Interface(
fn=process_rice_image,
inputs=gr.Image(type='pil'),
outputs=[
gr.Image(label="Oriented Bounding Boxes (OBBs)"),
gr.Image(label="Cluster Scatter Plot"),
gr.Image(label="Broken vs Intact Rice Pie Chart"),
gr.Text(label="Mean Whiteness Percentage")
],
title="Rice Grain Detection and Clustering with OBBs",
description="Upload an image to detect rice grains, cluster them, visualize OBBs, and broken vs intact rice grains.",
examples=examples
)
interface.launch()