| | import gradio as gr |
| | import requests |
| | import torch |
| | import os |
| | from tqdm import tqdm |
| | |
| | from ultralytics import YOLO |
| | import cv2 |
| | import numpy as np |
| | import pandas as pd |
| | from skimage.transform import resize |
| | from skimage import img_as_bool |
| | from skimage.morphology import convex_hull_image |
| | import json |
| |
|
| | |
| |
|
| | def tableConvexHull(img, masks): |
| | mask=np.zeros(masks[0].shape,dtype="bool") |
| | for msk in masks: |
| | temp=msk.cpu().detach().numpy(); |
| | chull = convex_hull_image(temp); |
| | mask=np.bitwise_or(mask,chull) |
| | return mask |
| |
|
| | def cls_exists(clss, cls): |
| | indices = torch.where(clss==cls) |
| | return len(indices[0])>0 |
| |
|
| | def empty_mask(img): |
| | mask = np.zeros(img.shape[:2], dtype="uint8") |
| | return np.array(mask, dtype=bool) |
| |
|
| | def extract_img_mask(img_model, img, config): |
| | res_dict = { |
| | 'status' : 1 |
| | } |
| | res = get_predictions(img_model, img, config) |
| | |
| | if res['status']==-1: |
| | res_dict['status'] = -1 |
| | |
| | elif res['status']==0: |
| | res_dict['mask']=empty_mask(img) |
| | |
| | else: |
| | masks = res['masks'] |
| | boxes = res['boxes'] |
| | clss = boxes[:, 5] |
| | mask = extract_mask(img, masks, boxes, clss, 0) |
| | res_dict['mask'] = mask |
| | return res_dict |
| |
|
| | def get_predictions(model, img2, config): |
| | res_dict = { |
| | 'status': 1 |
| | } |
| | try: |
| | for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\ |
| | imgsz=config['sz'], conf=config['conf'], stream=True,\ |
| | classes=config['classes']): |
| | try: |
| | res_dict['masks'] = result.masks.data |
| | res_dict['boxes'] = result.boxes.data |
| | del result |
| | return res_dict |
| | except Exception as e: |
| | res_dict['status'] = 0 |
| | return res_dict |
| | except: |
| | res_dict['status'] = -1 |
| | return res_dict |
| |
|
| | def extract_mask(img, masks, boxes, clss, cls): |
| | if not cls_exists(clss, cls): |
| | return empty_mask(img) |
| | indices = torch.where(clss==cls) |
| | c_masks = masks[indices] |
| | mask_arr = torch.any(c_masks, dim=0).bool() |
| | mask_arr = mask_arr.cpu().detach().numpy() |
| | mask = mask_arr |
| | return mask |
| |
|
| |
|
| | def get_masks(img, model, img_model, flags, configs): |
| | response = { |
| | 'status': 1 |
| | } |
| | ans_masks = [] |
| | img2 = img |
| | |
| | |
| | |
| | res = get_predictions(model, img2, configs['paratext']) |
| | if res['status']==-1: |
| | response['status'] = -1 |
| | return response |
| | elif res['status']==0: |
| | for i in range(2): ans_masks.append(empty_mask(img)) |
| | else: |
| | masks, boxes = res['masks'], res['boxes'] |
| | clss = boxes[:, 5] |
| | for cls in range(2): |
| | mask = extract_mask(img, masks, boxes, clss, cls) |
| | ans_masks.append(mask) |
| | |
| | |
| | |
| | res2 = get_predictions(model, img2, configs['imgtab']) |
| | if res2['status']==-1: |
| | response['status'] = -1 |
| | return response |
| | elif res2['status']==0: |
| | for i in range(2): ans_masks.append(empty_mask(img)) |
| | else: |
| | masks, boxes = res2['masks'], res2['boxes'] |
| | clss = boxes[:, 5] |
| | |
| | if cls_exists(clss, 2): |
| | img_res = extract_img_mask(img_model, img, configs['image']) |
| | if img_res['status'] == 1: |
| | img_mask = img_res['mask'] |
| | else: |
| | response['status'] = -1 |
| | return response |
| | |
| | else: |
| | img_mask = empty_mask(img) |
| | ans_masks.append(img_mask) |
| | |
| | if cls_exists(clss, 3): |
| | indices = torch.where(clss==3) |
| | tbl_mask = tableConvexHull(img, masks[indices]) |
| | else: |
| | tbl_mask = empty_mask(img) |
| | ans_masks.append(tbl_mask) |
| | |
| | if not configs['paratext']['rm']: |
| | h, w, c = img.shape |
| | for i in range(4): |
| | ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w))) |
| | |
| | |
| | response['masks'] = ans_masks |
| | return response |
| |
|
| | def overlay(image, mask, color, alpha, resize=None): |
| | """Combines image and its segmentation mask into a single image. |
| | https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay |
| | |
| | Params: |
| | image: Training image. np.ndarray, |
| | mask: Segmentation mask. np.ndarray, |
| | color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0) |
| | alpha: Segmentation mask's transparency. float = 0.5, |
| | resize: If provided, both image and its mask are resized before blending them together. |
| | tuple[int, int] = (1024, 1024)) |
| | |
| | Returns: |
| | image_combined: The combined image. np.ndarray |
| | |
| | """ |
| | color = color[::-1] |
| | colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) |
| | colored_mask = np.moveaxis(colored_mask, 0, -1) |
| | masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) |
| | image_overlay = masked.filled() |
| |
|
| | if resize is not None: |
| | image = cv2.resize(image.transpose(1, 2, 0), resize) |
| | image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) |
| |
|
| | image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) |
| |
|
| | return image_combined |
| | |
| |
|
| |
|
| | model_path = 'models' |
| | general_model_name = 'e50_aug.pt' |
| | image_model_name = 'e100_img.pt' |
| |
|
| | general_model = YOLO(os.path.join(model_path, general_model_name)) |
| | image_model = YOLO(os.path.join(model_path, image_model_name)) |
| |
|
| | image_path = 'examples' |
| | sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png', |
| | '0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png'] |
| |
|
| | sample_path = [os.path.join(image_path, sample) for sample in sample_name] |
| |
|
| | flags = { |
| | 'hist': False, |
| | 'bz': False |
| | } |
| |
|
| |
|
| | configs = {} |
| | configs['paratext'] = { |
| | 'sz' : 640, |
| | 'conf': 0.25, |
| | 'rm': True, |
| | 'classes': [0, 1] |
| | } |
| | configs['imgtab'] = { |
| | 'sz' : 640, |
| | 'conf': 0.35, |
| | 'rm': True, |
| | 'classes': [2, 3] |
| | } |
| | configs['image'] = { |
| | 'sz' : 640, |
| | 'conf': 0.35, |
| | 'rm': True, |
| | 'classes': [0] |
| | } |
| |
|
| | def evaluate(img_path, model=general_model, img_model=image_model,\ |
| | configs=configs, flags=flags): |
| | |
| | img = cv2.imread(img_path) |
| | res = get_masks(img, general_model, image_model, flags, configs) |
| | if res['status']==-1: |
| | for idx in configs.keys(): |
| | configs[idx]['rm'] = False |
| | return evaluate(img, model, img_model, flags, configs) |
| | else: |
| | masks = res['masks'] |
| | |
| | color_map = { |
| | 0 : (255, 0, 0), |
| | 1 : (0, 255, 0), |
| | 2 : (0, 0, 255), |
| | 3 : (255, 255, 0), |
| | } |
| | for i, mask in enumerate(masks): |
| | img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4) |
| | |
| | return img |
| |
|
| | |
| | |
| |
|
| |
|
| | inputs_image = [ |
| | gr.components.Image(type="filepath", label="Input Image"), |
| | ] |
| | outputs_image = [ |
| | gr.components.Image(type="numpy", label="Output Image"), |
| | ] |
| | interface_image = gr.Interface( |
| | fn=evaluate, |
| | inputs=inputs_image, |
| | outputs=outputs_image, |
| | title="Document Layout Segmentor", |
| | examples=sample_path, |
| | cache_examples=True, |
| | ).launch() |