Spaces:
Build error
Build error
| import gradio as gr | |
| from PIL import Image | |
| import sys | |
| #Set the Working Cloned directory - To find local version of the library | |
| sys.path.append('/home/user/app/Mask_RCNN') | |
| import mrcnn.model as modellib | |
| from mrcnn.config import Config | |
| import pickle | |
| def get_predictions_from_mask_rcnn(mask_model, image): | |
| predictions_mask_rcnn = [] | |
| # Run Object Detection using Mask RCNN | |
| results = mask_model.detect(image, verbose=0) | |
| # Get the Results | |
| r = results[0] | |
| # Get Predictions ROIs, Class IDs, Scores and Masks | |
| pred_boxes = r['rois'] | |
| pred_class_ids = r['class_ids'] | |
| pred_scores = r['scores'] | |
| pred_masks = r['masks'] | |
| print(f"predictions - rois shape: {pred_boxes.shape}, scores: {pred_scores}, masks shape: {pred_masks.shape}") | |
| # Update the Predictions | |
| predictions_mask_rcnn.append((pred_boxes, pred_class_ids, pred_scores, pred_masks)) | |
| # Finally, return the Predictions | |
| return predictions_mask_rcnn | |
| # Load Config | |
| def load_config(filepath): | |
| with open(filepath, 'rb') as f: | |
| return pickle.load(f) | |
| class RSNA_Detect_Config(Config): | |
| """Configuration for training pneumonia detection on the RSNA pneumonia dataset | |
| Overrides values in the base Config class | |
| """ | |
| # Provided the configuration a recognizable name | |
| NAME = 'RSNA_pneumonia' | |
| # Train on 1 GPU and 8 images per GPU. We can put multiple images on each | |
| # GPU because the images are small. Batch size is 8 (GPUs * images/GPU). | |
| GPU_COUNT = 1 | |
| BACKBONE = 'resnet50' | |
| NUM_CLASSES = 2 # background + 1 pneumonia classes | |
| IMAGE_MIN_DIM = 256 # default 800 | |
| IMAGE_MAX_DIM = 256 # default 1024 | |
| RPN_ANCHOR_SCALES = (16, 32, 64, 128, 256) # default (32, 64, 128, 256, 512) | |
| TRAIN_ROIS_PER_IMAGE = 32 # default 200 | |
| MAX_GT_INSTANCES = 4 # default 100 | |
| DETECTION_MAX_INSTANCES = 4 # default 100, to match distribution | |
| DETECTION_MIN_CONFIDENCE = 0.78 # match the target distribution | |
| DETECTION_NMS_THRESHOLD = 0.01 # default 0.3 | |
| # Get object reference for RSNA_Eval_Config | |
| config = RSNA_Detect_Config() | |
| mask_rcnn_inference_model = modellib.MaskRCNN(mode = 'inference', | |
| config = config, | |
| model_dir = "assests") | |
| def predict(inp): | |
| confidences = {"cat": 0.0, "dog": 0.5} | |
| predictions = get_predictions_from_mask_rcnn(mask_rcnn_inference_model, inp) | |
| print(predictions) | |
| return inp, confidences | |
| gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Image(type="pil"), gr.Label(num_top_classes=3)], | |
| examples=[]).launch() | |