Filipstrozik
Enhance Gradio interface instructions for better user guidance on image uploads and parameter adjustments
9a734e7
| import io | |
| from ast import mod | |
| import gradio as gr | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import torchvision.transforms as transforms | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from ellipse_rcnn import EllipseRCNN | |
| # load model.pth from Filipstrozik/sat-tree-detection-v0 repository in hugging face | |
| load_state_dict = torch.load( | |
| hf_hub_download("Filipstrozik/sat-tree-detection-v0", "model.pth"), | |
| weights_only=True, | |
| ) | |
| model = EllipseRCNN() | |
| model.load_state_dict(load_state_dict) | |
| model.eval() | |
| def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Returns center of ellipse in 2D cartesian coordinate system with numerical stability.""" | |
| # Extract the top-left 2x2 submatrix of the conic matrix | |
| A = conic_matrix[..., :2, :2] | |
| # Add stabilization for pseudoinverse computation by clamping singular values | |
| A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps) | |
| # Extract the last two rows for the linear term | |
| b = -conic_matrix[..., :2, 2][..., None] | |
| # Stabilize any potential numerical instabilities | |
| centers = torch.matmul(A_pinv, b).squeeze() | |
| return centers[..., 0], centers[..., 1] | |
| def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system.""" | |
| lambdas = ( | |
| torch.linalg.eigvalsh(conic_matrix[..., :2, :2]) | |
| / (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None] | |
| ) | |
| axes = torch.sqrt(1 / lambdas) | |
| return axes[..., 0], axes[..., 1] | |
| def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor: | |
| """Returns angle of ellipse in radians w.r.t. x-axis.""" | |
| return ( | |
| -torch.atan2( | |
| 2 * conic_matrix[..., 1, 0], | |
| conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0], | |
| ) | |
| / 2 | |
| ) | |
| def get_ellipse_params_from_matrices(ellipse_matrices): | |
| if ellipse_matrices.shape[0] == 0: | |
| return None | |
| a, b = ellipse_axes(ellipse_matrices) | |
| cx, cy = conic_center(ellipse_matrices) | |
| theta = ellipse_angle(ellipse_matrices) | |
| a = a.view(-1) | |
| b = b.view(-1) | |
| cx = cx.view(-1) | |
| cy = cy.view(-1) | |
| theta = theta.view(-1) | |
| ellipses = torch.stack([a, b, cx, cy, theta], dim=1).reshape(-1, 5) | |
| return ellipses | |
| def plot_ellipses( | |
| ellipse_params: torch.Tensor, | |
| image: torch.Tensor, | |
| plot_centers: bool = False, | |
| rim_color: str = "r", | |
| alpha: float = 0.25, | |
| ) -> None: | |
| if ellipse_params is None: | |
| return | |
| a, b, cx, cy, theta = ellipse_params.unbind(-1) | |
| # multiply all pixel values by 4 | |
| cx = cx * 4 | |
| cy = cy * 4 | |
| # draw ellipses | |
| for i in range(len(a)): | |
| ellipse = mpatches.Ellipse( | |
| (cx[i], cy[i]), | |
| width=a[i], | |
| height=b[i], | |
| angle=theta[i], | |
| fill=True, | |
| alpha=alpha, | |
| color=rim_color, | |
| ) | |
| plt.gca().add_patch(ellipse) | |
| if plot_centers: | |
| plt.scatter(cx[i], cy[i], c=rim_color, s=10) | |
| plt.imshow(image) | |
| # Define the necessary transformations and the inverse normalization | |
| def invert_normalization(image, mean, std): | |
| for t, m, s in zip(image, mean, std): | |
| t.mul_(s).add_(m) | |
| return torch.clamp(image, 0, 1) | |
| def process_image(image): | |
| original_size = image.size | |
| # Define the transform pipeline | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((1024, 1024)), | |
| transforms.PILToTensor(), | |
| transforms.ConvertImageDtype(torch.float), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| return image_tensor, original_size | |
| def generate_prediction(image, rpn_nms_thresh, score_thresh, nms_thresh): | |
| # Preprocess image | |
| image_tensor, original_size = process_image(image) | |
| image_tensor = image_tensor.to("cpu") | |
| # Ensure the model is in evaluation mode | |
| model.rpn.nms_thresh = rpn_nms_thresh | |
| model.roi_heads.score_thresh = score_thresh | |
| model.roi_heads.nms_thresh = nms_thresh | |
| with torch.no_grad(): | |
| prediction = model(image_tensor)[0] | |
| # Invert normalization for display | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| inverted_image = ( | |
| invert_normalization(image_tensor, mean, std) | |
| .squeeze(0) | |
| .permute(1, 2, 0) | |
| .cpu() | |
| .numpy() | |
| ) | |
| # Plot results with ellipses | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(inverted_image) | |
| plot_ellipses( | |
| get_ellipse_params_from_matrices(prediction["ellipse_matrices"]), | |
| inverted_image, | |
| plot_centers=True, | |
| rim_color="red", | |
| alpha=0.25, | |
| ) | |
| red_patch = mpatches.Patch(color="red", label="Predicted") | |
| plt.legend(handles=[red_patch], loc="upper right") | |
| plt.gca().set_aspect(original_size[0] / original_size[1]) | |
| plt.axis("off") | |
| plt.tight_layout() | |
| # Save the figure to a buffer and return as an image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| with Image.open(buf) as output_image: | |
| output_image = output_image.copy() | |
| buf.close() | |
| return output_image | |
| # Define Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Tree Detection from Satellite Images") | |
| gr.Markdown( | |
| "Upload an image and see the detected trees with ellipses. For better predictions, upload a high-resoltion image of orthophotomap with zoom level 18." | |
| ) | |
| gr.Markdown( | |
| "Try different values for RPN NMS Threshold, ROI Heads Score Threshold, and ROI Heads NMS Threshold to see how they affect the predictions." | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(label="Input Image", type="pil") | |
| image_output = gr.Image(label="Detected Trees") | |
| examples = [ | |
| ["examples/image1.jpg"], | |
| ["examples/image2.jpg"], | |
| ["examples/image3.jpg"], | |
| ] | |
| with gr.Row(): | |
| rpn_nms_slider = gr.Slider( | |
| 0.0, 1.0, value=model.rpn.nms_thresh, label="RPN NMS Threshold" | |
| ) | |
| score_thresh_slider = gr.Slider( | |
| 0.0, | |
| 1.0, | |
| value=model.roi_heads.score_thresh, | |
| label="ROI Heads Score Threshold", | |
| ) | |
| nms_thresh_slider = gr.Slider( | |
| 0.0, 1.0, value=model.roi_heads.nms_thresh, label="ROI Heads NMS Threshold" | |
| ) | |
| submit_button = gr.Button("Detect Trees") | |
| submit_button.click( | |
| fn=generate_prediction, | |
| inputs=[image_input, rpn_nms_slider, score_thresh_slider, nms_thresh_slider], | |
| outputs=image_output, | |
| ) | |
| gr.Examples(examples=examples, inputs=image_input, outputs=image_output) | |
| demo.launch() | |