Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import cv2 | |
| import numpy as np | |
| def resize_image(input_image, resolution): | |
| H, W, C = input_image.shape | |
| H = float(H) | |
| W = float(W) | |
| k = float(resolution) / min(H, W) | |
| H *= k | |
| W *= k | |
| H = int(np.round(H / 64.0)) * 64 | |
| W = int(np.round(W / 64.0)) * 64 | |
| img = cv2.resize( | |
| input_image, (W, H), | |
| interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
| return img, k | |
| def resize_image_ori(h, w, image, k): | |
| img = cv2.resize( | |
| image, (w, h), | |
| interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
| return img | |
| class AnnotatorProcessor(): | |
| canny_cfg = { | |
| 'NAME': 'CannyAnnotator', | |
| 'LOW_THRESHOLD': 100, | |
| 'HIGH_THRESHOLD': 200, | |
| 'INPUT_KEYS': ['img'], | |
| 'OUTPUT_KEYS': ['canny'] | |
| } | |
| hed_cfg = { | |
| 'NAME': 'HedAnnotator', | |
| 'PRETRAINED_MODEL': | |
| 'ms://damo/scepter_scedit@annotator/ckpts/ControlNetHED.pth', | |
| 'INPUT_KEYS': ['img'], | |
| 'OUTPUT_KEYS': ['hed'] | |
| } | |
| openpose_cfg = { | |
| 'NAME': 'OpenposeAnnotator', | |
| 'BODY_MODEL_PATH': | |
| 'ms://damo/scepter_scedit@annotator/ckpts/body_pose_model.pth', | |
| 'HAND_MODEL_PATH': | |
| 'ms://damo/scepter_scedit@annotator/ckpts/hand_pose_model.pth', | |
| 'INPUT_KEYS': ['img'], | |
| 'OUTPUT_KEYS': ['openpose'] | |
| } | |
| midas_cfg = { | |
| 'NAME': 'MidasDetector', | |
| 'PRETRAINED_MODEL': | |
| 'ms://damo/scepter_scedit@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt', | |
| 'INPUT_KEYS': ['img'], | |
| 'OUTPUT_KEYS': ['depth'] | |
| } | |
| mlsd_cfg = { | |
| 'NAME': 'MLSDdetector', | |
| 'PRETRAINED_MODEL': | |
| 'ms://damo/scepter_scedit@annotator/ckpts/mlsd_large_512_fp32.pth', | |
| 'INPUT_KEYS': ['img'], | |
| 'OUTPUT_KEYS': ['mlsd'] | |
| } | |
| color_cfg = { | |
| 'NAME': 'ColorAnnotator', | |
| 'RATIO': 64, | |
| 'INPUT_KEYS': ['img'], | |
| 'OUTPUT_KEYS': ['color'] | |
| } | |
| anno_type_map = { | |
| 'canny': canny_cfg, | |
| 'hed': hed_cfg, | |
| 'pose': openpose_cfg, | |
| 'depth': midas_cfg, | |
| 'mlsd': mlsd_cfg, | |
| 'color': color_cfg | |
| } | |
| def __init__(self, anno_type): | |
| from scepter.modules.annotator.registry import ANNOTATORS | |
| from scepter.modules.utils.config import Config | |
| from scepter.modules.utils.distribute import we | |
| if isinstance(anno_type, str): | |
| assert anno_type in self.anno_type_map.keys() | |
| anno_type = [anno_type] | |
| elif isinstance(anno_type, (list, tuple)): | |
| assert all(tp in self.anno_type_map.keys() for tp in anno_type) | |
| else: | |
| raise Exception(f'Error anno_type: {anno_type}') | |
| general_dict = { | |
| 'NAME': 'GeneralAnnotator', | |
| 'ANNOTATORS': [self.anno_type_map[tp] for tp in anno_type] | |
| } | |
| general_anno = Config(cfg_dict=general_dict, load=False) | |
| self.general_ins = ANNOTATORS.build(general_anno).to(we.device_id) | |
| def run(self, image, anno_type=None): | |
| output_image = self.general_ins({'img': image}) | |
| if anno_type is not None: | |
| if isinstance(anno_type, str) and anno_type in output_image: | |
| return output_image[anno_type] | |
| else: | |
| return { | |
| tp: output_image[tp] | |
| for tp in anno_type if tp in output_image | |
| } | |
| else: | |
| return output_image | |