Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import torch | |
| import os | |
| from tqdm import tqdm | |
| # import wandb | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| from skimage.transform import resize | |
| from skimage import img_as_bool | |
| from skimage.morphology import convex_hull_image | |
| import json | |
| # wandb.init(mode='disabled') | |
| def tableConvexHull(img, masks): | |
| mask=np.zeros(masks[0].shape,dtype="bool") | |
| for msk in masks: | |
| temp=msk.cpu().detach().numpy(); | |
| chull = convex_hull_image(temp); | |
| mask=np.bitwise_or(mask,chull) | |
| return mask | |
| def cls_exists(clss, cls): | |
| indices = torch.where(clss==cls) | |
| return len(indices[0])>0 | |
| def empty_mask(img): | |
| mask = np.zeros(img.shape[:2], dtype="uint8") | |
| return np.array(mask, dtype=bool) | |
| def extract_img_mask(img_model, img, config): | |
| res_dict = { | |
| 'status' : 1 | |
| } | |
| res = get_predictions(img_model, img, config) | |
| if res['status']==-1: | |
| res_dict['status'] = -1 | |
| elif res['status']==0: | |
| res_dict['mask']=empty_mask(img) | |
| else: | |
| masks = res['masks'] | |
| boxes = res['boxes'] | |
| clss = boxes[:, 5] | |
| mask = extract_mask(img, masks, boxes, clss, 0) | |
| res_dict['mask'] = mask | |
| return res_dict | |
| def get_predictions(model, img2, config): | |
| res_dict = { | |
| 'status': 1 | |
| } | |
| try: | |
| for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\ | |
| imgsz=config['sz'], conf=config['conf'], stream=True,\ | |
| classes=config['classes']): | |
| try: | |
| res_dict['masks'] = result.masks.data | |
| res_dict['boxes'] = result.boxes.data | |
| del result | |
| return res_dict | |
| except Exception as e: | |
| res_dict['status'] = 0 | |
| return res_dict | |
| except: | |
| res_dict['status'] = -1 | |
| return res_dict | |
| def extract_mask(img, masks, boxes, clss, cls): | |
| if not cls_exists(clss, cls): | |
| return empty_mask(img) | |
| indices = torch.where(clss==cls) | |
| c_masks = masks[indices] | |
| mask_arr = torch.any(c_masks, dim=0).bool() | |
| mask_arr = mask_arr.cpu().detach().numpy() | |
| mask = mask_arr | |
| return mask | |
| def get_masks(img, model, img_model, flags, configs): | |
| response = { | |
| 'status': 1 | |
| } | |
| ans_masks = [] | |
| img2 = img | |
| # ***** Getting paragraph and text masks | |
| res = get_predictions(model, img2, configs['paratext']) | |
| if res['status']==-1: | |
| response['status'] = -1 | |
| return response | |
| elif res['status']==0: | |
| for i in range(2): ans_masks.append(empty_mask(img)) | |
| else: | |
| masks, boxes = res['masks'], res['boxes'] | |
| clss = boxes[:, 5] | |
| for cls in range(2): | |
| mask = extract_mask(img, masks, boxes, clss, cls) | |
| ans_masks.append(mask) | |
| # ***** Getting image and table masks | |
| res2 = get_predictions(model, img2, configs['imgtab']) | |
| if res2['status']==-1: | |
| response['status'] = -1 | |
| return response | |
| elif res2['status']==0: | |
| for i in range(2): ans_masks.append(empty_mask(img)) | |
| else: | |
| masks, boxes = res2['masks'], res2['boxes'] | |
| clss = boxes[:, 5] | |
| if cls_exists(clss, 2): | |
| img_res = extract_img_mask(img_model, img, configs['image']) | |
| if img_res['status'] == 1: | |
| img_mask = img_res['mask'] | |
| else: | |
| response['status'] = -1 | |
| return response | |
| else: | |
| img_mask = empty_mask(img) | |
| ans_masks.append(img_mask) | |
| if cls_exists(clss, 3): | |
| indices = torch.where(clss==3) | |
| tbl_mask = tableConvexHull(img, masks[indices]) | |
| else: | |
| tbl_mask = empty_mask(img) | |
| ans_masks.append(tbl_mask) | |
| if not configs['paratext']['rm']: | |
| h, w, c = img.shape | |
| for i in range(4): | |
| ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w))) | |
| response['masks'] = ans_masks | |
| return response | |
| def overlay(image, mask, color, alpha, resize=None): | |
| """Combines image and its segmentation mask into a single image. | |
| https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay | |
| Params: | |
| image: Training image. np.ndarray, | |
| mask: Segmentation mask. np.ndarray, | |
| color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0) | |
| alpha: Segmentation mask's transparency. float = 0.5, | |
| resize: If provided, both image and its mask are resized before blending them together. | |
| tuple[int, int] = (1024, 1024)) | |
| Returns: | |
| image_combined: The combined image. np.ndarray | |
| """ | |
| color = color[::-1] | |
| colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) | |
| colored_mask = np.moveaxis(colored_mask, 0, -1) | |
| masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) | |
| image_overlay = masked.filled() | |
| if resize is not None: | |
| image = cv2.resize(image.transpose(1, 2, 0), resize) | |
| image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) | |
| image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) | |
| return image_combined | |
| model_path = 'models' | |
| general_model_name = 'e50_aug.pt' | |
| image_model_name = 'e100_img.pt' | |
| general_model = YOLO(os.path.join(model_path, general_model_name)) | |
| image_model = YOLO(os.path.join(model_path, image_model_name)) | |
| image_path = 'examples' | |
| sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png', | |
| '0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png'] | |
| sample_path = [os.path.join(image_path, sample) for sample in sample_name] | |
| flags = { | |
| 'hist': False, | |
| 'bz': False | |
| } | |
| configs = {} | |
| configs['paratext'] = { | |
| 'sz' : 640, | |
| 'conf': 0.25, | |
| 'rm': True, | |
| 'classes': [0, 1] | |
| } | |
| configs['imgtab'] = { | |
| 'sz' : 640, | |
| 'conf': 0.35, | |
| 'rm': True, | |
| 'classes': [2, 3] | |
| } | |
| configs['image'] = { | |
| 'sz' : 640, | |
| 'conf': 0.35, | |
| 'rm': True, | |
| 'classes': [0] | |
| } | |
| def evaluate(img_path, model=general_model, img_model=image_model,\ | |
| configs=configs, flags=flags): | |
| # print('starting') | |
| img = cv2.imread(img_path) | |
| res = get_masks(img, general_model, image_model, flags, configs) | |
| if res['status']==-1: | |
| for idx in configs.keys(): | |
| configs[idx]['rm'] = False | |
| return evaluate(img, model, img_model, flags, configs) | |
| else: | |
| masks = res['masks'] | |
| color_map = { | |
| 0 : (255, 0, 0), | |
| 1 : (0, 255, 0), | |
| 2 : (0, 0, 255), | |
| 3 : (255, 255, 0), | |
| } | |
| for i, mask in enumerate(masks): | |
| img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4) | |
| # print('finishing') | |
| return img | |
| # output = evaluate(img_path=sample_path, model=general_model, img_model=image_model,\ | |
| # configs=configs, flags=flags) | |
| inputs_image = [ | |
| gr.components.Image(type="filepath", label="Input Image"), | |
| ] | |
| outputs_image = [ | |
| gr.components.Image(type="numpy", label="Output Image"), | |
| ] | |
| interface_image = gr.Interface( | |
| fn=evaluate, | |
| inputs=inputs_image, | |
| outputs=outputs_image, | |
| title="Document Layout Segmentor", | |
| examples=sample_path, | |
| cache_examples=True, | |
| ).launch() |