Spaces:
Runtime error
Runtime error
| import os | |
| import os.path as osp | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from basicsr.utils import img2tensor, tensor2img | |
| from pytorch_lightning import seed_everything | |
| from ldm.models.diffusion.plms import PLMSSampler | |
| from ldm.modules.encoders.adapter import Adapter | |
| from ldm.util import instantiate_from_config | |
| from model_edge import pidinet | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| import pathlib | |
| import random | |
| import shlex | |
| import subprocess | |
| import sys | |
| import mmcv | |
| from mmdet.apis import inference_detector, init_detector | |
| from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result) | |
| skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], | |
| [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]] | |
| pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], | |
| [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], | |
| [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]] | |
| pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0], | |
| [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0], | |
| [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], | |
| [51, 153, 255], [51, 153, 255], [51, 153, 255]] | |
| sys.path.append('T2I-Adapter') | |
| config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/' | |
| model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/' | |
| def imshow_keypoints(img, | |
| pose_result, | |
| skeleton=None, | |
| kpt_score_thr=0.1, | |
| pose_kpt_color=None, | |
| pose_link_color=None, | |
| radius=4, | |
| thickness=1): | |
| """Draw keypoints and links on an image. | |
| Args: | |
| img (ndarry): The image to draw poses on. | |
| pose_result (list[kpts]): The poses to draw. Each element kpts is | |
| a set of K keypoints as an Kx3 numpy.ndarray, where each | |
| keypoint is represented as x, y, score. | |
| kpt_score_thr (float, optional): Minimum score of keypoints | |
| to be shown. Default: 0.3. | |
| pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, | |
| the keypoint will not be drawn. | |
| pose_link_color (np.array[Mx3]): Color of M links. If None, the | |
| links will not be drawn. | |
| thickness (int): Thickness of lines. | |
| """ | |
| img_h, img_w, _ = img.shape | |
| img = np.zeros(img.shape) | |
| for idx, kpts in enumerate(pose_result): | |
| if idx > 1: | |
| continue | |
| kpts = kpts['keypoints'] | |
| # print(kpts) | |
| kpts = np.array(kpts, copy=False) | |
| # draw each point on image | |
| if pose_kpt_color is not None: | |
| assert len(pose_kpt_color) == len(kpts) | |
| for kid, kpt in enumerate(kpts): | |
| x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] | |
| if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None: | |
| # skip the point that should not be drawn | |
| continue | |
| color = tuple(int(c) for c in pose_kpt_color[kid]) | |
| cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) | |
| # draw links | |
| if skeleton is not None and pose_link_color is not None: | |
| assert len(pose_link_color) == len(skeleton) | |
| for sk_id, sk in enumerate(skeleton): | |
| pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) | |
| pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) | |
| if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0 | |
| or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr | |
| or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None): | |
| # skip the link that should not be drawn | |
| continue | |
| color = tuple(int(c) for c in pose_link_color[sk_id]) | |
| cv2.line(img, pos1, pos2, color, thickness=thickness) | |
| return img | |
| def load_model_from_config(config, ckpt, verbose=False): | |
| print(f"Loading model from {ckpt}") | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| if "state_dict" in pl_sd: | |
| sd = pl_sd["state_dict"] | |
| else: | |
| sd = pl_sd | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| # if len(m) > 0 and verbose: | |
| # print("missing keys:") | |
| # print(m) | |
| # if len(u) > 0 and verbose: | |
| # print("unexpected keys:") | |
| # print(u) | |
| model.cuda() | |
| model.eval() | |
| return model | |
| class Model: | |
| def __init__(self, | |
| model_config_path: str = 'ControlNet/models/cldm_v15.yaml', | |
| model_dir: str = 'models', | |
| use_lightweight: bool = True): | |
| self.device = torch.device( | |
| 'cuda:0' if torch.cuda.is_available() else 'cpu') | |
| self.model_dir = pathlib.Path(model_dir) | |
| self.model_dir.mkdir(exist_ok=True, parents=True) | |
| self.download_pose_models() | |
| self.download_models() | |
| def download_pose_models(self) -> None: | |
| ## mmpose | |
| device = "cuda" | |
| det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py" | |
| subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py')) | |
| det_config = 'models/faster_rcnn_r50_fpn_coco.py' | |
| det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" | |
| subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')) | |
| det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' | |
| pose_config_file = model_path+"hrnet_w48_coco_256x192.py" | |
| subprocess.run(shlex.split(f'wget {pose_config_file} -O models/hrnet_w48_coco_256x192.py')) | |
| pose_config = 'models/hrnet_w48_coco_256x192.py' | |
| pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth" | |
| subprocess.run(shlex.split(f'wget {pose_checkpoint_file} -O models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth')) | |
| pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth' | |
| ## detector | |
| det_config_mmcv = mmcv.Config.fromfile(det_config) | |
| self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device) | |
| pose_config_mmcv = mmcv.Config.fromfile(pose_config) | |
| self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device) | |
| def download_models(self) -> None: | |
| device = 'cuda' | |
| config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml") | |
| config.model.params.cond_stage_config.params.device = device | |
| base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt" | |
| base_model_file_anything = "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.ckpt" | |
| sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth" | |
| pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth" | |
| seg_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth" | |
| pidinet_file = model_path+"table5_pidinet.pth" | |
| clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*" | |
| subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt')) | |
| subprocess.run(shlex.split(f'wget {base_model_file_anything} -O models/anything-v4.0-pruned.ckpt')) | |
| subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth')) | |
| subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth')) | |
| subprocess.run(shlex.split(f'wget {seg_adapter_file} -O models/t2iadapter_seg_sd14v1.pth')) | |
| subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth')) | |
| self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device) | |
| self.model_anything = load_model_from_config(config, "models/anything-v4.0-pruned.ckpt").to(device) | |
| current_base = 'sd-v1-4.ckpt' | |
| self.model_ad_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) | |
| self.model_ad_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth")) | |
| net_G = pidinet() | |
| ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] | |
| net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) | |
| net_G.to(device) | |
| self.sampler= PLMSSampler(self.model) | |
| self.sampler_anything= PLMSSampler(self.model_anything) | |
| save_memory=True | |
| self.model_ad_pose = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) | |
| self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth")) | |
| self.model_ad_seg = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) | |
| self.model_ad_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth")) | |
| def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): | |
| global current_base | |
| device = 'cuda' | |
| if base_model == 'sd-v1-4.ckpt': | |
| model = self.model | |
| sampler = self.sampler | |
| else: | |
| model = self.model_anything | |
| sampler = self.sampler_anything | |
| # if current_base != base_model: | |
| # ckpt = os.path.join("models", base_model) | |
| # pl_sd = torch.load(ckpt, map_location="cpu") | |
| # if "state_dict" in pl_sd: | |
| # sd = pl_sd["state_dict"] | |
| # else: | |
| # sd = pl_sd | |
| # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) | |
| # current_base = base_model | |
| con_strength = int((1-con_strength)*50) | |
| if fix_sample == 'True': | |
| seed_everything(42) | |
| im = cv2.resize(input_img,(512,512)) | |
| if type_in == 'Sketch': | |
| # net_G = net_G.cpu() | |
| if color_back == 'White': | |
| im = 255-im | |
| im_edge = im.copy() | |
| im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255. | |
| # edge = 1-edge # for white background | |
| im = im>0.5 | |
| im = im.float() | |
| elif type_in == 'Image': | |
| im = img2tensor(im).unsqueeze(0)/255. | |
| im = net_G(im.to(device))[-1] | |
| im = im>0.5 | |
| im = im.float() | |
| im_edge = tensor2img(im) | |
| c = model.get_learned_conditioning([prompt]) | |
| nc = model.get_learned_conditioning([neg_prompt]) | |
| with torch.no_grad(): | |
| # extract condition features | |
| features_adapter = self.model_ad_sketch(im.to(device)) | |
| shape = [4, 64, 64] | |
| # sampling | |
| samples_ddim, _ = sampler.sample(S=50, | |
| conditioning=c, | |
| batch_size=1, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=nc, | |
| eta=0.0, | |
| x_T=None, | |
| features_adapter1=features_adapter, | |
| mode = 'sketch', | |
| con_strength = con_strength) | |
| x_samples_ddim = model.decode_first_stage(samples_ddim) | |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
| x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0] | |
| x_samples_ddim = 255.*x_samples_ddim | |
| x_samples_ddim = x_samples_ddim.astype(np.uint8) | |
| return [im_edge, x_samples_ddim] | |
| def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): | |
| global current_base | |
| det_cat_id = 1 | |
| bbox_thr = 0.2 | |
| device = 'cuda' | |
| if base_model == 'sd-v1-4.ckpt': | |
| model = self.model | |
| sampler = self.sampler | |
| else: | |
| model = self.model_anything | |
| sampler = self.sampler_anything | |
| # if current_base != base_model: | |
| # ckpt = os.path.join("models", base_model) | |
| # pl_sd = torch.load(ckpt, map_location="cpu") | |
| # if "state_dict" in pl_sd: | |
| # sd = pl_sd["state_dict"] | |
| # else: | |
| # sd = pl_sd | |
| # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) | |
| # current_base = base_model | |
| con_strength = int((1-con_strength)*50) | |
| if fix_sample == 'True': | |
| seed_everything(42) | |
| im = cv2.resize(input_img,(512,512)) | |
| image = im.copy() | |
| im = img2tensor(im).unsqueeze(0)/255. | |
| mmdet_results = inference_detector(self.det_model, image) | |
| # keep the person class bounding boxes. | |
| person_results = process_mmdet_results(mmdet_results, det_cat_id) | |
| # optional | |
| return_heatmap = False | |
| dataset = self.pose_model.cfg.data['test']['type'] | |
| # e.g. use ('backbone', ) to return backbone feature | |
| output_layer_names = None | |
| pose_results, returned_outputs = inference_top_down_pose_model( | |
| self.pose_model, | |
| image, | |
| person_results, | |
| bbox_thr=bbox_thr, | |
| format='xyxy', | |
| dataset=dataset, | |
| dataset_info=None, | |
| return_heatmap=return_heatmap, | |
| outputs=output_layer_names) | |
| # show the results | |
| im_pose = imshow_keypoints( | |
| image, | |
| pose_results, | |
| skeleton=skeleton, | |
| pose_kpt_color=pose_kpt_color, | |
| pose_link_color=pose_link_color, | |
| radius=2, | |
| thickness=2) | |
| im_pose = cv2.resize(im_pose,(512,512)) | |
| c = model.get_learned_conditioning([prompt]) | |
| nc = model.get_learned_conditioning([neg_prompt]) | |
| with torch.no_grad(): | |
| # extract condition features | |
| pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255. | |
| pose = pose.unsqueeze(0) | |
| features_adapter = self.model_ad_pose(pose.to(device)) | |
| shape = [4, 64, 64] | |
| # sampling | |
| samples_ddim, _ = sampler.sample(S=50, | |
| conditioning=c, | |
| batch_size=1, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=nc, | |
| eta=0.0, | |
| x_T=None, | |
| features_adapter1=features_adapter, | |
| mode = 'sketch', | |
| con_strength = con_strength) | |
| x_samples_ddim = model.decode_first_stage(samples_ddim) | |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
| x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0] | |
| x_samples_ddim = 255.*x_samples_ddim | |
| x_samples_ddim = x_samples_ddim.astype(np.uint8) | |
| return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim] | |
| def process_seg(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): | |
| global current_base | |
| device = 'cuda' | |
| if base_model == 'sd-v1-4.ckpt': | |
| model = self.model | |
| sampler = self.sampler | |
| else: | |
| model = self.model_anything | |
| sampler = self.sampler_anything | |
| # if current_base != base_model: | |
| # ckpt = os.path.join("models", base_model) | |
| # pl_sd = torch.load(ckpt, map_location="cpu") | |
| # if "state_dict" in pl_sd: | |
| # sd = pl_sd["state_dict"] | |
| # else: | |
| # sd = pl_sd | |
| # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) | |
| # current_base = base_model | |
| con_strength = int((1-con_strength)*50) | |
| if fix_sample == 'True': | |
| seed_everything(42) | |
| im = cv2.resize(input_img,(512,512)) | |
| mask = im.copy() | |
| mask = img2tensor(mask, bgr2rgb=True, float32=True)/255. | |
| mask = mask.unsqueeze(0) | |
| im_mask = tensor2img(mask) | |
| c = model.get_learned_conditioning([prompt]) | |
| nc = model.get_learned_conditioning([neg_prompt]) | |
| with torch.no_grad(): | |
| # extract condition features | |
| features_adapter = self.model_ad_seg(mask.to(device)) | |
| shape = [4, 64, 64] | |
| # sampling | |
| samples_ddim, _ = sampler.sample(S=50, | |
| conditioning=c, | |
| batch_size=1, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=nc, | |
| eta=0.0, | |
| x_T=None, | |
| features_adapter1=features_adapter, | |
| mode = 'mask', | |
| con_strength = con_strength) | |
| x_samples_ddim = model.decode_first_stage(samples_ddim) | |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
| x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0] | |
| x_samples_ddim = 255.*x_samples_ddim | |
| x_samples_ddim = x_samples_ddim.astype(np.uint8) | |
| return [im_mask, x_samples_ddim] |