Spaces:
Runtime error
Runtime error
Commit
·
98fc92a
1
Parent(s):
6a8f88e
Add mmpose
Browse files
model.py
CHANGED
|
@@ -41,6 +41,73 @@ sys.path.append('T2I-Adapter')
|
|
| 41 |
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
|
| 42 |
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def load_model_from_config(config, ckpt, verbose=False):
|
| 45 |
print(f"Loading model from {ckpt}")
|
| 46 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
@@ -71,10 +138,36 @@ class Model:
|
|
| 71 |
self.device = torch.device(
|
| 72 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 73 |
self.model_dir = pathlib.Path(model_dir)
|
| 74 |
-
|
| 75 |
self.download_models()
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def download_models(self) -> None:
|
| 80 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
|
@@ -206,16 +299,49 @@ class Model:
|
|
| 206 |
seed_everything(42)
|
| 207 |
|
| 208 |
im = cv2.resize(input_img,(512,512))
|
| 209 |
-
pose = img2tensor(im, bgr2rgb=True, float32=True)/255.
|
| 210 |
-
pose = pose.unsqueeze(0)
|
| 211 |
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
c = model.get_learned_conditioning([prompt])
|
| 215 |
nc = model.get_learned_conditioning([neg_prompt])
|
| 216 |
|
| 217 |
with torch.no_grad():
|
| 218 |
# extract condition features
|
|
|
|
|
|
|
| 219 |
features_adapter = self.model_ad_pose(pose.to(device))
|
| 220 |
|
| 221 |
shape = [4, 64, 64]
|
|
|
|
| 41 |
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
|
| 42 |
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
|
| 43 |
|
| 44 |
+
|
| 45 |
+
def imshow_keypoints(img,
|
| 46 |
+
pose_result,
|
| 47 |
+
skeleton=None,
|
| 48 |
+
kpt_score_thr=0.1,
|
| 49 |
+
pose_kpt_color=None,
|
| 50 |
+
pose_link_color=None,
|
| 51 |
+
radius=4,
|
| 52 |
+
thickness=1):
|
| 53 |
+
"""Draw keypoints and links on an image.
|
| 54 |
+
Args:
|
| 55 |
+
img (ndarry): The image to draw poses on.
|
| 56 |
+
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
| 57 |
+
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
| 58 |
+
keypoint is represented as x, y, score.
|
| 59 |
+
kpt_score_thr (float, optional): Minimum score of keypoints
|
| 60 |
+
to be shown. Default: 0.3.
|
| 61 |
+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
| 62 |
+
the keypoint will not be drawn.
|
| 63 |
+
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
| 64 |
+
links will not be drawn.
|
| 65 |
+
thickness (int): Thickness of lines.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
img_h, img_w, _ = img.shape
|
| 69 |
+
img = np.zeros(img.shape)
|
| 70 |
+
|
| 71 |
+
for idx, kpts in enumerate(pose_result):
|
| 72 |
+
if idx > 1:
|
| 73 |
+
continue
|
| 74 |
+
kpts = kpts['keypoints']
|
| 75 |
+
# print(kpts)
|
| 76 |
+
kpts = np.array(kpts, copy=False)
|
| 77 |
+
|
| 78 |
+
# draw each point on image
|
| 79 |
+
if pose_kpt_color is not None:
|
| 80 |
+
assert len(pose_kpt_color) == len(kpts)
|
| 81 |
+
|
| 82 |
+
for kid, kpt in enumerate(kpts):
|
| 83 |
+
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
| 84 |
+
|
| 85 |
+
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
|
| 86 |
+
# skip the point that should not be drawn
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
color = tuple(int(c) for c in pose_kpt_color[kid])
|
| 90 |
+
cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
|
| 91 |
+
|
| 92 |
+
# draw links
|
| 93 |
+
if skeleton is not None and pose_link_color is not None:
|
| 94 |
+
assert len(pose_link_color) == len(skeleton)
|
| 95 |
+
|
| 96 |
+
for sk_id, sk in enumerate(skeleton):
|
| 97 |
+
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
| 98 |
+
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
| 99 |
+
|
| 100 |
+
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
|
| 101 |
+
or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
|
| 102 |
+
or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
|
| 103 |
+
# skip the link that should not be drawn
|
| 104 |
+
continue
|
| 105 |
+
color = tuple(int(c) for c in pose_link_color[sk_id])
|
| 106 |
+
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
| 107 |
+
|
| 108 |
+
return img
|
| 109 |
+
|
| 110 |
+
|
| 111 |
def load_model_from_config(config, ckpt, verbose=False):
|
| 112 |
print(f"Loading model from {ckpt}")
|
| 113 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
| 138 |
self.device = torch.device(
|
| 139 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 140 |
self.model_dir = pathlib.Path(model_dir)
|
| 141 |
+
self.download_pose_models()
|
| 142 |
self.download_models()
|
| 143 |
|
| 144 |
|
| 145 |
+
def download_pose_models(self) -> None:
|
| 146 |
+
## mmpose
|
| 147 |
+
device = "cuda"
|
| 148 |
+
det_config_file = model_path+"faster_rcnn_r50_fpn_coco.py"
|
| 149 |
+
subprocess.run(shlex.split(f'wget {det_config_file} -O models/faster_rcnn_r50_fpn_coco.py'))
|
| 150 |
+
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
| 151 |
+
|
| 152 |
+
det_checkpoint_file = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
|
| 153 |
+
subprocess.run(shlex.split(f'wget {det_checkpoint_file} -O models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'))
|
| 154 |
+
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
| 155 |
+
|
| 156 |
+
pose_config_file = model_path+"rnet_w48_coco_256x192.py"
|
| 157 |
+
subprocess.run(shlex.split(f'wget {pose_config_file} -O models/rnet_w48_coco_256x192.py'))
|
| 158 |
+
pose_config = 'models/hrnet_w48_coco_256x192.py'
|
| 159 |
+
|
| 160 |
+
pose_checkpoint_file = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
|
| 161 |
+
subprocess.run(shlex.split(f'wget {pose_checkpoint_file} -O models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'))
|
| 162 |
+
pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
| 163 |
+
|
| 164 |
+
det_cat_id = 1
|
| 165 |
+
bbox_thr = 0.2
|
| 166 |
+
## detector
|
| 167 |
+
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
| 168 |
+
self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
|
| 169 |
+
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
| 170 |
+
self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
|
| 171 |
|
| 172 |
def download_models(self) -> None:
|
| 173 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
| 299 |
seed_everything(42)
|
| 300 |
|
| 301 |
im = cv2.resize(input_img,(512,512))
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
image = im.copy()
|
| 304 |
+
im = img2tensor(im).unsqueeze(0)/255.
|
| 305 |
+
mmdet_results = inference_detector(det_model, image)
|
| 306 |
+
# keep the person class bounding boxes.
|
| 307 |
+
person_results = process_mmdet_results(mmdet_results, det_cat_id)
|
| 308 |
+
|
| 309 |
+
# optional
|
| 310 |
+
return_heatmap = False
|
| 311 |
+
dataset = pose_model.cfg.data['test']['type']
|
| 312 |
+
|
| 313 |
+
# e.g. use ('backbone', ) to return backbone feature
|
| 314 |
+
output_layer_names = None
|
| 315 |
+
pose_results, returned_outputs = inference_top_down_pose_model(
|
| 316 |
+
pose_model,
|
| 317 |
+
image,
|
| 318 |
+
person_results,
|
| 319 |
+
bbox_thr=bbox_thr,
|
| 320 |
+
format='xyxy',
|
| 321 |
+
dataset=dataset,
|
| 322 |
+
dataset_info=None,
|
| 323 |
+
return_heatmap=return_heatmap,
|
| 324 |
+
outputs=output_layer_names)
|
| 325 |
+
|
| 326 |
+
# show the results
|
| 327 |
+
im_pose = imshow_keypoints(
|
| 328 |
+
image,
|
| 329 |
+
pose_results,
|
| 330 |
+
skeleton=skeleton,
|
| 331 |
+
pose_kpt_color=pose_kpt_color,
|
| 332 |
+
pose_link_color=pose_link_color,
|
| 333 |
+
radius=2,
|
| 334 |
+
thickness=2)
|
| 335 |
+
|
| 336 |
+
im_pose = cv2.resize(im_pose,(512,512))
|
| 337 |
|
| 338 |
c = model.get_learned_conditioning([prompt])
|
| 339 |
nc = model.get_learned_conditioning([neg_prompt])
|
| 340 |
|
| 341 |
with torch.no_grad():
|
| 342 |
# extract condition features
|
| 343 |
+
pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
|
| 344 |
+
pose = pose.unsqueeze(0)
|
| 345 |
features_adapter = self.model_ad_pose(pose.to(device))
|
| 346 |
|
| 347 |
shape = [4, 64, 64]
|