|
|
import os |
|
|
import sys |
|
|
import traceback |
|
|
|
|
|
from math import ceil |
|
|
|
|
|
import PIL.Image |
|
|
import torch |
|
|
import distinctipy |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import facer |
|
|
import tyro |
|
|
|
|
|
from pixel3dmm import env_paths |
|
|
|
|
|
colors = distinctipy.get_colors(22, rng=0) |
|
|
|
|
|
|
|
|
def viz_results(img, seq_classes, n_classes, suppress_plot = False): |
|
|
|
|
|
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) |
|
|
|
|
|
bad_indices = [ |
|
|
0, |
|
|
1, |
|
|
|
|
|
3, |
|
|
4, |
|
|
5, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14, |
|
|
|
|
|
16, |
|
|
17, |
|
|
18, |
|
|
] |
|
|
bad_indices = [] |
|
|
|
|
|
for i in range(n_classes): |
|
|
if i not in bad_indices: |
|
|
seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255 |
|
|
|
|
|
if not suppress_plot: |
|
|
plt.imshow(seg_img.astype(np.uint(8))) |
|
|
plt.show() |
|
|
return Image.fromarray(seg_img.astype(np.uint8)) |
|
|
|
|
|
def get_color_seg(img, seq_classes, n_classes): |
|
|
|
|
|
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) |
|
|
colors = distinctipy.get_colors(n_classes+1, rng=0) |
|
|
|
|
|
bad_indices = [ |
|
|
0, |
|
|
1, |
|
|
|
|
|
3, |
|
|
4, |
|
|
5, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14, |
|
|
|
|
|
16, |
|
|
17, |
|
|
18, |
|
|
] |
|
|
|
|
|
for i in range(n_classes): |
|
|
if i not in bad_indices: |
|
|
seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255 |
|
|
|
|
|
|
|
|
return Image.fromarray(seg_img.astype(np.uint8)) |
|
|
|
|
|
|
|
|
def crop_gt_img(img, seq_classes, n_classes): |
|
|
|
|
|
seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) |
|
|
colors = distinctipy.get_colors(n_classes+1, rng=0) |
|
|
|
|
|
bad_indices = [ |
|
|
0, |
|
|
1, |
|
|
|
|
|
3, |
|
|
4, |
|
|
5, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14, |
|
|
|
|
|
16, |
|
|
17, |
|
|
18, |
|
|
] |
|
|
|
|
|
for i in range(n_classes): |
|
|
if i in bad_indices: |
|
|
img[seq_classes[0, :, :] == i] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return img.astype(np.uint8) |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
face_detector = facer.face_detector('retinaface/mobilenet', device=device) |
|
|
face_parser = facer.face_parser('farl/celebm/448', device=device) |
|
|
|
|
|
|
|
|
def main(video_name : str): |
|
|
|
|
|
|
|
|
out = f'{env_paths.PREPROCESSED_DATA}/{video_name}' |
|
|
out_seg = f'{out}/seg_og/' |
|
|
out_seg_annot = f'{out}/seg_non_crop_annotations/' |
|
|
os.makedirs(out_seg, exist_ok=True) |
|
|
os.makedirs(out_seg_annot, exist_ok=True) |
|
|
folder = f'{out}/cropped/' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frames = [f for f in os.listdir(folder) if f.endswith('.png') or f.endswith('.jpg')] |
|
|
|
|
|
frames.sort() |
|
|
|
|
|
if len(os.listdir(out_seg)) == len(frames): |
|
|
print(f''' |
|
|
<<<<<<<< ALREADY COMPLETED SEGMENTATION FOR {video_name}, SKIPPING >>>>>>>> |
|
|
''') |
|
|
return |
|
|
|
|
|
|
|
|
batch_size = 1 |
|
|
|
|
|
for i in range(len(frames)//batch_size): |
|
|
image_stack = [] |
|
|
frame_stack = [] |
|
|
original_shapes = [] |
|
|
for j in range(batch_size): |
|
|
file = frames[i * batch_size + j] |
|
|
|
|
|
if os.path.exists(f'{out_seg_annot}/color_{file}.png'): |
|
|
print('DONE') |
|
|
continue |
|
|
img = Image.open(f'{folder}/{file}') |
|
|
|
|
|
og_size = img.size |
|
|
|
|
|
image = facer.hwc2bchw(torch.from_numpy(np.array(img)[..., :3])).to(device=device) |
|
|
image_stack.append(image) |
|
|
frame_stack.append(file[:-4]) |
|
|
|
|
|
for batch_idx in range(ceil(len(image_stack)/batch_size)): |
|
|
image_batch = torch.cat(image_stack[batch_idx*batch_size:(batch_idx+1)*batch_size], dim=0) |
|
|
frame_idx_batch = frame_stack[batch_idx*batch_size:(batch_idx+1)*batch_size] |
|
|
og_shape_batch = original_shapes[batch_idx*batch_size:(batch_idx+1)*batch_size] |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
faces = face_detector(image_batch) |
|
|
torch.cuda.empty_cache() |
|
|
faces = face_parser(image_batch, faces, bbox_scale_factor=1.25) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
seg_logits = faces['seg']['logits'] |
|
|
back_ground = torch.all(seg_logits == 0, dim=1, keepdim=True).detach().squeeze(1).cpu().numpy() |
|
|
seg_probs = seg_logits.softmax(dim=1) |
|
|
seg_classes = seg_probs.argmax(dim=1).detach().cpu().numpy().astype(np.uint8) |
|
|
seg_classes[back_ground] = seg_probs.shape[1] + 1 |
|
|
|
|
|
|
|
|
for _iidx in range(seg_probs.shape[0]): |
|
|
frame = frame_idx_batch[_iidx] |
|
|
iidx = faces['image_ids'][_iidx].item() |
|
|
try: |
|
|
I_color = viz_results(image_batch[iidx:iidx+1], seq_classes=seg_classes[_iidx:_iidx+1], n_classes=seg_probs.shape[1] + 1, suppress_plot=True) |
|
|
I_color.save(f'{out_seg_annot}/color_{frame}.png') |
|
|
except Exception as ex: |
|
|
pass |
|
|
I = Image.fromarray(seg_classes[_iidx]) |
|
|
I.save(f'{out_seg}/{frame}.png') |
|
|
torch.cuda.empty_cache() |
|
|
except Exception as exx: |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
tyro.cli(main) |
|
|
|
|
|
|