| | import os |
| | os.environ["OMP_NUM_THREADS"] = "1" |
| | import sys |
| | import glob |
| | import cv2 |
| | import pickle |
| | import tqdm |
| | import numpy as np |
| | import mediapipe as mp |
| | from utils.commons.multiprocess_utils import multiprocess_run_tqdm |
| | from utils.commons.os_utils import multiprocess_glob |
| | from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker |
| | import warnings |
| | import traceback |
| |
|
| | warnings.filterwarnings('ignore') |
| |
|
| | """ |
| | 基于Face_aligment的lm68已被弃用,因为其: |
| | 1. 对眼睛部位的预测精度极低 |
| | 2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能 |
| | 我们目前转而使用基于mediapipe的lm68 |
| | """ |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def save_file(name, content): |
| | with open(name, "wb") as f: |
| | pickle.dump(content, f) |
| | |
| | def load_file(name): |
| | with open(name, "rb") as f: |
| | content = pickle.load(f) |
| | return content |
| |
|
| |
|
| | face_landmarker = None |
| | |
| | def extract_landmark_job(video_name, nerf=False): |
| | try: |
| | if nerf: |
| | out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") |
| | else: |
| | out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") |
| | if os.path.exists(out_name): |
| | |
| | return |
| | try: |
| | os.makedirs(os.path.dirname(out_name), exist_ok=True) |
| | except: |
| | pass |
| | global face_landmarker |
| | if face_landmarker is None: |
| | face_landmarker = MediapipeLandmarker() |
| | img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name) |
| | lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478) |
| | np.save(out_name, lm478) |
| | return True |
| | |
| | except Exception as e: |
| | traceback.print_exc() |
| | return False |
| | |
| | def out_exist_job(vid_name): |
| | out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") |
| | if os.path.exists(out_name): |
| | return None |
| | else: |
| | return vid_name |
| | |
| | def get_todo_vid_names(vid_names): |
| | if len(vid_names) == 1: |
| | return vid_names |
| | todo_vid_names = [] |
| | for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128): |
| | if res is not None: |
| | todo_vid_names.append(res) |
| | return todo_vid_names |
| |
|
| | if __name__ == '__main__': |
| | import argparse, glob, tqdm, random |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--vid_dir", default='nerf') |
| | parser.add_argument("--ds_name", default='data/raw/videos/May.mp4') |
| | parser.add_argument("--num_workers", default=2, type=int) |
| | parser.add_argument("--process_id", default=0, type=int) |
| | parser.add_argument("--total_process", default=1, type=int) |
| | parser.add_argument("--reset", action="store_true") |
| | parser.add_argument("--load_names", action="store_true") |
| |
|
| | args = parser.parse_args() |
| | vid_dir = args.vid_dir |
| | ds_name = args.ds_name |
| | load_names = args.load_names |
| |
|
| | if ds_name.lower() == 'nerf': |
| | vid_names = [vid_dir] |
| | out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names] |
| | else: |
| | if ds_name in ['lrs3_trainval']: |
| | vid_name_pattern = os.path.join(vid_dir, "*/*.mp4") |
| | elif ds_name in ['TH1KH_512', 'CelebV-HQ']: |
| | vid_name_pattern = os.path.join(vid_dir, "*.mp4") |
| | elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']: |
| | vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4") |
| | elif ds_name in ["RAVDESS", 'VFHQ']: |
| | vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4") |
| | else: |
| | raise NotImplementedError() |
| | |
| | vid_names_path = os.path.join(vid_dir, "vid_names.pkl") |
| | if os.path.exists(vid_names_path) and load_names: |
| | print(f"loading vid names from {vid_names_path}") |
| | vid_names = load_file(vid_names_path) |
| | else: |
| | vid_names = multiprocess_glob(vid_name_pattern) |
| | vid_names = sorted(vid_names) |
| | if not load_names: |
| | print(f"saving vid names to {vid_names_path}") |
| | save_file(vid_names_path, vid_names) |
| | out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names] |
| |
|
| | process_id = args.process_id |
| | total_process = args.total_process |
| | if total_process > 1: |
| | assert process_id <= total_process -1 |
| | num_samples_per_process = len(vid_names) // total_process |
| | if process_id == total_process: |
| | vid_names = vid_names[process_id * num_samples_per_process : ] |
| | else: |
| | vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process] |
| | |
| | if not args.reset: |
| | vid_names = get_todo_vid_names(vid_names) |
| | print(f"todo videos number: {len(vid_names)}") |
| |
|
| | fail_cnt = 0 |
| | job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names] |
| | for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"): |
| | if res is False: |
| | fail_cnt += 1 |
| | print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}") |
| | sys.stdout.flush() |
| | pass |