Spaces:
Runtime error
Runtime error
| try: | |
| import detectron2 | |
| except: | |
| import os | |
| os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
| import gradio as gr | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import base64 | |
| import json | |
| import io | |
| import json | |
| from PIL import Image | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer | |
| from detectron2.data import MetadataCatalog | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| model_path = hf_hub_download(repo_id="SalmanAboAraj/FinalModel", filename="model_final.pth", token=os.getenv('Token')) | |
| config_path = hf_hub_download(repo_id="SalmanAboAraj/FinalModel", filename="config.yaml", token=os.getenv('Token')) | |
| metadata_path = hf_hub_download(repo_id="SalmanAboAraj/FinalModel", filename="metadata.json", token=os.getenv('Token')) | |
| cfg = get_cfg() | |
| cfg.merge_from_file(config_path) | |
| cfg.MODEL.WEIGHTS = model_path | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 | |
| cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| predictor = DefaultPredictor(cfg) | |
| selected_classes = {1, 2, 3, 5, 7, 10, 15, 17, 18} | |
| def process_image_base64(image_base64): | |
| image_data = base64.b64decode(image_base64) | |
| image_rgb = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| original_size = image_rgb.size | |
| image_rgb = image_rgb.resize((512, 512)) | |
| image_np = np.array(image_rgb) | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| outputs = predictor(image_bgr) | |
| instances = outputs["instances"].to("cpu") | |
| if len(instances) == 0: | |
| return {"image": []} | |
| if instances.has("pred_boxes"): | |
| instances.remove("pred_boxes") | |
| if instances.has("scores"): | |
| instances.remove("scores") | |
| pred_classes = instances.pred_classes.numpy() | |
| # mask = np.isin(pred_classes, list(selected_classes)) | |
| # filtered_instances = instances[mask] | |
| # if len(filtered_instances) == 0: | |
| # return {"image": []} | |
| if len(instances) == 0: | |
| return {"image": []} | |
| # mask_shape = filtered_instances[0].pred_masks.shape[1:] | |
| mask_shape = instances[0].pred_masks.shape[1:] | |
| image_mask = np.zeros(mask_shape, dtype=np.int8) | |
| # for i in range(len(filtered_instances)): | |
| # class_id = filtered_instances[i].pred_classes.item() - 1 | |
| # mask_np = filtered_instances[i].pred_masks.numpy().squeeze() | |
| # image_mask[mask_np] = class_id | |
| for i in range(len(instances)): | |
| class_id = instances[i].pred_classes.item() | |
| mask_np = instances[i].pred_masks.numpy().squeeze() | |
| image_mask[mask_np] = class_id | |
| image_mask_resized = cv2.resize(image_mask, original_size, interpolation=cv2.INTER_NEAREST) | |
| return {"image": image_mask_resized.tolist()} | |
| iface = gr.Interface( | |
| fn=process_image_base64, | |
| inputs="text", | |
| outputs="json", | |
| title="Detectron2 Object Detection", | |
| description="Upload an image (Base64) to get a processed mask output." | |
| ) | |
| iface.launch() |