Spaces:
Build error
Build error
| # Modified from: | |
| # https://github.com/anibali/pytorch-stacked-hourglass | |
| # https://github.com/bearpaw/pytorch-pose | |
| import matplotlib as mpl | |
| mpl.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| # import stacked_hourglass.datasets.utils_stanext as utils_stanext | |
| # COLORS, labels = utils_stanext.load_keypoint_labels_and_colours() | |
| COLORS = ['#d82400', '#d82400', '#d82400', '#fcfc00', '#fcfc00', '#fcfc00', '#48b455', '#48b455', '#48b455', '#0090aa', '#0090aa', '#0090aa', '#d848ff', '#d848ff', '#fc90aa', '#006caa', '#d89000', '#d89000', '#fc90aa', '#006caa', '#ededed', '#ededed', '#a9d08e', '#a9d08e'] | |
| RGB_MEAN = [0.4404, 0.4440, 0.4327] | |
| RGB_STD = [0.2458, 0.2410, 0.2468] | |
| def get_img_from_fig(fig, dpi=180): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=dpi) | |
| buf.seek(0) | |
| img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) | |
| buf.close() | |
| img = cv2.imdecode(img_arr, 1) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return img | |
| def save_input_image_with_keypoints(img, tpts, out_path='./test_input.png', colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD, ratio_in_out=4., threshold=0.3, print_scores=False): | |
| """ | |
| img has shape (3, 256, 256) and is a torch tensor | |
| pts has shape (20, 3) and is a torch tensor | |
| -> this function is tested with the mpii dataset and the results look ok | |
| """ | |
| # reverse color normalization | |
| for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() | |
| img_np = img.detach().cpu().numpy().transpose(1, 2, 0) | |
| # tpts_np = tpts.detach().cpu().numpy() | |
| # plot image | |
| fig, ax = plt.subplots() | |
| plt.imshow(img_np) # plt.imshow(im) | |
| plt.gca().set_axis_off() | |
| plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) | |
| plt.margins(0,0) | |
| # plot all visible keypoints | |
| #import pdb; pdb.set_trace() | |
| for idx, (x, y, v) in enumerate(tpts): | |
| if v > threshold: | |
| x = int(x*ratio_in_out) | |
| y = int(y*ratio_in_out) | |
| plt.scatter([x], [y], c=[colors[idx]], marker="x", s=50) | |
| if print_scores: | |
| txt = '{:2.2f}'.format(v.item()) | |
| plt.annotate(txt, (x, y)) # , c=colors[idx]) | |
| plt.savefig(out_path, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| return | |
| def save_input_image(img, out_path, colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD): | |
| for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() | |
| img_np = img.detach().cpu().numpy().transpose(1, 2, 0) | |
| plt.imsave(out_path, img_np) | |
| return | |
| ###################################################################### | |
| def get_bodypart_colors(): | |
| # body colors | |
| n_body = 8 | |
| c = np.arange(1, n_body + 1) | |
| norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max()) | |
| cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.gist_rainbow) | |
| cmap.set_array([]) | |
| body_cols = [] | |
| for i in range(0, n_body): | |
| body_cols.append(cmap.to_rgba(i + 1)) | |
| # head colors | |
| n_blue = 5 | |
| c = np.arange(1, n_blue + 1) | |
| norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) | |
| cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues) | |
| cmap.set_array([]) | |
| head_cols = [] | |
| for i in range(0, n_body): | |
| head_cols.append(cmap.to_rgba(i + 1)) | |
| # torso colors | |
| n_blue = 2 | |
| c = np.arange(1, n_blue + 1) | |
| norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) | |
| cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Greens) | |
| cmap.set_array([]) | |
| torso_cols = [] | |
| for i in range(0, n_body): | |
| torso_cols.append(cmap.to_rgba(i + 1)) | |
| return body_cols, head_cols, torso_cols | |
| body_cols, head_cols, torso_cols = get_bodypart_colors() | |
| tbp_dict = {'full_body': [0, 8], | |
| 'head': [8, 13], | |
| 'torso': [13, 15]} | |
| def save_image_with_part_segmentation(partseg_big, seg_big, input_image_np, ind_img, out_path_seg=None, out_path_seg_overlay=None, thr=0.3): | |
| soft_max = torch.nn.Softmax(dim=0) | |
| # create dit with results | |
| tbp_dict_res = {} | |
| for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): | |
| partseg_tbp = partseg_big[:, tbp_dict[part][0]:tbp_dict[part][1], :, :] | |
| segm_img_pred = soft_max((partseg_tbp[ind_img, :, :, :])) # [1, :, :] | |
| m_v, m_i = segm_img_pred.max(axis=0) | |
| tbp_dict_res[part] = { | |
| 'inds': tbp_dict[part], | |
| 'seg_probs': segm_img_pred, | |
| 'seg_max_inds': m_i, | |
| 'seg_max_values': m_v} | |
| # create output_image | |
| partseg_image = np.zeros((256, 256, 3)) | |
| for ind_sp in range(0, 5): | |
| # partseg_image[tbp_dict_res['head']['seg_max_inds']==ind_sp, :] = head_cols[ind_sp][0:3] | |
| mask_a = tbp_dict_res['full_body']['seg_max_inds']==1 | |
| mask_b = tbp_dict_res['head']['seg_max_inds']==ind_sp | |
| partseg_image[mask_a*mask_b, :] = head_cols[ind_sp][0:3] | |
| for ind_sp in range(0, 2): | |
| # partseg_image[tbp_dict_res['torso']['seg_max_inds']==ind_sp, :] = torso_cols[ind_sp][0:3] | |
| mask_a = tbp_dict_res['full_body']['seg_max_inds']==2 | |
| mask_b = tbp_dict_res['torso']['seg_max_inds']==ind_sp | |
| partseg_image[mask_a*mask_b, :] = torso_cols[ind_sp][0:3] | |
| for ind_sp in range(0, 8): | |
| if (not ind_sp == 1) and (not ind_sp == 2): # head and torso | |
| partseg_image[tbp_dict_res['full_body']['seg_max_inds']==ind_sp, :] = body_cols[ind_sp][0:3] | |
| partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]<thr, :] = 0 | |
| # save images | |
| if out_path_seg is not None: | |
| plt.imsave(out_path_seg, partseg_image) | |
| if out_path_seg_overlay is not None: | |
| partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]<thr, :] = input_image_np[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]<thr, :] | |
| im_masked_partseg = cv2.addWeighted(input_image_np.astype(np.float32),0.5,partseg_image.astype(np.float32),0.5,0) | |
| plt.imsave(out_path_seg_overlay, im_masked_partseg) | |
| return | |
| def save_image_with_part_segmentation_from_gt_annotation(partseg_annots, out_path, ind_img=0): | |
| # partseg_annots: (bs, 3, 256, 256) | |
| # import pdb; pdb.set_trace() | |
| annots = partseg_annots[ind_img, :, :, :] | |
| partseg_image = np.zeros((256, 256, 3)) | |
| for ind_sp in range(0, 8): | |
| partseg_image[annots[0, :, :]==ind_sp, :] = body_cols[ind_sp][0:3] | |
| for ind_sp in range(0, 5): | |
| partseg_image[annots[1, :, :]==ind_sp, :] = head_cols[ind_sp][0:3] | |
| for ind_sp in range(0, 2): | |
| partseg_image[annots[2, :, :]==ind_sp, :] = torso_cols[ind_sp][0:3] | |
| plt.imsave(out_path, partseg_image.astype(np.float32)) | |
| return | |
| def save_image_from_prepared_partseg(partseg_init, out_path): | |
| # partseg_init: (256, 256, 11) | |
| # partseg_init = output_reproj['partseg_images_hg_nograd'][0, :, :, :].detach().cpu().numpy() | |
| # out_path = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/debugging_output/partseg_hg_0.png' | |
| partseg = np.argmax(partseg_init, axis=2) | |
| partseg_image = np.zeros((256, 256, 3)) | |
| for ind in range(partseg_init.shape[2]): | |
| if ind == 0: # head | |
| partseg_image[partseg==ind, :] = np.asarray(head_cols[0][0:3]) | |
| elif ind < 7: | |
| partseg_image[partseg==ind, :] = np.asarray(body_cols[ind+1][0:3]) | |
| else: # 7 to 10 | |
| partseg_image[partseg==ind, :] = np.asarray(head_cols[ind-6][0:3]) | |
| partseg_image[partseg_init.sum(axis=2)==0, :] = 0 | |
| plt.imsave(out_path, partseg_image) | |
| return | |