|
|
import os |
|
|
import cv2 |
|
|
import time |
|
|
import glob |
|
|
import argparse |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from itertools import cycle |
|
|
from torch.multiprocessing import Pool, Process, set_start_method |
|
|
|
|
|
from facexlib.alignment import landmark_98_to_68 |
|
|
from facexlib.detection import init_detection_model |
|
|
|
|
|
from facexlib.utils import load_file_from_url |
|
|
from src.face3d.util.my_awing_arch import FAN |
|
|
|
|
|
from app.config import settings |
|
|
import os |
|
|
|
|
|
|
|
|
def init_alignment_model(model_name, half=False, device="cuda", model_rootpath=None): |
|
|
if model_name == "awing_fan": |
|
|
model = FAN(num_modules=4, num_landmarks=98, device=device) |
|
|
model_url = "https://huggingface.co/duyv/MC-AI/resolve/main/gfpgan/weights/alignment_WFLW_4HG.pth" |
|
|
else: |
|
|
raise NotImplementedError(f"{model_name} is not implemented.") |
|
|
|
|
|
model_path = load_file_from_url(url=model_url, model_dir="facexlib/weights", progress=True, file_name=None, save_dir=model_rootpath) |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)["state_dict"], strict=True) |
|
|
model.eval() |
|
|
model = model.to(device) |
|
|
return model |
|
|
|
|
|
|
|
|
class KeypointExtractor: |
|
|
def __init__(self, device="cuda"): |
|
|
|
|
|
try: |
|
|
import webui |
|
|
|
|
|
root_path = "extensions/SadTalker/gfpgan/weights" |
|
|
|
|
|
except: |
|
|
root_path = os.path.join(settings.DIR_ROOT, "SadTalker", "gfpgan", "weights") |
|
|
|
|
|
self.detector = init_alignment_model("awing_fan", device=device, model_rootpath=root_path) |
|
|
self.det_net = init_detection_model("retinaface_resnet50", half=False, device=device, model_rootpath=root_path) |
|
|
|
|
|
def extract_keypoint(self, images, name=None, info=True): |
|
|
if isinstance(images, list): |
|
|
keypoints = [] |
|
|
if info: |
|
|
i_range = tqdm(images, desc="landmark Det:") |
|
|
else: |
|
|
i_range = images |
|
|
|
|
|
for image in i_range: |
|
|
current_kp = self.extract_keypoint(image) |
|
|
|
|
|
if np.mean(current_kp) == -1 and keypoints: |
|
|
keypoints.append(keypoints[-1]) |
|
|
else: |
|
|
keypoints.append(current_kp[None]) |
|
|
|
|
|
keypoints = np.concatenate(keypoints, 0) |
|
|
np.savetxt(os.path.splitext(name)[0] + ".txt", keypoints.reshape(-1)) |
|
|
return keypoints |
|
|
else: |
|
|
while True: |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
|
|
|
img = np.array(images) |
|
|
bboxes = self.det_net.detect_faces(images, 0.97) |
|
|
|
|
|
bboxes = bboxes[0] |
|
|
img = img[int(bboxes[1]) : int(bboxes[3]), int(bboxes[0]) : int(bboxes[2]), :] |
|
|
|
|
|
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) |
|
|
|
|
|
|
|
|
keypoints[:, 0] += int(bboxes[0]) |
|
|
keypoints[:, 1] += int(bboxes[1]) |
|
|
|
|
|
break |
|
|
except RuntimeError as e: |
|
|
if str(e).startswith("CUDA"): |
|
|
print("Warning: out of memory, sleep for 1s") |
|
|
time.sleep(1) |
|
|
else: |
|
|
print(e) |
|
|
break |
|
|
except TypeError: |
|
|
print("No face detected in this image") |
|
|
shape = [68, 2] |
|
|
keypoints = -1.0 * np.ones(shape) |
|
|
break |
|
|
if name is not None: |
|
|
np.savetxt(os.path.splitext(name)[0] + ".txt", keypoints.reshape(-1)) |
|
|
return keypoints |
|
|
|
|
|
|
|
|
def read_video(filename): |
|
|
frames = [] |
|
|
cap = cv2.VideoCapture(filename) |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frame = Image.fromarray(frame) |
|
|
frames.append(frame) |
|
|
else: |
|
|
break |
|
|
cap.release() |
|
|
return frames |
|
|
|
|
|
|
|
|
def run(data): |
|
|
filename, opt, device = data |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = device |
|
|
kp_extractor = KeypointExtractor() |
|
|
images = read_video(filename) |
|
|
name = filename.split("/")[-2:] |
|
|
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) |
|
|
kp_extractor.extract_keypoint(images, name=os.path.join(opt.output_dir, name[-2], name[-1])) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
set_start_method("spawn") |
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
parser.add_argument("--input_dir", type=str, help="the folder of the input files") |
|
|
parser.add_argument("--output_dir", type=str, help="the folder of the output files") |
|
|
parser.add_argument("--device_ids", type=str, default="0,1") |
|
|
parser.add_argument("--workers", type=int, default=4) |
|
|
|
|
|
opt = parser.parse_args() |
|
|
filenames = list() |
|
|
VIDEO_EXTENSIONS_LOWERCASE = {"mp4"} |
|
|
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) |
|
|
extensions = VIDEO_EXTENSIONS |
|
|
|
|
|
for ext in extensions: |
|
|
os.listdir(f"{opt.input_dir}") |
|
|
print(f"{opt.input_dir}/*.{ext}") |
|
|
filenames = sorted(glob.glob(f"{opt.input_dir}/*.{ext}")) |
|
|
print("Total number of videos:", len(filenames)) |
|
|
pool = Pool(opt.workers) |
|
|
args_list = cycle([opt]) |
|
|
device_ids = opt.device_ids.split(",") |
|
|
device_ids = cycle(device_ids) |
|
|
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): |
|
|
None |
|
|
|