Spaces:
Build error
Build error
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights | |
| from torchvision.utils import draw_bounding_boxes | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT | |
| categories = weights.meta["categories"] | |
| img_preprocess = weights.transforms() | |
| def load_model(): | |
| model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| def make_prediction(img): | |
| img_processed = img_preprocess(img) | |
| prediction = model(img_processed.unsqueeze(0)) | |
| prediction = prediction[0] | |
| prediction["labels"] = [categories[label] for label in prediction["labels"]] | |
| return prediction | |
| def create_image_with_bboxes(img, prediction): | |
| img_tensor = torch.tensor(img) | |
| img_with_bboxes = draw_bounding_boxes(img_tensor, boxes=prediction["boxes"], labels=prediction["labels"], | |
| colors=["red" if label=="person" else "green" for label in prediction["labels"]], width=2) | |
| img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1,2,0) | |
| return img_with_bboxes_np | |
| def process_image(image): | |
| img = Image.fromarray(image.astype('uint8'), 'RGB') | |
| prediction = make_prediction(img) | |
| img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2,0,1), prediction) | |
| fig = plt.figure(figsize=(12,12)) | |
| ax = fig.add_subplot(111) | |
| plt.imshow(img_with_bbox) | |
| plt.xticks([],[]) | |
| plt.yticks([],[]) | |
| ax.spines[["top", "bottom", "right", "left"]].set_visible(False) | |
| plt.tight_layout() | |
| plt.close(fig) | |
| # Save plot to a BytesIO object | |
| img_bytes = BytesIO() | |
| fig.savefig(img_bytes, format='png') | |
| img_bytes.seek(0) | |
| # Create a summary of detected objects | |
| detected_objects = [] | |
| for label, score in zip(prediction["labels"], prediction["scores"]): | |
| detected_objects.append(f"{label}: {score:.2f}") | |
| prediction_data = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in prediction.items()} | |
| return Image.open(img_bytes), detected_objects, prediction_data | |
| gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[gr.Image(type="pil"), gr.Textbox(), gr.JSON()], | |
| title="OBJECT_DETECTOR_254", | |
| description="Upload an image to detect objects and display bounding boxes along with a summary of detected objects.", | |
| ).launch() | |