diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..3bc438f6e5bd183da128057ba19773ef2fa597fc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,28 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +models/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/image3--d6_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/image3--intro_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d10_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d10.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d13_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d13.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d3_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d9_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/input--d9.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/s1--d13_concat.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/animations/s1--d13.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/docs/showcase.gif filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/docs/showcase2.gif filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/examples/driving/d0.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/examples/driving/d10.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/examples/driving/d13.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/examples/driving/d3.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/examples/driving/d6.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/assets/examples/driving/d9.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/pretrained_weights/docs/showcase2.gif filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/uploads/d6.mp4 filter=lfs diff=lfs merge=lfs -text +subprocess/LivePortrait/uploads/intro.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/__pycache__/main.cpython-310.pyc b/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c10b4c77653b22023c878241b13480d4856cc06 Binary files /dev/null and b/__pycache__/main.cpython-310.pyc differ diff --git a/api.db b/api.db new file mode 100644 index 0000000000000000000000000000000000000000..42b702b28d7e96e651e645ead71d08ea63ad6585 Binary files /dev/null and b/api.db differ diff --git a/assests/index.html b/assests/index.html new file mode 100644 index 0000000000000000000000000000000000000000..e63ca4f8b26f2b256eca8313d14581d738e0645d --- /dev/null +++ b/assests/index.html @@ -0,0 +1,238 @@ + + + + + + Quantum Grove + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

Welcome to QuantumGrove

+

Your gateway to the future of technology.

+
+

Contact us:

+

Phone: +92 309 2193555

+

Email: info@quantumgrove.tech

+

Address: 95-X Gulberg Greens, Islamabad, Pakistan

+

Website: QUANTUMGROVE.TECH

+
+
+ + diff --git a/assests/logo.png b/assests/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..aca706e76d9271d9791fbb5a92f45fafff27be55 Binary files /dev/null and b/assests/logo.png differ diff --git a/gfpgan/weights/GFPGANv1.4.pth b/gfpgan/weights/GFPGANv1.4.pth new file mode 100644 index 0000000000000000000000000000000000000000..afedb5c7e826056840c9cc183f2c6f0186fd17ba --- /dev/null +++ b/gfpgan/weights/GFPGANv1.4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad +size 348632874 diff --git a/gfpgan/weights/alignment_WFLW_4HG.pth b/gfpgan/weights/alignment_WFLW_4HG.pth new file mode 100644 index 0000000000000000000000000000000000000000..3cfeef20123eb2e74b35a4319c2111ef65783c34 --- /dev/null +++ b/gfpgan/weights/alignment_WFLW_4HG.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbfd137307a4c7debd5c283b9b0ce539466cee417ac0a155e184d857f9f2899c +size 193670248 diff --git a/gfpgan/weights/detection_Resnet50_Final.pth b/gfpgan/weights/detection_Resnet50_Final.pth new file mode 100644 index 0000000000000000000000000000000000000000..16546738ce0a00a9fd47585e0fc52744d31cc117 --- /dev/null +++ b/gfpgan/weights/detection_Resnet50_Final.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d +size 109497761 diff --git a/gfpgan/weights/parsing_parsenet.pth b/gfpgan/weights/parsing_parsenet.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ac2efc50360a79c9905dbac57d9d99cbfbe863c --- /dev/null +++ b/gfpgan/weights/parsing_parsenet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2 +size 85331193 diff --git a/main copy.py b/main copy.py new file mode 100644 index 0000000000000000000000000000000000000000..79fd75ca21ab7a47804ab1d49af7630ae709dede --- /dev/null +++ b/main copy.py @@ -0,0 +1,1754 @@ +import os +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +web_url ="https://api.quantumgrove.tech" + +from fastapi import FastAPI, UploadFile, File, BackgroundTasks + +import sqlite3 + +def create_database(): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''CREATE TABLE IF NOT EXISTS ban_list ( + id INTEGER PRIMARY KEY, + ip TEXT NOT NULL, + reason TEXT NOT NULL)''') + cursor.execute('''CREATE TABLE IF NOT EXISTS request_log ( + id INTEGER PRIMARY KEY, + ip TEXT NOT NULL, + url TEXT NOT NULL, + method TEXT NOT NULL, + endpoint TEXT NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') + conn.commit() + conn.close() + +def insert_ban(ip, reason): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''INSERT INTO ban_list (ip, reason) VALUES (?, ?)''', (ip, reason)) + conn.commit() + conn.close() + +def search_ban(ip): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''SELECT * FROM ban_list WHERE ip = ?''', (ip,)) + rows = cursor.fetchall() + conn.close() + return rows + +def log_request(ip, url, method, endpoint): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''INSERT INTO request_log (ip, url, method, endpoint) VALUES (?, ?, ?, ?)''', (ip, url, method, endpoint)) + conn.commit() + conn.close() + +if not os.path.exists('api.db'): + create_database() + + +def generate_response(response_message , response_status ,uuid_code , age , gender , metadata): + + response_dict = { + "response_message" : response_message, + "response_status" : response_status, + "data" :{ + "UUID" : uuid_code, + "info" : {"age" : age , "gender" : gender}, + "metadata" : metadata + } + } + + return response_dict + + +import numpy as np +import cv2 +import time +import io +import gc +import shutil +import uuid +import torch +import subprocess +import torchvision.transforms as transforms +from scripts.psp import pSp +from argparse import Namespace +import dlib +from scripts.align_all_parallel import align_face +from scripts.augmentations import AgeTransformer +from scripts.common import tensor2im +from torchvision.transforms.functional import normalize +from scripts.basicsr.utils import img2tensor, tensor2img +from scripts.basicsr.utils.misc import get_device +from scripts.facelib.utils.face_restoration_helper import FaceRestoreHelper +from scripts.basicsr.utils.registry import ARCH_REGISTRY +from scripts.basicsr.archs.rrdbnet_arch import RRDBNet +from scripts.basicsr.utils.realesrgan_utils import RealESRGANer +from scripts.facelib.utils.misc import is_gray +import PIL.Image +from scripts.erasescratches.models import Pix2PixHDModel_Mapping +from scripts.erasescratches.options import Options +from scripts.maskscratches import ScratchesDetector +from scripts.util import irregular_hole_synthesize, tensor_to_ndarray +import insightface +from insightface.app import FaceAnalysis +from pydub import AudioSegment +import scripts.RRDBNet_arch as arch +from RealESRGAN import RealESRGAN +from PIL import Image + +device = ("cuda" if torch.cuda.is_available() else "cpu") + +#age model +EXPERIMENT_TYPE = 'ffhq_aging' +model_path = "./models/sam_ffhq_aging.pt" +model_age_slider = None + +EXPERIMENT_DATA_ARGS = { + "ffhq_aging": { + "transform": transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } +} +EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] + + +def load_model_age(): + try: + if not get_model_state_age(): + ckpt = torch.load(model_path, map_location='cpu') + opts = ckpt['opts'] + del ckpt + torch.cuda.empty_cache() + opts['checkpoint_path'] = model_path + opts = Namespace(**opts) + global model_age_slider + model_age_slider = pSp(opts) + del opts + torch.cuda.empty_cache() + model_age_slider.eval() + model_age_slider.cuda() + torch.cuda.empty_cache() + return True + except: + return False + +def unload_model_age(): + try: + global model_age_slider + if model_age_slider is not None: + model_age_slider = None + torch.cuda.empty_cache() + return True + except: + return False + +def get_model_state_age(): + global model_age_slider + if model_age_slider is None: + return False + return True +#model bg +from scripts.BG import BG +bg_model = None + +def load_model_bg(): + try: + global bg_model + if bg_model is None: + bg_model = BG() + return True + except: + return False + +def unload_model_bg(): + try: + global bg_model + if bg_model is not None: + bg_model = None + return True + except: + return False + +def get_model_state_bg(): + global bg_model + if bg_model is None: + return False + return True +#model code_former +model_upsampler = None +model_code_former = None + + +def set_realesrgan_cf(): + use_half = True if torch.cuda.is_available() else False + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + upsampler = RealESRGANer( + scale=2, + model_path="./models/realesrgan/RealESRGAN_x2plus.pth", + model=model, + tile=400, + tile_pad=40, + pre_pad=0, + half=use_half + ) + return upsampler + + + +def load_model_cf(): + try: + device = get_device() + global model_code_former + model_code_former = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, + connect_list=['32', '64', '128', '256']).to(device) + ckpt_path = 'models/CodeFormer/codeformer.pth' + checkpoint = torch.load(ckpt_path)['params_ema'] + model_code_former.load_state_dict(checkpoint) + model_code_former.eval() + + global model_upsampler + model_upsampler = set_realesrgan_cf() + return True + except: + return False + + +def get_model_state_cf(): + global model_code_former + if model_code_former is None: + return False + else: + return True + + +def unload_model_cf(): + if get_model_state_cf() == False: + return True + try: + global model_code_former + model_code_former = None + torch.cuda.empty_cache() + + global model_upsampler + model_upsampler = None + torch.cuda.empty_cache() + return True + except: + return False +#model swap +swap_model = None +app_model = None + +def load_model_swap(): + try: + global swap_model + global app_model + if swap_model is None: + swap_model = insightface.model_zoo.get_model("./models/inswapper_128.onnx",download=False,download_zip=False) + app_model = FaceAnalysis(name="buffalo_l") + app_model.prepare(ctx_id=0,det_size=(640,640)) + return True + except: + return False + +def unload_model_swap(): + try: + global swap_model + global app_model + if swap_model is not None: + swap_model = None + if app_model is not None: + app_model = None + return True + except: + return False + +def get_model_state_swap(): + global swap_model + + if swap_model is None: + return False + return True +#model rest +model_scratches_remove = None +model_scratches_remove_detector = None +model_scratches_remove_options = None + + +def load_model_rest(): + try: + + model_path = "./models/zeroscratches/restoration" + global model_scratches_remove_detector + model_scratches_remove_detector = ScratchesDetector('./models/zeroscratches') + gpu_ids = [] + if torch.cuda.is_available(): + gpu_ids = [d for d in range(torch.cuda.device_count())] + global model_scratches_remove_options + model_scratches_remove_options = Options(model_path, gpu_ids) + model_scratches = Pix2PixHDModel_Mapping() + global model_scratches_remove + model_scratches_remove = Pix2PixHDModel_Mapping() + model_scratches_remove.initialize(model_scratches_remove_options) + model_scratches_remove.eval() + return True + except: + return False + + +def get_model_state_rest(): + global model_scratches_remove + if model_scratches_remove is None: + return False + else: + return True + + +def unload_model_rest(): + if get_model_state_rest() == False: + return True + try: + global model_scratches_remove + model_scratches_remove = None + global model_scratches_remove_detector + model_scratches_remove_detector = None + global model_scratches_remove_options + model_scratches_remove_options = None + torch.cuda.empty_cache() + return True + except: + return False +#model talk +from scripts.talker import sad_talker +sad_talker_model = None + +def load_model_talk(): + try: + global sad_talker_model + if sad_talker_model is None: + sad_talker_model = sad_talker() + return True + except: + return False + +def unload_model_talk(): + try: + global sad_talker_model + if sad_talker_model is not None: + sad_talker_model = None + return True + except: + return False + +def get_model_state_talk(): + global sad_talker_model + if sad_talker_model is None: + return False + return True +#model sky +from scripts.SKY import SKY +sky_model = None + +def load_model_sky(): + try: + global sky_model + if sky_model is None: + sky_model = SKY() + return True + except: + return False + +def unload_model_sky(): + try: + global sky_model + if sky_model is not None: + sky_model = None + return True + except: + return False + +def get_model_state_sky(): + global sky_model + if sky_model is None: + return False + return True +#model up +model_upscale = None + +def load_model_up(): + try: + model_path_upscale = './models/RRDB_ESRGAN_x4.pth' + global model_upscale + model_upscale = arch.RRDBNet(3, 3, 64, 23, gc=32) + model_upscale.load_state_dict(torch.load(model_path_upscale), strict=True) + + model_upscale = model_upscale.half() + model_upscale.eval() + model_upscale = model_upscale.to(device) + return True + except: + return False + + +def get_model_state_up(): + global model_upscale + if model_upscale is None: + return False + else: + return True + + +def unload_model_up(): + if get_model_state_up() == False: + return True + try: + global model_upscale + model_upscale = None + torch.cuda.empty_cache() + return True + except: + return False +#agr slider helper +def run_alignment(image_path): + predictor = dlib.shape_predictor("./models/shape_predictor_68_face_landmarks.dat") + aligned_image = align_face(filepath=image_path, predictor=predictor) + return aligned_image + +def run_on_batch(inputs, model_age_slider): + result_batch = model_age_slider(inputs.to("cuda").float(), randomize_noise=False, resize=False) + return result_batch + + +#model model_esrgan +model_esrgan = None + +def load_model_esrgan(): + try: + global model_esrgan + model_esrgan = RealESRGAN( torch.device(device), scale=2) + model_esrgan.load_weights('./models/realesrgan/RealESRGAN_x2plus.pth', download=False) + + return True + except: + return False + + +def get_model_state_esrgan(): + global model_esrgan + if model_esrgan is None: + return False + else: + return True + + +def unload_model_esrgan(): + if get_model_state_esrgan() == False: + return True + try: + global model_esrgan + model_esrgan = None + torch.cuda.empty_cache() + return True + except: + return False + + + +def unload_model_all(): + try: + unload_model_age() + unload_model_bg() + unload_model_cf() + unload_model_up() + unload_model_swap() + unload_model_rest() + unload_model_talk() + unload_model_sky() + unload_model_esrgan() + return True + except: + return False + + +def start_http_file_server(): + directory = './outputs' + port = 8010 + if not os.path.exists(directory): + os.makedirs(directory) + cmd = f"python -m http.server {port} --directory {directory}" + return subprocess.Popen(cmd, shell=True) + +import asyncio +from datetime import datetime + + +time_last_call = None + +async def continuous_function(): + while True: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + global time_last_call + if time_last_call is not None: + current_time = datetime.now() + time_difference = current_time - time_last_call + time_elapsed_minutes = time_difference.total_seconds() / 60 + + if time_elapsed_minutes > 1: + gc.collect() + + if time_elapsed_minutes > 5: + model_delete_status = unload_model_all() + + print(f"model delete status = {model_delete_status}") + + time_last_call = None + + await asyncio.sleep(60*1) + +async def start_continuous_function(): + await continuous_function() + + +app = FastAPI() + +from fastapi.responses import HTMLResponse + + +from fastapi import Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +async def custom_middleware(request: Request, call_next): + client_ip = request.client.host + request_url = str(request.url) + request_method = request.method + endpoint = request.url.path + + log_request(client_ip, request_url, request_method, endpoint) + + if len(search_ban(client_ip)) > 0: + return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) + + # user_agent = request.headers.get("user-agent", "").lower() + # print(user_agent) + # if "android" not in user_agent: + # insert_data(client_ip, "Access using unapproved device type") + # return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) + + response = await call_next(request) + + if (response.status_code == 404) or (response.status_code == 405): + insert_ban(client_ip, "Access using unapproved method/endpoint") + return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) + + return response + +app.middleware('http')(custom_middleware) + + +@app.get("/", response_class=HTMLResponse) +def read_index(): + index_path = './assests/index.html' + with open(index_path, 'r') as file: + content = file.read() + return HTMLResponse(content=content) + +from fastapi.responses import FileResponse + +@app.get("/favicon.ico") +async def get_image(): + image_path = "./assests/logo.png" + return FileResponse(image_path, media_type="image/png") + +@app.on_event("startup") +async def startup_event(): + asyncio.create_task(start_continuous_function()) + +@app.on_event("shutdown") +async def shutdown_event(): + print("Clearing Models") + model_status = unload_model_all() + print("Model clear status = "+str(model_status)) + asyncio.get_event_loop().stop() + print("Stopping Server") + +ALLOWED_IMAGE_FORMATS = ["jpg", "jpeg", "png"] +ALLOWED_AUDIO_FORMATS = ["wav", "mp3"] + +def delete_uuid(uuid_to_delete): + path = "./outputs" + directories = [] + for root, dirs, files in os.walk(path): + for dir in dirs: + directories.append(os.path.join(root, dir)) + for i in directories: + temp = i.split('/') + if temp[-1] == uuid_to_delete: + shutil.rmtree(i) + return True + return False + +@app.delete("/Delete/{uuid}") +async def delete_item(uuid: str): + gc.collect() + try: + result = delete_uuid(uuid) + if result: + return {"error": f"UUID '{uuid}' does not exist"} + + return {"message": f"UUID '{uuid}' deleted"} + + except : + return {"error": "unexpected error"} + + + +@app.put("/ClearModel_ALL") +async def clearModel_ALL(): + try: + model_delete_status = unload_model_all() + + + if model_delete_status: + global time_last_call + time_last_call = None + return {"message": "vram cleared"} + else : + return {"message": "vram not cleared"} + + except : + return {"error": "unexpected error"} + +@app.post("/AgeSlider") +async def age_slider(background_tasks: BackgroundTasks , image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + app_face = FaceAnalysis(name="buffalo_l") + app_face.prepare(ctx_id=0,det_size=(640,640)) + + faces = app_face.get(image) + if len(faces)==0: + response_message = "Error no face detected" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + elif len(faces)!=1: + response_message = "Error more than 1 face detected" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + info_age = faces[0]['age'] + info_gender = faces[0]['gender'] +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/AGE/{unique_id}" + os.makedirs(output_path) + + # loading model + if not get_model_state_age(): + unload_model_all() + model_status = load_model_age() + print(f"model status =========={model_status}") + + + cv2.imwrite(output_path+'/input.png',image) + aligned_image = run_alignment(output_path+'/input.png') + #aligned_image.save(output_path+"/input_aligned_cropped.png") + copy_aligned_image = aligned_image.copy() + + aligned_image.resize((256, 256)) + img_transforms = EXPERIMENT_ARGS['transform'] + input_image = img_transforms(aligned_image) + + target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] + age_transformers = [AgeTransformer(target_age=age) for age in target_ages] + + images = [] + + + for age_transformer in age_transformers: + with torch.no_grad(): + input_image_age = [age_transformer(input_image.cpu()).to('cuda')] + input_image_age = torch.stack(input_image_age) + result_tensor = run_on_batch(input_image_age, model_age_slider)[0] + result_image = tensor2im(result_tensor) + images.append(result_image) + + image_paths_temp = [] + + result_paths_list = [] + for idx,i in enumerate(images): + image_temp = np.array(i) + image_temp = cv2.cvtColor(np.array(image_temp), cv2.COLOR_RGB2BGR) + result_path_temp = output_path+"/output_"+str(idx)+".png" + cv2.imwrite(result_path_temp,image_temp) + + image_url_temp = '/'.join(result_path_temp.split('/')[-3:]) + image_url_temp = f"{web_url}:8001/{image_url_temp}" + result_paths_list.append({"age" :str(target_ages[idx]), "image_url" : image_url_temp}) + + image_paths_temp.append(result_path_temp) + + closest_age = min(target_ages, key=lambda x: abs(x - info_age)) + closest_age = target_ages.index(closest_age) + + os.remove(f"{output_path}/output_{closest_age}.png") + copy_aligned_image = copy_aligned_image.resize((1024, 1024)) + copy_aligned_image.save(f"{output_path}/output_{closest_age}.png") + + background_tasks.add_task(Age_Slider_Video,unique_id , image_paths_temp) + + os.remove(output_path+'/input.png') + + response_message = "AgeSlider Ran Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , info_age , ["Female","Male"][info_gender] , result_paths_list) + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +def Age_Slider_Video(uuid_input , image_files): + temp_file_uuid = str(uuid.uuid4()) + directory_path = f'./outputs/{temp_file_uuid}' + if os.path.exists(directory_path): + shutil.rmtree(directory_path) + os.mkdir(directory_path) + + from scripts.morph_video import doMorphing + doMorphing(image_files, 0.3, 20, f"{directory_path}/output") + video_path = f"{directory_path}/output_combined.mp4" + gif_path = f"{directory_path}/output.gif" + + shutil.copy(video_path, f"./outputs/AGE/{uuid_input}/output_combined.mp4") + shutil.copy(gif_path, f"./outputs/AGE/{uuid_input}/output.gif") + + shutil.rmtree(directory_path) + +@app.post("/AgeSliderVideo/{uuid_input}") +async def ageslidervideo(uuid_input: str): +################## + try: + input_path = f"./outputs/AGE/{uuid_input}" +################################################################################################################################################################################### +############################################################################### Test Input UUID ################################################################################## +################################################################################################################################################################################### + if not os.path.exists(input_path): + response_message = f"Invalid UUID , file does not exist" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Processing ################################################################################## +################################################################################################################################################################################### + + output_video_path = '' + output_gif_path = '' + duration = 120 + start_time = time.time() + while time.time() - start_time < duration: + + if output_video_path == '': + if os.path.exists(input_path+"/output_combined.mp4"): + output_video_path = f'{web_url}:8001/AGE/{uuid_input}'+"/output_combined.mp4" + + if output_gif_path == '': + if os.path.exists(input_path+"/output.gif"): + output_gif_path = f'{web_url}:8001/AGE/{uuid_input}'+"/output.gif" + + if output_video_path != '': + if output_gif_path != '': + response_message = "Video Genrated Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + if output_video_path != '': + if output_gif_path == '': + response_message = "Only Video Genrated" + response_status = "Failure" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + if output_gif_path != '': + if output_video_path == '': + response_message = "Only GIF Genrated" + response_status = "Failure" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + + if output_gif_path == '': + if output_video_path == '': + response_message = "Timeout Reached , Retry Later" + response_status = "Failure" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + + + except Exception as e: + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + + +@app.post("/BG_remove") +async def bg(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/BG/{unique_id}" + + + + + # loading model + if not get_model_state_bg(): + unload_model_all() + model_status = load_model_bg() + print(f"model status =========={model_status}") + + + output_image = bg_model.BG_remove(image) + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,output_image) + + + + response_message = "BG Removed Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/BG/{unique_id}/output.png"]) + + + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/CodeFormer") +async def codeformer(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(500*500) : + aspect_ratio = width / height + new_height = int(((500*500) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/CF/{unique_id}" + + + # loading model + if not get_model_state_cf(): + unload_model_all() + model_status = load_model_cf() + print(f"model status =========={model_status}") + + + #image = cv2.imread(output_path+'/input.png') + + + has_aligned = False + only_center_face = False + draw_box = False + detection_model = "retinaface_resnet50" + background_enhance = True + face_upsample = True + upscale = 2 + codeformer_fidelity = 0.5 + img = image#cv2.imread(str(img_path), cv2.IMREAD_COLOR) + upscale = int(upscale) # convert type to int + if upscale > 4: # avoid memory exceeded due to too large upscale + upscale = 4 + if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution + upscale = 2 + if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution + upscale = 1 + background_enhance = False + face_upsample = False + face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model=detection_model, + save_ext="png", + use_parse=True, + device=device, + ) + bg_upsampler = model_upsampler if background_enhance else None + face_upsampler = model_upsampler if face_upsample else None + if has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + face_helper.is_gray = is_gray(img, threshold=5) + if face_helper.is_gray: + print('\tgrayscale input: True') + face_helper.cropped_faces = [img] + else: + face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=only_center_face, resize=640, eye_dist_threshold=5 + ) + print(f'\tdetect {num_det_faces} faces') + # align and warp each face + face_helper.align_warp_face() + + # face restoration for each cropped face + for idx, cropped_face in enumerate(face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor( + cropped_face / 255.0, bgr2rgb=True, float32=True + ) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + try: + with torch.no_grad(): + output = model_code_former(cropped_face_t, w=codeformer_fidelity, adain=True)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for CodeFormer: {error}") + restored_face = tensor2img( + cropped_face_t, rgb2bgr=True, min_max=(-1, 1) + ) + restored_face = restored_face.astype("uint8") + face_helper.add_restored_face(restored_face) + + # paste_back + if not has_aligned: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] + else: + bg_img = None + face_helper.get_inverse_affine(None) + # paste each restored face to the input image + if face_upsample and face_upsampler is not None: + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, + draw_box=draw_box, + face_upsampler=face_upsampler, + ) + else: + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, draw_box=draw_box + ) + + + + + + output_image_path = output_path+'/output.png' + os.makedirs(output_path) + cv2.imwrite(output_image_path,restored_img) + + + #cv2.imwrite(output_path+'/output.png',restored_img) + + + + + + + response_message = "CodeFormer Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/CF/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +@app.post("/FaceSwap_single_image") +async def Swap_single_image(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/SWAP/{unique_id}" + + + + + # loading model + if not get_model_state_swap(): + unload_model_all() + model_status = load_model_swap() + print(f"model status =========={model_status}") + + faces = app_model.get(image) + + if len(faces)==0: + response_message = "Error no face detected" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + elif len(faces)!=2: + response_message = "Error more than 2 face detected" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + + face1 = faces[0] + face2 = faces[1] + + image = swap_model.get(image,face1,face2,paste_back=True) + image = swap_model.get(image,face2,face1,paste_back=True) + + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,image) + + + + response_message = "Face Swapped Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SWAP/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + + +@app.post("/FaceSwap_two_images") +async def swap_two_image(image1: UploadFile = File(...), image2: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + images = [] + for idx ,image in enumerate([image1,image2]): + image_format = image.filename.split(".")[-1].lower() + print(image.filename,image.filename.split(".")[-1].lower()) + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image{str(idx+1)} format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + images.append(image) + + if image is None: + response_message = f"Image{str(idx+1)} is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/SWAP/{unique_id}" + + + + + # loading model + if not get_model_state_swap(): + unload_model_all() + model_status = load_model_swap() + print(f"model status =========={model_status}") + + faces = [] + faces.append(app_model.get(images[0])) + faces.append(app_model.get(images[1])) + + for idx ,face in enumerate(faces): + if len(face)==0: + response_message = f"Error no face detected in image{idx}" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + elif len(face) >1: + response_message = f"Error more than 1 face detected in image{idx}" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + + face1 = faces[0][0] + face2 = faces[1][0] + + image = swap_model.get(images[0],face1,face2,paste_back=True) + + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,image) + + + + response_message = "Face Swapped Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SWAP/{unique_id}/output.png"]) + + + except Exception as e: + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/Restore_Images") +async def restore(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/Restore/{unique_id}" + os.makedirs(output_path) + + cv2.imwrite(output_path+'/input.png',image) + + + + + # loading model + if not get_model_state_rest(): + unload_model_all() + model_status = load_model_rest() + print(f"model status =========={model_status}") + + image = PIL.Image.open(output_path+'/input.png') + os.remove(output_path+'/input.png') + + + transformed, mask = model_scratches_remove_detector.process(image) + img_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + mask_transform = transforms.ToTensor() + if model_scratches_remove_options.mask_dilation != 0: + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask) + mask = cv2.dilate(mask, kernel, iterations=model_scratches_remove_options.mask_dilation) + mask = PIL.Image.fromarray(mask.astype('uint8')) + transformed = irregular_hole_synthesize(transformed, mask) + mask = mask_transform(mask) + mask = mask[:1, :, :] + mask = mask.unsqueeze(0) + transformed = img_transform(transformed) + transformed = transformed.unsqueeze(0) + generated = model_scratches_remove.inference(transformed, mask) + tensor_restored = (generated.data.cpu() + 1.0) / 2.0 + image_to_show = tensor_restored.squeeze().cpu().numpy().transpose((1, 2, 0)) + + image_to_show = (image_to_show * 255).astype(np.uint8)[:, :, ::-1] + + + from scripts.colorizer import colorize_image + image_to_show = colorize_image(image_to_show) + + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,image_to_show) + + + + response_message = "Restred Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Restore/{unique_id}/output.png"]) + + + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/Sad_Talker") +async def sadtalker(image: UploadFile = File(...), audio: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Test Input Audio ################################################################################## +################################################################################################################################################################################### + audio_format = audio.filename.split(".")[-1].lower() + if audio_format not in ALLOWED_AUDIO_FORMATS: + response_message = f"Unsupported audio format. Supported formats: {', '.join(ALLOWED_AUDIO_FORMATS)}" + response_status = "Failure" + return generate_response(response_message, response_status, '', '', '', []) + + audio_contents = await audio.read() + if len(audio_contents) == 0: + response_message = "Audio file is empty" + response_status = "Failure" + return generate_response(response_message, response_status, '', '', '', []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + + + output_path = f"./outputs/Sad_Talker/{unique_id}" + + os.makedirs(output_path) + + image_input_path = output_path+'/input.png' + cv2.imwrite(image_input_path,image) + + audio_segment = AudioSegment.from_file(io.BytesIO(audio_contents), format=audio_format) + audio_input_path = os.path.join(output_path, f"audio.{audio_format}") + audio_segment.export(audio_input_path, format=audio_format) + + app_face = FaceAnalysis(name="buffalo_l") + app_face.prepare(ctx_id=0,det_size=(640,640)) + + faces = app_face.get(image) + if len(faces)==0: + response_message = "Error no face detected" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + elif len(faces)!=1: + response_message = "Error more than 1 face detected" + response_status = "Failure" + + return generate_response(response_message , response_status ,'' , '' , '' , []) + + info_age = faces[0]['age'] + info_gender = faces[0]['gender'] + + + # loading model + if not get_model_state_talk(): + unload_model_all() + model_status = load_model_talk() + print(f"model status =========={model_status}") + + + + + response = sad_talker_model.genrate_video(image_input_path,audio_input_path,output_path,True) + + # input_file = f"{output_path}/output.mp4 " + # subprocess.run(['ffmpeg', '-i', input_file, '-c:v', 'libx264', '-preset', 'slow', '-crf', '23', '-c:a', 'aac', '-b:a', '192k', '-y', input_file], check=True) + + os.remove(image_input_path) + os.remove(audio_input_path) + + if response: + + input_file = f"outputs/Sad_Talker/{unique_id}/output.mp4" + output_file = f"outputs/Sad_Talker/{unique_id}/final.mp4" + + subprocess.run(['ffmpeg', '-i', input_file, '-c:v', 'libx264', '-preset', 'slow', '-crf', '23', '-c:a', 'aac', '-b:a', '192k', output_file], check=True) + + os.remove(input_file) + + response_message = "Video Generated Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , info_age , ["Female","Male"][info_gender] , [f"{web_url}:8001/Sad_Talker/{unique_id}/final.mp4"]) + + else: + response_message = "Error Genrating Video, try another image" + response_status = "Failure" + #shutil.rmtree(output_path) + return generate_response(response_message , response_status ,'','','',[]) + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/SKY_remove") +async def sky(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/SKY/{unique_id}" + + + + + # loading model + if not get_model_state_sky(): + unload_model_all() + model_status = load_model_sky() + print(f"model status =========={model_status}") + + + output_image = sky_model.SKY_remove(image) + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,output_image) + + + + response_message = "SKY Removed Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SKY/{unique_id}/output.png"]) + + + + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/ImageUpscale") +async def upscale(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(200*200) : + print(11) + aspect_ratio = width / height + new_height = int(((200*200) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/Upscale/{unique_id}" + + + + # loading model + if not get_model_state_up(): + unload_model_all() + model_status = load_model_up() + print(f"model status =========={model_status}") + + + #image = cv2.imread(output_path+'/input.png') + + image = image * 1.0 / 255 + image = torch.from_numpy(np.transpose(image[:, :, [2, 1, 0]], (2, 0, 1))).float() + img_LR = image.unsqueeze(0) + img_LR = img_LR.to(device).half() + + output = model_upscale(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() + + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round() + image = output.astype(np.uint8) + + output_image_path = output_path+'/output.png' + os.makedirs(output_path) + cv2.imwrite(output_image_path,image) + + + + response_message = "Upscaled Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Upscale/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +@app.post("/ImageDenoise") +async def deniose(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(500*500) : + print(11) + aspect_ratio = width / height + new_height = int(((500*500) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/Denoise/{unique_id}" + + + + # loading model + if not get_model_state_esrgan(): + unload_model_all() + model_status = load_model_esrgan() + print(f"model status =========={model_status}") + + + #image = cv2.imread(output_path+'/input.png') + + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_pil = Image.fromarray(image_rgb) + global model_esrgan + result_pil = model_esrgan.predict(image_pil) + + result_rgb = result_pil.convert('RGB') + result_np = np.array(result_rgb) + image = cv2.cvtColor(result_np, cv2.COLOR_RGB2BGR) + + output_image_path = output_path+'/output.png' + os.makedirs(output_path) + cv2.imwrite(output_image_path,image) + + + + response_message = "Upscaled Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Denoise/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + + +# if __name__ == '__main__': + +# import uvicorn +# try: +# http_file_server_process = start_http_file_server() +# uvicorn.run("main:app", host='0.0.0.0', port=8080, workers=1 , limit_max_requests=20) +# except KeyboardInterrupt: +# print("Received KeyboardInterrupt, shutting down gracefully...") +# http_file_server_process.kill() +# shutdown_event() + +if __name__ == '__main__': + + import uvicorn + try: + http_file_server_process = start_http_file_server() + uvicorn.run("main:app", host='0.0.0.0', port=8080, workers=2 , limit_max_requests=200) + print("Received KeyboardInterrupt, shutting down gracefully...") + http_file_server_process.kill() + shutdown_event() + except KeyboardInterrupt: + print("Received KeyboardInterrupt, shutting down gracefully...") + http_file_server_process.kill() + shutdown_event() diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..833e52ef910a96f6adae73cee10eea36ca2ca0e7 --- /dev/null +++ b/main.py @@ -0,0 +1,1982 @@ +import os +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +web_url ="https://api.quantumgrove.tech" + +from fastapi import FastAPI, UploadFile, File, BackgroundTasks + +import sqlite3 + +def create_database(): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''CREATE TABLE IF NOT EXISTS ban_list ( + id INTEGER PRIMARY KEY, + ip TEXT NOT NULL, + reason TEXT NOT NULL)''') + cursor.execute('''CREATE TABLE IF NOT EXISTS request_log ( + id INTEGER PRIMARY KEY, + ip TEXT NOT NULL, + url TEXT NOT NULL, + method TEXT NOT NULL, + endpoint TEXT NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') + conn.commit() + conn.close() + + +def insert_ban(ip, reason): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''INSERT INTO ban_list (ip, reason) VALUES (?, ?)''', (ip, reason)) + conn.commit() + conn.close() + +def search_ban(ip): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''SELECT * FROM ban_list WHERE ip = ?''', (ip,)) + rows = cursor.fetchall() + conn.close() + return rows + +def log_request(ip, url, method, endpoint): + conn = sqlite3.connect('api.db') + cursor = conn.cursor() + cursor.execute('''INSERT INTO request_log (ip, url, method, endpoint) VALUES (?, ?, ?, ?)''', (ip, url, method, endpoint)) + conn.commit() + conn.close() + +if not os.path.exists('api.db'): + create_database() + + +def create_database_port(): + conn = sqlite3.connect('./subprocess/port.db', check_same_thread=False) + cursor = conn.cursor() + cursor.execute('''DROP TABLE IF EXISTS ports''') + conn.commit() + cursor.execute('''CREATE TABLE IF NOT EXISTS ports + (port INTEGER)''') + conn.commit() + conn.close() + +def search_port(port): + conn = sqlite3.connect('./subprocess/port.db') + cursor = conn.cursor() + cursor.execute('''SELECT * FROM ports WHERE port = ?''', (port,)) + rows = cursor.fetchall() + conn.close() + if len(rows) > 0: + return True + return False + +def remove_port(port): + if(search_port(port)): + conn = sqlite3.connect('./subprocess/port.db') + cursor = conn.cursor() + cursor.execute('''DELETE FROM ports WHERE port = ?''', (port,)) + conn.commit() + conn.close() + return True + return False + +def insert_port(port): + if( not search_port(port)): + conn = sqlite3.connect('./subprocess/port.db') + cursor = conn.cursor() + cursor.execute('''INSERT INTO ports (port) VALUES (?)''', (port,)) + conn.commit() + conn.close() + return True + else: + return False + + + +def generate_response(response_message , response_status ,uuid_code , age , gender , metadata): + + response_dict = { + "response_message" : response_message, + "response_status" : response_status, + "data" :{ + "UUID" : uuid_code, + "info" : {"age" : age , "gender" : gender}, + "metadata" : metadata + } + } + + return response_dict + + +import numpy as np +import cv2 +import time +import io +import gc +import shutil +import uuid +import torch +import subprocess +import torchvision.transforms as transforms +from scripts.psp import pSp +from argparse import Namespace +import dlib +from scripts.align_all_parallel import align_face +from scripts.augmentations import AgeTransformer +from scripts.common import tensor2im +from torchvision.transforms.functional import normalize +from scripts.basicsr.utils import img2tensor, tensor2img +from scripts.basicsr.utils.misc import get_device +from scripts.facelib.utils.face_restoration_helper import FaceRestoreHelper +from scripts.basicsr.utils.registry import ARCH_REGISTRY +from scripts.basicsr.archs.rrdbnet_arch import RRDBNet +from scripts.basicsr.utils.realesrgan_utils import RealESRGANer +from scripts.facelib.utils.misc import is_gray +import PIL.Image +from scripts.erasescratches.models import Pix2PixHDModel_Mapping +from scripts.erasescratches.options import Options +from scripts.maskscratches import ScratchesDetector +from scripts.util import irregular_hole_synthesize, tensor_to_ndarray +import insightface +from insightface.app import FaceAnalysis +from pydub import AudioSegment +import scripts.RRDBNet_arch as arch +from RealESRGAN import RealESRGAN +from PIL import Image +import socket +import subprocess +from gradio_client import Client, handle_file + +device = ("cuda" if torch.cuda.is_available() else "cpu") + +#age model +EXPERIMENT_TYPE = 'ffhq_aging' +model_path = "./models/sam_ffhq_aging.pt" +model_age_slider = None + +EXPERIMENT_DATA_ARGS = { + "ffhq_aging": { + "transform": transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } +} +EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[EXPERIMENT_TYPE] + + +def load_model_age(): + try: + if not get_model_state_age(): + ckpt = torch.load(model_path, map_location='cpu') + opts = ckpt['opts'] + del ckpt + torch.cuda.empty_cache() + opts['checkpoint_path'] = model_path + opts = Namespace(**opts) + global model_age_slider + model_age_slider = pSp(opts) + del opts + torch.cuda.empty_cache() + model_age_slider.eval() + model_age_slider.cuda() + torch.cuda.empty_cache() + return True + except: + return False + +def unload_model_age(): + try: + global model_age_slider + if model_age_slider is not None: + model_age_slider = None + torch.cuda.empty_cache() + return True + except: + return False + +def get_model_state_age(): + global model_age_slider + if model_age_slider is None: + return False + return True +#model bg +from scripts.BG import BG +bg_model = None + +def load_model_bg(): + try: + global bg_model + if bg_model is None: + bg_model = BG() + return True + except: + return False + +def unload_model_bg(): + try: + global bg_model + if bg_model is not None: + bg_model = None + return True + except: + return False + +def get_model_state_bg(): + global bg_model + if bg_model is None: + return False + return True +#model code_former +model_upsampler = None +model_code_former = None + + +def set_realesrgan_cf(): + use_half = True if torch.cuda.is_available() else False + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + upsampler = RealESRGANer( + scale=2, + model_path="./models/realesrgan/RealESRGAN_x2plus.pth", + model=model, + tile=400, + tile_pad=40, + pre_pad=0, + half=use_half + ) + return upsampler + + + +def load_model_cf(): + try: + device = get_device() + global model_code_former + model_code_former = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, + connect_list=['32', '64', '128', '256']).to(device) + ckpt_path = 'models/CodeFormer/codeformer.pth' + checkpoint = torch.load(ckpt_path)['params_ema'] + model_code_former.load_state_dict(checkpoint) + model_code_former.eval() + + global model_upsampler + model_upsampler = set_realesrgan_cf() + return True + except: + return False + + +def get_model_state_cf(): + global model_code_former + if model_code_former is None: + return False + else: + return True + + +def unload_model_cf(): + if get_model_state_cf() == False: + return True + try: + global model_code_former + model_code_former = None + torch.cuda.empty_cache() + + global model_upsampler + model_upsampler = None + torch.cuda.empty_cache() + return True + except: + return False +#model swap +swap_model = None +app_model = None + +def load_model_swap(): + try: + global swap_model + global app_model + if swap_model is None: + swap_model = insightface.model_zoo.get_model("./models/inswapper_128.onnx",download=False,download_zip=False) + app_model = FaceAnalysis(name="buffalo_l") + app_model.prepare(ctx_id=0,det_size=(640,640)) + return True + except: + return False + +def unload_model_swap(): + try: + global swap_model + global app_model + if swap_model is not None: + swap_model = None + if app_model is not None: + app_model = None + return True + except: + return False + +def get_model_state_swap(): + global swap_model + + if swap_model is None: + return False + return True +#model rest +model_scratches_remove = None +model_scratches_remove_detector = None +model_scratches_remove_options = None + + +def load_model_rest(): + try: + + model_path = "./models/zeroscratches/restoration" + global model_scratches_remove_detector + model_scratches_remove_detector = ScratchesDetector('./models/zeroscratches') + gpu_ids = [] + if torch.cuda.is_available(): + gpu_ids = [d for d in range(torch.cuda.device_count())] + global model_scratches_remove_options + model_scratches_remove_options = Options(model_path, gpu_ids) + model_scratches = Pix2PixHDModel_Mapping() + global model_scratches_remove + model_scratches_remove = Pix2PixHDModel_Mapping() + model_scratches_remove.initialize(model_scratches_remove_options) + model_scratches_remove.eval() + return True + except: + return False + + +def get_model_state_rest(): + global model_scratches_remove + if model_scratches_remove is None: + return False + else: + return True + + +def unload_model_rest(): + if get_model_state_rest() == False: + return True + try: + global model_scratches_remove + model_scratches_remove = None + global model_scratches_remove_detector + model_scratches_remove_detector = None + global model_scratches_remove_options + model_scratches_remove_options = None + torch.cuda.empty_cache() + return True + except: + return False +#model talk +from scripts.talker import sad_talker +sad_talker_model = None + +def load_model_talk(): + try: + global sad_talker_model + if sad_talker_model is None: + sad_talker_model = sad_talker() + return True + except: + return False + +def unload_model_talk(): + try: + global sad_talker_model + if sad_talker_model is not None: + sad_talker_model = None + return True + except: + return False + +def get_model_state_talk(): + global sad_talker_model + if sad_talker_model is None: + return False + return True +#model sky +from scripts.SKY import SKY +sky_model = None + +def load_model_sky(): + try: + global sky_model + if sky_model is None: + sky_model = SKY() + return True + except: + return False + +def unload_model_sky(): + try: + global sky_model + if sky_model is not None: + sky_model = None + return True + except: + return False + +def get_model_state_sky(): + global sky_model + if sky_model is None: + return False + return True +#model up +model_upscale = None + +def load_model_up(): + try: + model_path_upscale = './models/RRDB_ESRGAN_x4.pth' + global model_upscale + model_upscale = arch.RRDBNet(3, 3, 64, 23, gc=32) + model_upscale.load_state_dict(torch.load(model_path_upscale), strict=True) + + model_upscale = model_upscale.half() + model_upscale.eval() + model_upscale = model_upscale.to(device) + return True + except: + return False + + +def get_model_state_up(): + global model_upscale + if model_upscale is None: + return False + else: + return True + + +def unload_model_up(): + if get_model_state_up() == False: + return True + try: + global model_upscale + model_upscale = None + torch.cuda.empty_cache() + return True + except: + return False +#agr slider helper +def run_alignment(image_path): + predictor = dlib.shape_predictor("./models/shape_predictor_68_face_landmarks.dat") + aligned_image = align_face(filepath=image_path, predictor=predictor) + return aligned_image + +def run_on_batch(inputs, model_age_slider): + result_batch = model_age_slider(inputs.to("cuda").float(), randomize_noise=False, resize=False) + return result_batch + + +#model model_esrgan +model_esrgan = None + +def load_model_esrgan(): + try: + global model_esrgan + model_esrgan = RealESRGAN( torch.device(device), scale=2) + model_esrgan.load_weights('./models/realesrgan/RealESRGAN_x2plus.pth', download=False) + + return True + except: + return False + + +def get_model_state_esrgan(): + global model_esrgan + if model_esrgan is None: + return False + else: + return True + + +def unload_model_esrgan(): + if get_model_state_esrgan() == False: + return True + try: + global model_esrgan + model_esrgan = None + torch.cuda.empty_cache() + return True + except: + return False + +#model lp +def is_port_in_use(port, host='localhost'): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(1) + try: + s.bind((host, port)) + except socket.error: + return True + return False + +import requests +def is_server_ready(url, timeout=30): + start_time = time.time() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + return True + except requests.RequestException: + pass + + if time.time() - start_time > timeout: + return False + time.sleep(1) + +model_port_lp = None + +def load_port_lp(): + try: + if not get_model_port_lp(): + port_to_check = 60000 + con_count = 0 + while(True): + if con_count > 100: + break + if is_port_in_use(port_to_check) or (search_port(port_to_check)): + port_to_check = port_to_check + 1 + con_count = con_count + 1 + else: + break + port_to_check = str(port_to_check) + script_dir = "/home/qtech/Desktop/API_v1/subprocess/LivePortrait" + + command = ["/home/qtech/miniconda3/envs/LivePortrait/bin/python", script_dir+"/app.py", "--server_port", port_to_check] + process = subprocess.Popen(command, cwd=script_dir) + + global model_port_lp + model_port_lp = {"port":port_to_check , "pid":process} + + insert_port(port_to_check) + + is_server_ready(f'http://127.0.0.1:{port_to_check}/') + + + return True + except: + return False + +def unload_port_lp(): + try: + if get_model_port_lp(): + global model_port_lp + port = model_port_lp['port'] + remove_port(port) + pid = model_port_lp['pid'] + pid.terminate() + pid.kill() + model_port_lp = None + return True + else: + return True + except: + return False + +def get_model_port_lp(): + global model_port_lp + if model_port_lp is None: + return False + return True + +def unload_model_all(): + try: + unload_model_age() + unload_model_bg() + unload_model_cf() + unload_model_up() + unload_model_swap() + unload_model_rest() + unload_model_talk() + unload_model_sky() + unload_model_esrgan() + unload_port_lp() + return True + except: + return False + + +def start_http_file_server(): + directory = './outputs' + port = 8010 + if not os.path.exists(directory): + os.makedirs(directory) + cmd = f"python -m http.server {port} --directory {directory}" + return subprocess.Popen(cmd, shell=True) + +import asyncio +from datetime import datetime + + +time_last_call = None + +async def continuous_function(): + while True: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + global time_last_call + if time_last_call is not None: + current_time = datetime.now() + time_difference = current_time - time_last_call + time_elapsed_minutes = time_difference.total_seconds() / 60 + + if time_elapsed_minutes > 1: + gc.collect() + + if time_elapsed_minutes > 5: + model_delete_status = unload_model_all() + + print(f"model delete status = {model_delete_status}") + + time_last_call = None + + await asyncio.sleep(60*1) + +async def start_continuous_function(): + await continuous_function() + + +app = FastAPI() + +from fastapi.responses import HTMLResponse + + +from fastapi import Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +async def custom_middleware(request: Request, call_next): + client_ip = request.client.host + request_url = str(request.url) + request_method = request.method + endpoint = request.url.path + + log_request(client_ip, request_url, request_method, endpoint) + + if len(search_ban(client_ip)) > 0: + return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) + + # user_agent = request.headers.get("user-agent", "").lower() + # print(user_agent) + # if "android" not in user_agent: + # insert_data(client_ip, "Access using unapproved device type") + # return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) + + response = await call_next(request) + + if (response.status_code == 404) or (response.status_code == 405): + insert_ban(client_ip, "Access using unapproved method/endpoint") + return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) + + return response + +app.middleware('http')(custom_middleware) + + +@app.get("/", response_class=HTMLResponse) +def read_index(): + index_path = './assests/index.html' + with open(index_path, 'r') as file: + content = file.read() + return HTMLResponse(content=content) + +from fastapi.responses import FileResponse + +@app.get("/favicon.ico") +async def get_image(): + image_path = "./assests/logo.png" + return FileResponse(image_path, media_type="image/png") + +@app.on_event("startup") +async def startup_event(): + asyncio.create_task(start_continuous_function()) + +@app.on_event("shutdown") +async def shutdown_event(): + print("Clearing Models") + model_status = unload_model_all() + print("Model clear status = "+str(model_status)) + asyncio.get_event_loop().stop() + print("Stopping Server") + +ALLOWED_IMAGE_FORMATS = ["jpg", "jpeg", "png"] +ALLOWED_AUDIO_FORMATS = ["wav", "mp3"] + +def delete_uuid(uuid_to_delete): + path = "./outputs" + directories = [] + for root, dirs, files in os.walk(path): + for dir in dirs: + directories.append(os.path.join(root, dir)) + for i in directories: + temp = i.split('/') + if temp[-1] == uuid_to_delete: + shutil.rmtree(i) + return True + return False + +@app.delete("/Delete/{uuid}") +async def delete_item(uuid: str): + gc.collect() + try: + result = delete_uuid(uuid) + if result: + return {"error": f"UUID '{uuid}' does not exist"} + + return {"message": f"UUID '{uuid}' deleted"} + + except : + return {"error": "unexpected error"} + + + +@app.put("/ClearModel_ALL") +async def clearModel_ALL(): + try: + model_delete_status = unload_model_all() + + + if model_delete_status: + global time_last_call + time_last_call = None + return {"message": "vram cleared"} + else : + return {"message": "vram not cleared"} + + except : + return {"error": "unexpected error"} + +@app.post("/AgeSlider") +async def age_slider(background_tasks: BackgroundTasks , image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + app_face = FaceAnalysis(name="buffalo_l") + app_face.prepare(ctx_id=0,det_size=(640,640)) + + faces = app_face.get(image) + if len(faces)==0: + response_message = "Error no face detected" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + elif len(faces)!=1: + response_message = "Error more than 1 face detected" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + info_age = faces[0]['age'] + info_gender = faces[0]['gender'] +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/AGE/{unique_id}" + os.makedirs(output_path) + + # loading model + if not get_model_state_age(): + unload_model_all() + model_status = load_model_age() + print(f"model status =========={model_status}") + + + cv2.imwrite(output_path+'/input.png',image) + aligned_image = run_alignment(output_path+'/input.png') + #aligned_image.save(output_path+"/input_aligned_cropped.png") + copy_aligned_image = aligned_image.copy() + + aligned_image.resize((256, 256)) + img_transforms = EXPERIMENT_ARGS['transform'] + input_image = img_transforms(aligned_image) + + target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] + age_transformers = [AgeTransformer(target_age=age) for age in target_ages] + + images = [] + + + for age_transformer in age_transformers: + with torch.no_grad(): + input_image_age = [age_transformer(input_image.cpu()).to('cuda')] + input_image_age = torch.stack(input_image_age) + result_tensor = run_on_batch(input_image_age, model_age_slider)[0] + result_image = tensor2im(result_tensor) + images.append(result_image) + + image_paths_temp = [] + + result_paths_list = [] + for idx,i in enumerate(images): + image_temp = np.array(i) + image_temp = cv2.cvtColor(np.array(image_temp), cv2.COLOR_RGB2BGR) + result_path_temp = output_path+"/output_"+str(idx)+".png" + cv2.imwrite(result_path_temp,image_temp) + + image_url_temp = '/'.join(result_path_temp.split('/')[-3:]) + image_url_temp = f"{web_url}:8001/{image_url_temp}" + result_paths_list.append({"age" :str(target_ages[idx]), "image_url" : image_url_temp}) + + image_paths_temp.append(result_path_temp) + + closest_age = min(target_ages, key=lambda x: abs(x - info_age)) + closest_age = target_ages.index(closest_age) + + os.remove(f"{output_path}/output_{closest_age}.png") + copy_aligned_image = copy_aligned_image.resize((1024, 1024)) + copy_aligned_image.save(f"{output_path}/output_{closest_age}.png") + + background_tasks.add_task(Age_Slider_Video,unique_id , image_paths_temp) + + os.remove(output_path+'/input.png') + + response_message = "AgeSlider Ran Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , info_age , ["Female","Male"][info_gender] , result_paths_list) + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +def Age_Slider_Video(uuid_input , image_files): + temp_file_uuid = str(uuid.uuid4()) + directory_path = f'./outputs/{temp_file_uuid}' + if os.path.exists(directory_path): + shutil.rmtree(directory_path) + os.mkdir(directory_path) + + from scripts.morph_video import doMorphing + doMorphing(image_files, 0.3, 20, f"{directory_path}/output") + video_path = f"{directory_path}/output_combined.mp4" + gif_path = f"{directory_path}/output.gif" + + shutil.copy(video_path, f"./outputs/AGE/{uuid_input}/output_combined.mp4") + shutil.copy(gif_path, f"./outputs/AGE/{uuid_input}/output.gif") + + shutil.rmtree(directory_path) + +@app.post("/AgeSliderVideo/{uuid_input}") +async def ageslidervideo(uuid_input: str): +################## + try: + input_path = f"./outputs/AGE/{uuid_input}" +################################################################################################################################################################################### +############################################################################### Test Input UUID ################################################################################## +################################################################################################################################################################################### + if not os.path.exists(input_path): + response_message = f"Invalid UUID , file does not exist" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Processing ################################################################################## +################################################################################################################################################################################### + + output_video_path = '' + output_gif_path = '' + duration = 120 + start_time = time.time() + while time.time() - start_time < duration: + + if output_video_path == '': + if os.path.exists(input_path+"/output_combined.mp4"): + output_video_path = f'{web_url}:8001/AGE/{uuid_input}'+"/output_combined.mp4" + + if output_gif_path == '': + if os.path.exists(input_path+"/output.gif"): + output_gif_path = f'{web_url}:8001/AGE/{uuid_input}'+"/output.gif" + + if output_video_path != '': + if output_gif_path != '': + response_message = "Video Genrated Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + if output_video_path != '': + if output_gif_path == '': + response_message = "Only Video Genrated" + response_status = "Failure" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + if output_gif_path != '': + if output_video_path == '': + response_message = "Only GIF Genrated" + response_status = "Failure" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + + if output_gif_path == '': + if output_video_path == '': + response_message = "Timeout Reached , Retry Later" + response_status = "Failure" + return generate_response(response_message , response_status ,uuid_input , '' , '', [{"video":output_video_path , "gif":output_gif_path}]) + + + + + except Exception as e: + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + + +@app.post("/BG_remove") +async def bg(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/BG/{unique_id}" + + + + + # loading model + if not get_model_state_bg(): + unload_model_all() + model_status = load_model_bg() + print(f"model status =========={model_status}") + + + output_image = bg_model.BG_remove(image) + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,output_image) + + + + response_message = "BG Removed Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/BG/{unique_id}/output.png"]) + + + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/CodeFormer") +async def codeformer(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(500*500) : + aspect_ratio = width / height + new_height = int(((500*500) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/CF/{unique_id}" + + + # loading model + if not get_model_state_cf(): + unload_model_all() + model_status = load_model_cf() + print(f"model status =========={model_status}") + + + #image = cv2.imread(output_path+'/input.png') + + + has_aligned = False + only_center_face = False + draw_box = False + detection_model = "retinaface_resnet50" + background_enhance = True + face_upsample = True + upscale = 2 + codeformer_fidelity = 0.5 + img = image#cv2.imread(str(img_path), cv2.IMREAD_COLOR) + upscale = int(upscale) # convert type to int + if upscale > 4: # avoid memory exceeded due to too large upscale + upscale = 4 + if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution + upscale = 2 + if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution + upscale = 1 + background_enhance = False + face_upsample = False + face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model=detection_model, + save_ext="png", + use_parse=True, + device=device, + ) + bg_upsampler = model_upsampler if background_enhance else None + face_upsampler = model_upsampler if face_upsample else None + if has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + face_helper.is_gray = is_gray(img, threshold=5) + if face_helper.is_gray: + print('\tgrayscale input: True') + face_helper.cropped_faces = [img] + else: + face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=only_center_face, resize=640, eye_dist_threshold=5 + ) + print(f'\tdetect {num_det_faces} faces') + # align and warp each face + face_helper.align_warp_face() + + # face restoration for each cropped face + for idx, cropped_face in enumerate(face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor( + cropped_face / 255.0, bgr2rgb=True, float32=True + ) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + try: + with torch.no_grad(): + output = model_code_former(cropped_face_t, w=codeformer_fidelity, adain=True)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for CodeFormer: {error}") + restored_face = tensor2img( + cropped_face_t, rgb2bgr=True, min_max=(-1, 1) + ) + restored_face = restored_face.astype("uint8") + face_helper.add_restored_face(restored_face) + + # paste_back + if not has_aligned: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] + else: + bg_img = None + face_helper.get_inverse_affine(None) + # paste each restored face to the input image + if face_upsample and face_upsampler is not None: + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, + draw_box=draw_box, + face_upsampler=face_upsampler, + ) + else: + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, draw_box=draw_box + ) + + + + + + output_image_path = output_path+'/output.png' + os.makedirs(output_path) + cv2.imwrite(output_image_path,restored_img) + + + #cv2.imwrite(output_path+'/output.png',restored_img) + + + + + + + response_message = "CodeFormer Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/CF/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +@app.post("/FaceSwap_single_image") +async def Swap_single_image(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/SWAP/{unique_id}" + + + + + # loading model + if not get_model_state_swap(): + unload_model_all() + model_status = load_model_swap() + print(f"model status =========={model_status}") + + faces = app_model.get(image) + + if len(faces)==0: + response_message = "Error no face detected" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + elif len(faces)!=2: + response_message = "Error more than 2 face detected" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + + face1 = faces[0] + face2 = faces[1] + + image = swap_model.get(image,face1,face2,paste_back=True) + image = swap_model.get(image,face2,face1,paste_back=True) + + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,image) + + + + response_message = "Face Swapped Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SWAP/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + + +@app.post("/FaceSwap_two_images") +async def swap_two_image(image1: UploadFile = File(...), image2: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + images = [] + for idx ,image in enumerate([image1,image2]): + image_format = image.filename.split(".")[-1].lower() + print(image.filename,image.filename.split(".")[-1].lower()) + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image{str(idx+1)} format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + images.append(image) + + if image is None: + response_message = f"Image{str(idx+1)} is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/SWAP/{unique_id}" + + + + + # loading model + if not get_model_state_swap(): + unload_model_all() + model_status = load_model_swap() + print(f"model status =========={model_status}") + + faces = [] + faces.append(app_model.get(images[0])) + faces.append(app_model.get(images[1])) + + for idx ,face in enumerate(faces): + if len(face)==0: + response_message = f"Error no face detected in image{idx}" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + elif len(face) >1: + response_message = f"Error more than 1 face detected in image{idx}" + response_status = "Failure" + return generate_response(response_message,response_status,'' , '' , '' , []) + + face1 = faces[0][0] + face2 = faces[1][0] + + image = swap_model.get(images[0],face1,face2,paste_back=True) + + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,image) + + + + response_message = "Face Swapped Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SWAP/{unique_id}/output.png"]) + + + except Exception as e: + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/Restore_Images") +async def restore(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/Restore/{unique_id}" + os.makedirs(output_path) + + cv2.imwrite(output_path+'/input.png',image) + + + + + # loading model + if not get_model_state_rest(): + unload_model_all() + model_status = load_model_rest() + print(f"model status =========={model_status}") + + image = PIL.Image.open(output_path+'/input.png') + os.remove(output_path+'/input.png') + + + transformed, mask = model_scratches_remove_detector.process(image) + img_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + mask_transform = transforms.ToTensor() + if model_scratches_remove_options.mask_dilation != 0: + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask) + mask = cv2.dilate(mask, kernel, iterations=model_scratches_remove_options.mask_dilation) + mask = PIL.Image.fromarray(mask.astype('uint8')) + transformed = irregular_hole_synthesize(transformed, mask) + mask = mask_transform(mask) + mask = mask[:1, :, :] + mask = mask.unsqueeze(0) + transformed = img_transform(transformed) + transformed = transformed.unsqueeze(0) + generated = model_scratches_remove.inference(transformed, mask) + tensor_restored = (generated.data.cpu() + 1.0) / 2.0 + image_to_show = tensor_restored.squeeze().cpu().numpy().transpose((1, 2, 0)) + + image_to_show = (image_to_show * 255).astype(np.uint8)[:, :, ::-1] + + + from scripts.colorizer import colorize_image + image_to_show = colorize_image(image_to_show) + + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,image_to_show) + + + + response_message = "Restred Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Restore/{unique_id}/output.png"]) + + + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/Sad_Talker") +async def sadtalker(image: UploadFile = File(...), audio: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Test Input Audio ################################################################################## +################################################################################################################################################################################### + audio_format = audio.filename.split(".")[-1].lower() + if audio_format not in ALLOWED_AUDIO_FORMATS: + response_message = f"Unsupported audio format. Supported formats: {', '.join(ALLOWED_AUDIO_FORMATS)}" + response_status = "Failure" + return generate_response(response_message, response_status, '', '', '', []) + + audio_contents = await audio.read() + if len(audio_contents) == 0: + response_message = "Audio file is empty" + response_status = "Failure" + return generate_response(response_message, response_status, '', '', '', []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + + + output_path = f"./outputs/Sad_Talker/{unique_id}" + + os.makedirs(output_path) + + image_input_path = output_path+'/input.png' + cv2.imwrite(image_input_path,image) + + audio_segment = AudioSegment.from_file(io.BytesIO(audio_contents), format=audio_format) + audio_input_path = os.path.join(output_path, f"audio.{audio_format}") + audio_segment.export(audio_input_path, format=audio_format) + + app_face = FaceAnalysis(name="buffalo_l") + app_face.prepare(ctx_id=0,det_size=(640,640)) + + faces = app_face.get(image) + if len(faces)==0: + response_message = "Error no face detected" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + elif len(faces)!=1: + response_message = "Error more than 1 face detected" + response_status = "Failure" + + return generate_response(response_message , response_status ,'' , '' , '' , []) + + info_age = faces[0]['age'] + info_gender = faces[0]['gender'] + + + # loading model + if not get_model_state_talk(): + unload_model_all() + model_status = load_model_talk() + print(f"model status =========={model_status}") + + + + + response = sad_talker_model.genrate_video(image_input_path,audio_input_path,output_path,True) + + # input_file = f"{output_path}/output.mp4 " + # subprocess.run(['ffmpeg', '-i', input_file, '-c:v', 'libx264', '-preset', 'slow', '-crf', '23', '-c:a', 'aac', '-b:a', '192k', '-y', input_file], check=True) + + os.remove(image_input_path) + os.remove(audio_input_path) + + if response: + + input_file = f"outputs/Sad_Talker/{unique_id}/output.mp4" + output_file = f"outputs/Sad_Talker/{unique_id}/final.mp4" + + subprocess.run(['ffmpeg', '-i', input_file, '-c:v', 'libx264', '-preset', 'slow', '-crf', '23', '-c:a', 'aac', '-b:a', '192k', output_file], check=True) + + os.remove(input_file) + + response_message = "Video Generated Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , info_age , ["Female","Male"][info_gender] , [f"{web_url}:8001/Sad_Talker/{unique_id}/final.mp4"]) + + else: + response_message = "Error Genrating Video, try another image" + response_status = "Failure" + #shutil.rmtree(output_path) + return generate_response(response_message , response_status ,'','','',[]) + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/SKY_remove") +async def sky(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/SKY/{unique_id}" + + + + + # loading model + if not get_model_state_sky(): + unload_model_all() + model_status = load_model_sky() + print(f"model status =========={model_status}") + + + output_image = sky_model.SKY_remove(image) + + os.makedirs(output_path) + output_image_path = output_path+'/output.png' + cv2.imwrite(output_image_path,output_image) + + + + response_message = "SKY Removed Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/SKY/{unique_id}/output.png"]) + + + + + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + + +@app.post("/ImageUpscale") +async def upscale(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(200*200) : + aspect_ratio = width / height + new_height = int(((200*200) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/Upscale/{unique_id}" + + + + # loading model + if not get_model_state_up(): + unload_model_all() + model_status = load_model_up() + print(f"model status =========={model_status}") + + + #image = cv2.imread(output_path+'/input.png') + + image = image * 1.0 / 255 + image = torch.from_numpy(np.transpose(image[:, :, [2, 1, 0]], (2, 0, 1))).float() + img_LR = image.unsqueeze(0) + img_LR = img_LR.to(device).half() + + output = model_upscale(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy() + + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round() + image = output.astype(np.uint8) + + output_image_path = output_path+'/output.png' + os.makedirs(output_path) + cv2.imwrite(output_image_path,image) + + + + response_message = "Upscaled Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Upscale/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +@app.post("/ImageDenoise") +async def deniose(image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(500*500) : + aspect_ratio = width / height + new_height = int(((500*500) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/Denoise/{unique_id}" + + + + # loading model + if not get_model_state_esrgan(): + unload_model_all() + model_status = load_model_esrgan() + print(f"model status =========={model_status}") + + + #image = cv2.imread(output_path+'/input.png') + + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_pil = Image.fromarray(image_rgb) + global model_esrgan + result_pil = model_esrgan.predict(image_pil) + + result_rgb = result_pil.convert('RGB') + result_np = np.array(result_rgb) + image = cv2.cvtColor(result_np, cv2.COLOR_RGB2BGR) + + output_image_path = output_path+'/output.png' + os.makedirs(output_path) + cv2.imwrite(output_image_path,image) + + + + response_message = "Upscaled Image Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/Denoise/{unique_id}/output.png"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +@app.post("/LivePortrait/{index}") +async def lp(index: int,image: UploadFile = File(...)): + +################# +################# + #get time + global time_last_call + time_last_call = datetime.now() +################## + try: + +################################################################################################################################################################################### +############################################################################### Test Input Image ################################################################################## +################################################################################################################################################################################### + image_format = image.filename.split(".")[-1].lower() + if image_format not in ALLOWED_IMAGE_FORMATS: + response_message = f"Unsupported image format. Supported formats: {', '.join(ALLOWED_IMAGE_FORMATS)}" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + contents = await image.read() + nparr = np.frombuffer(contents, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if image is None: + response_message = "Image is empty" + response_status = "Failure" + return generate_response(response_message , response_status ,'' , '' , '' , []) + + height = image.shape[0] + width = image.shape[1] + + if (height*width)>(1500*1500) : + print(11) + aspect_ratio = width / height + new_height = int(((500*500) / aspect_ratio) ** 0.5) + new_width = int(aspect_ratio * new_height) + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) + +################################################################################################################################################################################### +############################################################################### Image Processing ################################################################################## +################################################################################################################################################################################### + + + + + + unique_id = str(uuid.uuid4()) + output_path = f"./outputs/LP/{unique_id}" + os.makedirs(output_path) + + image_input_path = output_path+'/input.png' + cv2.imwrite(image_input_path,image) + + + import glob + directory_path = "./subprocess/LivePortrait/assets/examples/driving" + mp4_files = glob.glob(os.path.join(directory_path, "*.mp4")) + + + + + # loading model + if not get_model_port_lp(): + unload_model_all() + model_status = load_port_lp() + print(f"model status =========={model_status}") + + + global model_port_lp + client = Client(f"http://127.0.0.1:{model_port_lp['port']}/") + + result = client.predict( + param_0=handle_file(os.path.abspath(output_path+'/input.png')), + param_1={"video": handle_file(mp4_files[index])}, + param_2=True, + param_3=True, + param_4=True, + param_5=False, + api_name="/gpu_wrapped_execute_video" + ) + + shutil.copy(result[0]['video'], output_path+'/output.mp4') + + for i in result: + root = os.path.dirname(i['video']) + + if os.path.isdir(root): + shutil.rmtree(root) + + os.remove(output_path+'/input.png') + + response_message = "LP genrated Sucessfully" + response_status = "Successful" + return generate_response(response_message , response_status ,unique_id , '' , '', [f"{web_url}:8001/LP/{unique_id}/output.mp4"]) + + + except Exception as e: + shutil.rmtree(output_path) + return generate_response(f"Unknown Internal Error : {e}" , "Failure" ,'' , '' , '' , []) + +# if __name__ == '__main__': + +# import uvicorn +# try: +# http_file_server_process = start_http_file_server() +# uvicorn.run("main:app", host='0.0.0.0', port=8080, workers=1 , limit_max_requests=20) +# except KeyboardInterrupt: +# print("Received KeyboardInterrupt, shutting down gracefully...") +# http_file_server_process.kill() +# shutdown_event() + +if __name__ == '__main__': + + import uvicorn + import os + try: + if not os.path.exists('./subprocess/port.db'): + os.remove('./subprocess/port.db') + create_database_port() + + http_file_server_process = start_http_file_server() + uvicorn.run("main:app", host='0.0.0.0', port=8080, workers=2 , limit_max_requests=200) + print("Received KeyboardInterrupt, shutting down gracefully...") + http_file_server_process.kill() + shutdown_event() + except KeyboardInterrupt: + print("Received KeyboardInterrupt, shutting down gracefully...") + http_file_server_process.kill() + shutdown_event() diff --git a/models/CodeFormer/.gitkeep b/models/CodeFormer/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/CodeFormer/codeformer.pth b/models/CodeFormer/codeformer.pth new file mode 100644 index 0000000000000000000000000000000000000000..edd450da13c5ff890f70d726c992af569813f6af --- /dev/null +++ b/models/CodeFormer/codeformer.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1009e537e0c2a07d4cabce6355f53cb66767cd4b4297ec7a4a64ca4b8a5684b7 +size 376637898 diff --git a/models/RMBG-1.4/.gitattributes b/models/RMBG-1.4/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..85d164dd4039759357c6f4f14bb467c6b86bbedd --- /dev/null +++ b/models/RMBG-1.4/.gitattributes @@ -0,0 +1,41 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +example.png filter=lfs diff=lfs merge=lfs -text +results.png filter=lfs diff=lfs merge=lfs -text +Screenshot[[:space:]]2024-01-21[[:space:]]at[[:space:]]11.56.17.png filter=lfs diff=lfs merge=lfs -text +T1.png filter=lfs diff=lfs merge=lfs -text +T2.png filter=lfs diff=lfs merge=lfs -text +t4.png filter=lfs diff=lfs merge=lfs -text diff --git a/models/RMBG-1.4/MyConfig.py b/models/RMBG-1.4/MyConfig.py new file mode 100644 index 0000000000000000000000000000000000000000..d50006abdaa687edd18dd39c9fb5ac9a31023773 --- /dev/null +++ b/models/RMBG-1.4/MyConfig.py @@ -0,0 +1,13 @@ +from transformers import PretrainedConfig +from typing import List + +class RMBGConfig(PretrainedConfig): + model_type = "SegformerForSemanticSegmentation" + def __init__( + self, + in_ch=3, + out_ch=1, + **kwargs): + self.in_ch = in_ch + self.out_ch = out_ch + super().__init__(**kwargs) diff --git a/models/RMBG-1.4/MyPipe.py b/models/RMBG-1.4/MyPipe.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b127caa6e4e4ef7c84e1a92144480660547b6b --- /dev/null +++ b/models/RMBG-1.4/MyPipe.py @@ -0,0 +1,76 @@ +import torch, os +import torch.nn.functional as F +from torchvision.transforms.functional import normalize +import numpy as np +from transformers import Pipeline +from transformers.image_utils import load_image +from skimage import io +from PIL import Image + +class RMBGPipe(Pipeline): + def __init__(self,**kwargs): + Pipeline.__init__(self,**kwargs) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + self.model.eval() + + def _sanitize_parameters(self, **kwargs): + # parse parameters + preprocess_kwargs = {} + postprocess_kwargs = {} + if "model_input_size" in kwargs : + preprocess_kwargs["model_input_size"] = kwargs["model_input_size"] + if "return_mask" in kwargs: + postprocess_kwargs["return_mask"] = kwargs["return_mask"] + return preprocess_kwargs, {}, postprocess_kwargs + + def preprocess(self,input_image,model_input_size: list=[1024,1024]): + # preprocess the input + orig_im = load_image(input_image) + orig_im = np.array(orig_im) + orig_im_size = orig_im.shape[0:2] + preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device) + inputs = { + "preprocessed_image":preprocessed_image, + "orig_im_size":orig_im_size, + "input_image" : input_image + } + return inputs + + def _forward(self,inputs): + result = self.model(inputs.pop("preprocessed_image")) + inputs["result"] = result + return inputs + + def postprocess(self,inputs,return_mask:bool=False ): + result = inputs.pop("result") + orig_im_size = inputs.pop("orig_im_size") + input_image = inputs.pop("input_image") + result_image = self.postprocess_image(result[0][0], orig_im_size) + pil_im = Image.fromarray(result_image) + if return_mask ==True : + return pil_im + no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0)) + input_image = load_image(input_image) + no_bg_image.paste(input_image, mask=pil_im) + return no_bg_image + + # utilities functions + def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor: + # same as utilities.py with minor modification + if len(im.shape) < 3: + im = im[:, :, np.newaxis] + im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1) + im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear') + image = torch.divide(im_tensor,255.0) + image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0]) + return image + + def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray: + result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0) + ma = torch.max(result) + mi = torch.min(result) + result = (result-mi)/(ma-mi) + im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8) + im_array = np.squeeze(im_array) + return im_array diff --git a/models/RMBG-1.4/README.md b/models/RMBG-1.4/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b2cf762e319cc9cc96209c42b0ac79630f804c12 --- /dev/null +++ b/models/RMBG-1.4/README.md @@ -0,0 +1,157 @@ +--- +license: other +license_name: bria-rmbg-1.4 +license_link: https://bria.ai/bria-huggingface-model-license-agreement/ +pipeline_tag: image-segmentation +tags: +- remove background +- background +- background-removal +- Pytorch +- vision +- legal liability +- transformers + +extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you. +extra_gated_fields: + Name: text + Company/Org name: text + Org Type (Early/Growth Startup, Enterprise, Academy): text + Role: text + Country: text + Email: text + By submitting this form, I agree to BRIA’s Privacy policy and Terms & conditions, see links below: checkbox +--- + +# BRIA Background Removal v1.4 Model Card + +RMBG v1.4 is our state-of-the-art background removal model, designed to effectively separate foreground from background in a range of +categories and image types. This model has been trained on a carefully selected dataset, which includes: +general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale. +The accuracy, efficiency, and versatility currently rival leading source-available models. +It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount. + +Developed by BRIA AI, RMBG v1.4 is available as a source-available model for non-commercial use. + +[CLICK HERE FOR A DEMO](https://huggingface.co/spaces/briaai/BRIA-RMBG-1.4) +![examples](t4.png) + +### Model Description + +- **Developed by:** [BRIA AI](https://bria.ai/) +- **Model type:** Background Removal +- **License:** [bria-rmbg-1.4](https://bria.ai/bria-huggingface-model-license-agreement/) + - The model is released under a Creative Commons license for non-commercial use. + - Commercial use is subject to a commercial agreement with BRIA. [Contact Us](https://bria.ai/contact-us) for more information. + +- **Model Description:** BRIA RMBG 1.4 is a saliency segmentation model trained exclusively on a professional-grade dataset. +- **BRIA:** Resources for more information: [BRIA AI](https://bria.ai/) + + + +## Training data +Bria-RMBG model was trained with over 12,000 high-quality, high-resolution, manually labeled (pixel-wise accuracy), fully licensed images. +Our benchmark included balanced gender, balanced ethnicity, and people with different types of disabilities. +For clarity, we provide our data distribution according to different categories, demonstrating our model’s versatility. + +### Distribution of images: + +| Category | Distribution | +| -----------------------------------| -----------------------------------:| +| Objects only | 45.11% | +| People with objects/animals | 25.24% | +| People only | 17.35% | +| people/objects/animals with text | 8.52% | +| Text only | 2.52% | +| Animals only | 1.89% | + +| Category | Distribution | +| -----------------------------------| -----------------------------------------:| +| Photorealistic | 87.70% | +| Non-Photorealistic | 12.30% | + + +| Category | Distribution | +| -----------------------------------| -----------------------------------:| +| Non Solid Background | 52.05% | +| Solid Background | 47.95% + + +| Category | Distribution | +| -----------------------------------| -----------------------------------:| +| Single main foreground object | 51.42% | +| Multiple objects in the foreground | 48.58% | + + +## Qualitative Evaluation + +![examples](results.png) + + +## Architecture + +RMBG v1.4 is developed on the [IS-Net](https://github.com/xuebinqin/DIS) enhanced with our unique training scheme and proprietary dataset. +These modifications significantly improve the model’s accuracy and effectiveness in diverse image-processing scenarios. + +## Installation +```bash +pip install -qr https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt +``` + +## Usage + +Either load the pipeline +```python +from transformers import pipeline +image_path = "https://farm5.staticflickr.com/4007/4322154488_997e69e4cf_z.jpg" +pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True) +pillow_mask = pipe(image_path, return_mask = True) # outputs a pillow mask +pillow_image = pipe(image_path) # applies mask on input and returns a pillow image +``` + +Or load the model +```python +from transformers import AutoModelForImageSegmentation +from torchvision.transforms.functional import normalize +model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True) +def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor: + if len(im.shape) < 3: + im = im[:, :, np.newaxis] + # orig_im_size=im.shape[0:2] + im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1) + im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear') + image = torch.divide(im_tensor,255.0) + image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0]) + return image + +def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray: + result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0) + ma = torch.max(result) + mi = torch.min(result) + result = (result-mi)/(ma-mi) + im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8) + im_array = np.squeeze(im_array) + return im_array + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model.to(device) + +# prepare input +image_path = "https://farm5.staticflickr.com/4007/4322154488_997e69e4cf_z.jpg" +orig_im = io.imread(image_path) +orig_im_size = orig_im.shape[0:2] +image = preprocess_image(orig_im, model_input_size).to(device) + +# inference +result=model(image) + +# post process +result_image = postprocess_image(result[0][0], orig_im_size) + +# save result +pil_im = Image.fromarray(result_image) +no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0)) +orig_image = Image.open(image_path) +no_bg_image.paste(orig_image, mask=pil_im) +``` + diff --git a/models/RMBG-1.4/briarmbg.py b/models/RMBG-1.4/briarmbg.py new file mode 100644 index 0000000000000000000000000000000000000000..da8591aafb1fd74c194a665302452a2aa591d161 --- /dev/null +++ b/models/RMBG-1.4/briarmbg.py @@ -0,0 +1,458 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel +from .MyConfig import RMBGConfig + +class REBNCONV(nn.Module): + def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): + super(REBNCONV,self).__init__() + + self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self,x): + + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src,tar): + + src = F.interpolate(src,size=tar.shape[2:],mode='bilinear') + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7,self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) + hx6dup = _upsample_like(hx6d,hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + + hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-5 ### +class RSU5(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4 ### +class RSU4(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4F ### +class RSU4F(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) + hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__(self, in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1): + super(myrebnconv,self).__init__() + + self.conv = nn.Conv2d(in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self,x): + return self.rl(self.bn(self.conv(x))) + + +class BriaRMBG(PreTrainedModel): + config_class = RMBGConfig + def __init__(self,config:RMBGConfig = RMBGConfig()): + super().__init__(config) + in_ch = config.in_ch # 3 + out_ch = config.out_ch # 1 + self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) + self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage1 = RSU7(64,32,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,32,128) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(128,64,256) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(256,128,512) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(512,256,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,256,512) + + # decoder + self.stage5d = RSU4F(1024,256,512) + self.stage4d = RSU4(1024,128,256) + self.stage3d = RSU5(512,64,128) + self.stage2d = RSU6(256,32,64) + self.stage1d = RSU7(128,16,64) + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def forward(self,x): + + hx = x + + hxin = self.conv_in(hx) + #hx = self.pool_in(hxin) + + #stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6,hx5) + + #-------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) + + + #side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1,x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2,x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3,x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4,x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5,x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,x) + + return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] + diff --git a/models/RMBG-1.4/config.json b/models/RMBG-1.4/config.json new file mode 100644 index 0000000000000000000000000000000000000000..a587b1fea5d4c20cc7be9ce8790300e414ebed9d --- /dev/null +++ b/models/RMBG-1.4/config.json @@ -0,0 +1,25 @@ +{ + "_name_or_path": "briaai/RMBG-1.4", + "architectures": [ + "BriaRMBG" + ], + "auto_map": { + "AutoConfig": "MyConfig.RMBGConfig", + "AutoModelForImageSegmentation": "briarmbg.BriaRMBG" + }, + "custom_pipelines": { + "image-segmentation": { + "impl": "MyPipe.RMBGPipe", + "pt": [ + "AutoModelForImageSegmentation" + ], + "tf": [], + "type": "image" + } + }, + "in_ch": 3, + "model_type": "SegformerForSemanticSegmentation", + "out_ch": 1, + "torch_dtype": "float32", + "transformers_version": "4.38.0.dev0" +} diff --git a/models/RMBG-1.4/example_inference.py b/models/RMBG-1.4/example_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d21c4d021ed462edd947ce4436de01f0d7b6d0c2 --- /dev/null +++ b/models/RMBG-1.4/example_inference.py @@ -0,0 +1,39 @@ +from skimage import io +import torch, os +from PIL import Image +from briarmbg import BriaRMBG +from utilities import preprocess_image, postprocess_image +from huggingface_hub import hf_hub_download + +def example_inference(): + + im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg" + + net = BriaRMBG() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = BriaRMBG.from_pretrained("briaai/RMBG-1.4") + net.to(device) + net.eval() + + # prepare input + model_input_size = [1024,1024] + orig_im = io.imread(im_path) + orig_im_size = orig_im.shape[0:2] + image = preprocess_image(orig_im, model_input_size).to(device) + + # inference + result=net(image) + + # post process + result_image = postprocess_image(result[0][0], orig_im_size) + + # save result + pil_im = Image.fromarray(result_image) + no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0)) + orig_image = Image.open(im_path) + no_bg_image.paste(orig_image, mask=pil_im) + no_bg_image.save("example_image_no_bg.png") + + +if __name__ == "__main__": + example_inference() \ No newline at end of file diff --git a/models/RMBG-1.4/model.pth b/models/RMBG-1.4/model.pth new file mode 100644 index 0000000000000000000000000000000000000000..5a35fd28a4f04bc3a38135dec168c918632d6e8c --- /dev/null +++ b/models/RMBG-1.4/model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:893c16c340b1ddafc93e78457a4d94190da9b7179149f8574284c83caebf5e8c +size 176718373 diff --git a/models/RMBG-1.4/model.safetensors b/models/RMBG-1.4/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..cb830b5251e59a882ce09a8318101cec5aa51496 --- /dev/null +++ b/models/RMBG-1.4/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46ef7fe46f2ae284d8f1aaa24bfa5fca5ef25a34e2c7caa890a0029eb100e87f +size 176381984 diff --git a/models/RMBG-1.4/onnx/model.onnx b/models/RMBG-1.4/onnx/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b163d4126e96faf2e2ccce840dcd3e84877f94a7 --- /dev/null +++ b/models/RMBG-1.4/onnx/model.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cafcf770b06757c4eaced21b1a88e57fd2b66de01b8045f35f01535ba742e0f +size 176153355 diff --git a/models/RMBG-1.4/onnx/model_fp16.onnx b/models/RMBG-1.4/onnx/model_fp16.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f576b317aa978e06c2defef2c572b2449d46f247 --- /dev/null +++ b/models/RMBG-1.4/onnx/model_fp16.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9fdfdb41866d872e0acf4a010c35c1a8547bf0eebe0d1544406bbf1c824cb59d +size 88217533 diff --git a/models/RMBG-1.4/onnx/model_quantized.onnx b/models/RMBG-1.4/onnx/model_quantized.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b3f42878420c3951cb341d1c2e6d4d04458aefa5 --- /dev/null +++ b/models/RMBG-1.4/onnx/model_quantized.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6648479275dfd0ede0f3a8abc20aa5c437b394681b05e5af6d268250aaf40f3 +size 44403226 diff --git a/models/RMBG-1.4/onnx/quantize_config.json b/models/RMBG-1.4/onnx/quantize_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d0dd447b33c4757b5e90081b61a48cb2d65d4634 --- /dev/null +++ b/models/RMBG-1.4/onnx/quantize_config.json @@ -0,0 +1,24 @@ +{ + "per_channel": false, + "reduce_range": false, + "per_model_config": { + "model": { + "op_types": [ + "Concat", + "MaxPool", + "Resize", + "Conv", + "Unsqueeze", + "Cast", + "Shape", + "Relu", + "Sigmoid", + "Gather", + "Constant", + "Slice", + "Add" + ], + "weight_type": "QUInt8" + } + } +} \ No newline at end of file diff --git a/models/RMBG-1.4/preprocessor_config.json b/models/RMBG-1.4/preprocessor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..ef29ef503ab28b350a34ea32785f8015af34b5c1 --- /dev/null +++ b/models/RMBG-1.4/preprocessor_config.json @@ -0,0 +1,23 @@ +{ + "do_normalize": true, + "do_pad": false, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.5, + 0.5, + 0.5 + ], + "feature_extractor_type": "ImageFeatureExtractor", + "image_std": [ + 1, + 1, + 1 + ], + "resample": 2, + "rescale_factor": 0.00392156862745098, + "size": { + "width": 1024, + "height": 1024 + } +} \ No newline at end of file diff --git a/models/RMBG-1.4/pytorch_model.bin b/models/RMBG-1.4/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..49f0ef6038eff1254a986d9635cc7dfde3e1973b --- /dev/null +++ b/models/RMBG-1.4/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59569acdb281ac9fc9f78f9d33b6f9f17f68e25086b74f9025c35bb5f2848967 +size 176574018 diff --git a/models/RMBG-1.4/requirements.txt b/models/RMBG-1.4/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..24fc6e3706df0ed5e8bf8ece6d3a14bf261679d6 --- /dev/null +++ b/models/RMBG-1.4/requirements.txt @@ -0,0 +1,8 @@ +torch +torchvision +pillow +numpy +typing +scikit-image +huggingface_hub +transformers>=4.39.1 \ No newline at end of file diff --git a/models/RMBG-1.4/utilities.py b/models/RMBG-1.4/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..77bf9e6f0ea22608c69a6e1c5528a3f225b48a63 --- /dev/null +++ b/models/RMBG-1.4/utilities.py @@ -0,0 +1,25 @@ +import torch +import torch.nn.functional as F +from torchvision.transforms.functional import normalize +import numpy as np + +def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor: + if len(im.shape) < 3: + im = im[:, :, np.newaxis] + # orig_im_size=im.shape[0:2] + im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1) + im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8) + image = torch.divide(im_tensor,255.0) + image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0]) + return image + + +def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray: + result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0) + ma = torch.max(result) + mi = torch.min(result) + result = (result-mi)/(ma-mi) + im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8) + im_array = np.squeeze(im_array) + return im_array + \ No newline at end of file diff --git a/models/RRDB_ESRGAN_x4.pth b/models/RRDB_ESRGAN_x4.pth new file mode 100644 index 0000000000000000000000000000000000000000..101716c6c861c6059e5820332256f3d57d97d117 --- /dev/null +++ b/models/RRDB_ESRGAN_x4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65fece06e1ccb48853242aa972bdf00ad07a7dd8938d2dcbdf4221b59f6372ce +size 66929193 diff --git a/models/SadTalker_V0.0.2_256.safetensors b/models/SadTalker_V0.0.2_256.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..1d0eb9787332ff6c6a603c0e79ebada49010270a --- /dev/null +++ b/models/SadTalker_V0.0.2_256.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c211f5d6de003516bf1bbda9f47049a4c9c99133b1ab565c6961e5af16477bff +size 725066984 diff --git a/models/deoldify_artistic.onnx b/models/deoldify_artistic.onnx new file mode 100644 index 0000000000000000000000000000000000000000..05aa08f4872ea5bd6830e9b5ec7a5d23982c923b --- /dev/null +++ b/models/deoldify_artistic.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be026e17c47c85527b3084cacad352f7ca0e021c33aa827062c5997ebe72c61f +size 255024891 diff --git a/models/inswapper_128.onnx b/models/inswapper_128.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cb672b799d74fdf7ab8b172a1b1d78411f6400f5 --- /dev/null +++ b/models/inswapper_128.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af +size 554253681 diff --git a/models/mapping_00109-model.pth.tar b/models/mapping_00109-model.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..009c3190f5d903c56a2fb0a085d605dc782a83c9 --- /dev/null +++ b/models/mapping_00109-model.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84a8642468a3fcfdd9ab6be955267043116c2bec2284686a5262f1eaf017f64c +size 155779231 diff --git a/models/mask2former-swin-large-ade-semantic/.gitattributes b/models/mask2former-swin-large-ade-semantic/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/models/mask2former-swin-large-ade-semantic/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/models/mask2former-swin-large-ade-semantic/README.md b/models/mask2former-swin-large-ade-semantic/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e8149a252c425cbeab21072c5577f26c4a8e6087 --- /dev/null +++ b/models/mask2former-swin-large-ade-semantic/README.md @@ -0,0 +1,67 @@ +--- +license: other +tags: +- vision +- image-segmentation +datasets: +- coco +widget: +- src: http://images.cocodataset.org/val2017/000000039769.jpg + example_title: Cats +- src: http://images.cocodataset.org/val2017/000000039770.jpg + example_title: Castle +--- + +# Mask2Former + +Mask2Former model trained on ADE20k semantic segmentation (large-sized version, Swin backbone). It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation +](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). + +Disclaimer: The team releasing Mask2Former did not write a model card for this model so this model card has been written by the Hugging Face team. + +## Model description + +Mask2Former addresses instance, semantic and panoptic segmentation with the same paradigm: by predicting a set of masks and corresponding labels. Hence, all 3 tasks are treated as if they were instance segmentation. Mask2Former outperforms the previous SOTA, +[MaskFormer](https://arxiv.org/abs/2107.06278) both in terms of performance an efficiency by (i) replacing the pixel decoder with a more advanced multi-scale deformable attention Transformer, (ii) adopting a Transformer decoder with masked attention to boost performance without +without introducing additional computation and (iii) improving training efficiency by calculating the loss on subsampled points instead of whole masks. + +![model image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/mask2former_architecture.png) + +## Intended uses & limitations + +You can use this particular checkpoint for panoptic segmentation. See the [model hub](https://huggingface.co/models?search=mask2former) to look for other +fine-tuned versions on a task that interests you. + +### How to use + +Here is how to use this model: + +```python +import requests +import torch +from PIL import Image +from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation + + +# load Mask2Former fine-tuned on ADE20k semantic segmentation +processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") +model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic") + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) +inputs = processor(images=image, return_tensors="pt") + +with torch.no_grad(): + outputs = model(**inputs) + +# model predicts class_queries_logits of shape `(batch_size, num_queries)` +# and masks_queries_logits of shape `(batch_size, num_queries, height, width)` +class_queries_logits = outputs.class_queries_logits +masks_queries_logits = outputs.masks_queries_logits + +# you can pass them to processor for postprocessing +predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] +# we refer to the demo notebooks for visualization (see "Resources" section in the Mask2Former docs) +``` + +For more code examples, we refer to the [documentation](https://huggingface.co/docs/transformers/master/en/model_doc/mask2former). \ No newline at end of file diff --git a/models/mask2former-swin-large-ade-semantic/config.json b/models/mask2former-swin-large-ade-semantic/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6cc89c35ce70d7227b8956eab0db1c6240ae18a7 --- /dev/null +++ b/models/mask2former-swin-large-ade-semantic/config.json @@ -0,0 +1,2451 @@ +{ + "_commit_hash": null, + "activation_function": "relu", + "architectures": [ + "Mask2FormerForUniversalSegmentation" + ], + "backbone_config": { + "_name_or_path": "", + "add_cross_attention": false, + "architectures": [ + "SwinForImageClassification" + ], + "attention_probs_dropout_prob": 0.0, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depths": [ + 2, + 2, + 18, + 2 + ], + "diversity_penalty": 0.0, + "do_sample": false, + "drop_path_rate": 0.3, + "early_stopping": false, + "embed_dim": 192, + "encoder_no_repeat_ngram_size": 0, + "encoder_stride": 32, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1536, + "id2label": { + "0": "tench, Tinca tinca", + "1": "goldfish, Carassius auratus", + "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "3": "tiger shark, Galeocerdo cuvieri", + "4": "hammerhead, hammerhead shark", + "5": "electric ray, crampfish, numbfish, torpedo", + "6": "stingray", + "7": "cock", + "8": "hen", + "9": "ostrich, Struthio camelus", + "10": "brambling, Fringilla montifringilla", + "11": "goldfinch, Carduelis carduelis", + "12": "house finch, linnet, Carpodacus mexicanus", + "13": "junco, snowbird", + "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "15": "robin, American robin, Turdus migratorius", + "16": "bulbul", + "17": "jay", + "18": "magpie", + "19": "chickadee", + "20": "water ouzel, dipper", + "21": "kite", + "22": "bald eagle, American eagle, Haliaeetus leucocephalus", + "23": "vulture", + "24": "great grey owl, great gray owl, Strix nebulosa", + "25": "European fire salamander, Salamandra salamandra", + "26": "common newt, Triturus vulgaris", + "27": "eft", + "28": "spotted salamander, Ambystoma maculatum", + "29": "axolotl, mud puppy, Ambystoma mexicanum", + "30": "bullfrog, Rana catesbeiana", + "31": "tree frog, tree-frog", + "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "33": "loggerhead, loggerhead turtle, Caretta caretta", + "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "35": "mud turtle", + "36": "terrapin", + "37": "box turtle, box tortoise", + "38": "banded gecko", + "39": "common iguana, iguana, Iguana iguana", + "40": "American chameleon, anole, Anolis carolinensis", + "41": "whiptail, whiptail lizard", + "42": "agama", + "43": "frilled lizard, Chlamydosaurus kingi", + "44": "alligator lizard", + "45": "Gila monster, Heloderma suspectum", + "46": "green lizard, Lacerta viridis", + "47": "African chameleon, Chamaeleo chamaeleon", + "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "49": "African crocodile, Nile crocodile, Crocodylus niloticus", + "50": "American alligator, Alligator mississipiensis", + "51": "triceratops", + "52": "thunder snake, worm snake, Carphophis amoenus", + "53": "ringneck snake, ring-necked snake, ring snake", + "54": "hognose snake, puff adder, sand viper", + "55": "green snake, grass snake", + "56": "king snake, kingsnake", + "57": "garter snake, grass snake", + "58": "water snake", + "59": "vine snake", + "60": "night snake, Hypsiglena torquata", + "61": "boa constrictor, Constrictor constrictor", + "62": "rock python, rock snake, Python sebae", + "63": "Indian cobra, Naja naja", + "64": "green mamba", + "65": "sea snake", + "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "68": "sidewinder, horned rattlesnake, Crotalus cerastes", + "69": "trilobite", + "70": "harvestman, daddy longlegs, Phalangium opilio", + "71": "scorpion", + "72": "black and gold garden spider, Argiope aurantia", + "73": "barn spider, Araneus cavaticus", + "74": "garden spider, Aranea diademata", + "75": "black widow, Latrodectus mactans", + "76": "tarantula", + "77": "wolf spider, hunting spider", + "78": "tick", + "79": "centipede", + "80": "black grouse", + "81": "ptarmigan", + "82": "ruffed grouse, partridge, Bonasa umbellus", + "83": "prairie chicken, prairie grouse, prairie fowl", + "84": "peacock", + "85": "quail", + "86": "partridge", + "87": "African grey, African gray, Psittacus erithacus", + "88": "macaw", + "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "90": "lorikeet", + "91": "coucal", + "92": "bee eater", + "93": "hornbill", + "94": "hummingbird", + "95": "jacamar", + "96": "toucan", + "97": "drake", + "98": "red-breasted merganser, Mergus serrator", + "99": "goose", + "100": "black swan, Cygnus atratus", + "101": "tusker", + "102": "echidna, spiny anteater, anteater", + "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "104": "wallaby, brush kangaroo", + "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "106": "wombat", + "107": "jellyfish", + "108": "sea anemone, anemone", + "109": "brain coral", + "110": "flatworm, platyhelminth", + "111": "nematode, nematode worm, roundworm", + "112": "conch", + "113": "snail", + "114": "slug", + "115": "sea slug, nudibranch", + "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "117": "chambered nautilus, pearly nautilus, nautilus", + "118": "Dungeness crab, Cancer magister", + "119": "rock crab, Cancer irroratus", + "120": "fiddler crab", + "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "124": "crayfish, crawfish, crawdad, crawdaddy", + "125": "hermit crab", + "126": "isopod", + "127": "white stork, Ciconia ciconia", + "128": "black stork, Ciconia nigra", + "129": "spoonbill", + "130": "flamingo", + "131": "little blue heron, Egretta caerulea", + "132": "American egret, great white heron, Egretta albus", + "133": "bittern", + "134": "crane", + "135": "limpkin, Aramus pictus", + "136": "European gallinule, Porphyrio porphyrio", + "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", + "138": "bustard", + "139": "ruddy turnstone, Arenaria interpres", + "140": "red-backed sandpiper, dunlin, Erolia alpina", + "141": "redshank, Tringa totanus", + "142": "dowitcher", + "143": "oystercatcher, oyster catcher", + "144": "pelican", + "145": "king penguin, Aptenodytes patagonica", + "146": "albatross, mollymawk", + "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "149": "dugong, Dugong dugon", + "150": "sea lion", + "151": "Chihuahua", + "152": "Japanese spaniel", + "153": "Maltese dog, Maltese terrier, Maltese", + "154": "Pekinese, Pekingese, Peke", + "155": "Shih-Tzu", + "156": "Blenheim spaniel", + "157": "papillon", + "158": "toy terrier", + "159": "Rhodesian ridgeback", + "160": "Afghan hound, Afghan", + "161": "basset, basset hound", + "162": "beagle", + "163": "bloodhound, sleuthhound", + "164": "bluetick", + "165": "black-and-tan coonhound", + "166": "Walker hound, Walker foxhound", + "167": "English foxhound", + "168": "redbone", + "169": "borzoi, Russian wolfhound", + "170": "Irish wolfhound", + "171": "Italian greyhound", + "172": "whippet", + "173": "Ibizan hound, Ibizan Podenco", + "174": "Norwegian elkhound, elkhound", + "175": "otterhound, otter hound", + "176": "Saluki, gazelle hound", + "177": "Scottish deerhound, deerhound", + "178": "Weimaraner", + "179": "Staffordshire bullterrier, Staffordshire bull terrier", + "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "181": "Bedlington terrier", + "182": "Border terrier", + "183": "Kerry blue terrier", + "184": "Irish terrier", + "185": "Norfolk terrier", + "186": "Norwich terrier", + "187": "Yorkshire terrier", + "188": "wire-haired fox terrier", + "189": "Lakeland terrier", + "190": "Sealyham terrier, Sealyham", + "191": "Airedale, Airedale terrier", + "192": "cairn, cairn terrier", + "193": "Australian terrier", + "194": "Dandie Dinmont, Dandie Dinmont terrier", + "195": "Boston bull, Boston terrier", + "196": "miniature schnauzer", + "197": "giant schnauzer", + "198": "standard schnauzer", + "199": "Scotch terrier, Scottish terrier, Scottie", + "200": "Tibetan terrier, chrysanthemum dog", + "201": "silky terrier, Sydney silky", + "202": "soft-coated wheaten terrier", + "203": "West Highland white terrier", + "204": "Lhasa, Lhasa apso", + "205": "flat-coated retriever", + "206": "curly-coated retriever", + "207": "golden retriever", + "208": "Labrador retriever", + "209": "Chesapeake Bay retriever", + "210": "German short-haired pointer", + "211": "vizsla, Hungarian pointer", + "212": "English setter", + "213": "Irish setter, red setter", + "214": "Gordon setter", + "215": "Brittany spaniel", + "216": "clumber, clumber spaniel", + "217": "English springer, English springer spaniel", + "218": "Welsh springer spaniel", + "219": "cocker spaniel, English cocker spaniel, cocker", + "220": "Sussex spaniel", + "221": "Irish water spaniel", + "222": "kuvasz", + "223": "schipperke", + "224": "groenendael", + "225": "malinois", + "226": "briard", + "227": "kelpie", + "228": "komondor", + "229": "Old English sheepdog, bobtail", + "230": "Shetland sheepdog, Shetland sheep dog, Shetland", + "231": "collie", + "232": "Border collie", + "233": "Bouvier des Flandres, Bouviers des Flandres", + "234": "Rottweiler", + "235": "German shepherd, German shepherd dog, German police dog, alsatian", + "236": "Doberman, Doberman pinscher", + "237": "miniature pinscher", + "238": "Greater Swiss Mountain dog", + "239": "Bernese mountain dog", + "240": "Appenzeller", + "241": "EntleBucher", + "242": "boxer", + "243": "bull mastiff", + "244": "Tibetan mastiff", + "245": "French bulldog", + "246": "Great Dane", + "247": "Saint Bernard, St Bernard", + "248": "Eskimo dog, husky", + "249": "malamute, malemute, Alaskan malamute", + "250": "Siberian husky", + "251": "dalmatian, coach dog, carriage dog", + "252": "affenpinscher, monkey pinscher, monkey dog", + "253": "basenji", + "254": "pug, pug-dog", + "255": "Leonberg", + "256": "Newfoundland, Newfoundland dog", + "257": "Great Pyrenees", + "258": "Samoyed, Samoyede", + "259": "Pomeranian", + "260": "chow, chow chow", + "261": "keeshond", + "262": "Brabancon griffon", + "263": "Pembroke, Pembroke Welsh corgi", + "264": "Cardigan, Cardigan Welsh corgi", + "265": "toy poodle", + "266": "miniature poodle", + "267": "standard poodle", + "268": "Mexican hairless", + "269": "timber wolf, grey wolf, gray wolf, Canis lupus", + "270": "white wolf, Arctic wolf, Canis lupus tundrarum", + "271": "red wolf, maned wolf, Canis rufus, Canis niger", + "272": "coyote, prairie wolf, brush wolf, Canis latrans", + "273": "dingo, warrigal, warragal, Canis dingo", + "274": "dhole, Cuon alpinus", + "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "276": "hyena, hyaena", + "277": "red fox, Vulpes vulpes", + "278": "kit fox, Vulpes macrotis", + "279": "Arctic fox, white fox, Alopex lagopus", + "280": "grey fox, gray fox, Urocyon cinereoargenteus", + "281": "tabby, tabby cat", + "282": "tiger cat", + "283": "Persian cat", + "284": "Siamese cat, Siamese", + "285": "Egyptian cat", + "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "287": "lynx, catamount", + "288": "leopard, Panthera pardus", + "289": "snow leopard, ounce, Panthera uncia", + "290": "jaguar, panther, Panthera onca, Felis onca", + "291": "lion, king of beasts, Panthera leo", + "292": "tiger, Panthera tigris", + "293": "cheetah, chetah, Acinonyx jubatus", + "294": "brown bear, bruin, Ursus arctos", + "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", + "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "297": "sloth bear, Melursus ursinus, Ursus ursinus", + "298": "mongoose", + "299": "meerkat, mierkat", + "300": "tiger beetle", + "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "302": "ground beetle, carabid beetle", + "303": "long-horned beetle, longicorn, longicorn beetle", + "304": "leaf beetle, chrysomelid", + "305": "dung beetle", + "306": "rhinoceros beetle", + "307": "weevil", + "308": "fly", + "309": "bee", + "310": "ant, emmet, pismire", + "311": "grasshopper, hopper", + "312": "cricket", + "313": "walking stick, walkingstick, stick insect", + "314": "cockroach, roach", + "315": "mantis, mantid", + "316": "cicada, cicala", + "317": "leafhopper", + "318": "lacewing, lacewing fly", + "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "320": "damselfly", + "321": "admiral", + "322": "ringlet, ringlet butterfly", + "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "324": "cabbage butterfly", + "325": "sulphur butterfly, sulfur butterfly", + "326": "lycaenid, lycaenid butterfly", + "327": "starfish, sea star", + "328": "sea urchin", + "329": "sea cucumber, holothurian", + "330": "wood rabbit, cottontail, cottontail rabbit", + "331": "hare", + "332": "Angora, Angora rabbit", + "333": "hamster", + "334": "porcupine, hedgehog", + "335": "fox squirrel, eastern fox squirrel, Sciurus niger", + "336": "marmot", + "337": "beaver", + "338": "guinea pig, Cavia cobaya", + "339": "sorrel", + "340": "zebra", + "341": "hog, pig, grunter, squealer, Sus scrofa", + "342": "wild boar, boar, Sus scrofa", + "343": "warthog", + "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "345": "ox", + "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "347": "bison", + "348": "ram, tup", + "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "350": "ibex, Capra ibex", + "351": "hartebeest", + "352": "impala, Aepyceros melampus", + "353": "gazelle", + "354": "Arabian camel, dromedary, Camelus dromedarius", + "355": "llama", + "356": "weasel", + "357": "mink", + "358": "polecat, fitch, foulmart, foumart, Mustela putorius", + "359": "black-footed ferret, ferret, Mustela nigripes", + "360": "otter", + "361": "skunk, polecat, wood pussy", + "362": "badger", + "363": "armadillo", + "364": "three-toed sloth, ai, Bradypus tridactylus", + "365": "orangutan, orang, orangutang, Pongo pygmaeus", + "366": "gorilla, Gorilla gorilla", + "367": "chimpanzee, chimp, Pan troglodytes", + "368": "gibbon, Hylobates lar", + "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "370": "guenon, guenon monkey", + "371": "patas, hussar monkey, Erythrocebus patas", + "372": "baboon", + "373": "macaque", + "374": "langur", + "375": "colobus, colobus monkey", + "376": "proboscis monkey, Nasalis larvatus", + "377": "marmoset", + "378": "capuchin, ringtail, Cebus capucinus", + "379": "howler monkey, howler", + "380": "titi, titi monkey", + "381": "spider monkey, Ateles geoffroyi", + "382": "squirrel monkey, Saimiri sciureus", + "383": "Madagascar cat, ring-tailed lemur, Lemur catta", + "384": "indri, indris, Indri indri, Indri brevicaudatus", + "385": "Indian elephant, Elephas maximus", + "386": "African elephant, Loxodonta africana", + "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "389": "barracouta, snoek", + "390": "eel", + "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "392": "rock beauty, Holocanthus tricolor", + "393": "anemone fish", + "394": "sturgeon", + "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", + "396": "lionfish", + "397": "puffer, pufferfish, blowfish, globefish", + "398": "abacus", + "399": "abaya", + "400": "academic gown, academic robe, judge's robe", + "401": "accordion, piano accordion, squeeze box", + "402": "acoustic guitar", + "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "404": "airliner", + "405": "airship, dirigible", + "406": "altar", + "407": "ambulance", + "408": "amphibian, amphibious vehicle", + "409": "analog clock", + "410": "apiary, bee house", + "411": "apron", + "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "413": "assault rifle, assault gun", + "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "415": "bakery, bakeshop, bakehouse", + "416": "balance beam, beam", + "417": "balloon", + "418": "ballpoint, ballpoint pen, ballpen, Biro", + "419": "Band Aid", + "420": "banjo", + "421": "bannister, banister, balustrade, balusters, handrail", + "422": "barbell", + "423": "barber chair", + "424": "barbershop", + "425": "barn", + "426": "barometer", + "427": "barrel, cask", + "428": "barrow, garden cart, lawn cart, wheelbarrow", + "429": "baseball", + "430": "basketball", + "431": "bassinet", + "432": "bassoon", + "433": "bathing cap, swimming cap", + "434": "bath towel", + "435": "bathtub, bathing tub, bath, tub", + "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "437": "beacon, lighthouse, beacon light, pharos", + "438": "beaker", + "439": "bearskin, busby, shako", + "440": "beer bottle", + "441": "beer glass", + "442": "bell cote, bell cot", + "443": "bib", + "444": "bicycle-built-for-two, tandem bicycle, tandem", + "445": "bikini, two-piece", + "446": "binder, ring-binder", + "447": "binoculars, field glasses, opera glasses", + "448": "birdhouse", + "449": "boathouse", + "450": "bobsled, bobsleigh, bob", + "451": "bolo tie, bolo, bola tie, bola", + "452": "bonnet, poke bonnet", + "453": "bookcase", + "454": "bookshop, bookstore, bookstall", + "455": "bottlecap", + "456": "bow", + "457": "bow tie, bow-tie, bowtie", + "458": "brass, memorial tablet, plaque", + "459": "brassiere, bra, bandeau", + "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "461": "breastplate, aegis, egis", + "462": "broom", + "463": "bucket, pail", + "464": "buckle", + "465": "bulletproof vest", + "466": "bullet train, bullet", + "467": "butcher shop, meat market", + "468": "cab, hack, taxi, taxicab", + "469": "caldron, cauldron", + "470": "candle, taper, wax light", + "471": "cannon", + "472": "canoe", + "473": "can opener, tin opener", + "474": "cardigan", + "475": "car mirror", + "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", + "477": "carpenter's kit, tool kit", + "478": "carton", + "479": "car wheel", + "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "481": "cassette", + "482": "cassette player", + "483": "castle", + "484": "catamaran", + "485": "CD player", + "486": "cello, violoncello", + "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "488": "chain", + "489": "chainlink fence", + "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "491": "chain saw, chainsaw", + "492": "chest", + "493": "chiffonier, commode", + "494": "chime, bell, gong", + "495": "china cabinet, china closet", + "496": "Christmas stocking", + "497": "church, church building", + "498": "cinema, movie theater, movie theatre, movie house, picture palace", + "499": "cleaver, meat cleaver, chopper", + "500": "cliff dwelling", + "501": "cloak", + "502": "clog, geta, patten, sabot", + "503": "cocktail shaker", + "504": "coffee mug", + "505": "coffeepot", + "506": "coil, spiral, volute, whorl, helix", + "507": "combination lock", + "508": "computer keyboard, keypad", + "509": "confectionery, confectionary, candy store", + "510": "container ship, containership, container vessel", + "511": "convertible", + "512": "corkscrew, bottle screw", + "513": "cornet, horn, trumpet, trump", + "514": "cowboy boot", + "515": "cowboy hat, ten-gallon hat", + "516": "cradle", + "517": "crane", + "518": "crash helmet", + "519": "crate", + "520": "crib, cot", + "521": "Crock Pot", + "522": "croquet ball", + "523": "crutch", + "524": "cuirass", + "525": "dam, dike, dyke", + "526": "desk", + "527": "desktop computer", + "528": "dial telephone, dial phone", + "529": "diaper, nappy, napkin", + "530": "digital clock", + "531": "digital watch", + "532": "dining table, board", + "533": "dishrag, dishcloth", + "534": "dishwasher, dish washer, dishwashing machine", + "535": "disk brake, disc brake", + "536": "dock, dockage, docking facility", + "537": "dogsled, dog sled, dog sleigh", + "538": "dome", + "539": "doormat, welcome mat", + "540": "drilling platform, offshore rig", + "541": "drum, membranophone, tympan", + "542": "drumstick", + "543": "dumbbell", + "544": "Dutch oven", + "545": "electric fan, blower", + "546": "electric guitar", + "547": "electric locomotive", + "548": "entertainment center", + "549": "envelope", + "550": "espresso maker", + "551": "face powder", + "552": "feather boa, boa", + "553": "file, file cabinet, filing cabinet", + "554": "fireboat", + "555": "fire engine, fire truck", + "556": "fire screen, fireguard", + "557": "flagpole, flagstaff", + "558": "flute, transverse flute", + "559": "folding chair", + "560": "football helmet", + "561": "forklift", + "562": "fountain", + "563": "fountain pen", + "564": "four-poster", + "565": "freight car", + "566": "French horn, horn", + "567": "frying pan, frypan, skillet", + "568": "fur coat", + "569": "garbage truck, dustcart", + "570": "gasmask, respirator, gas helmet", + "571": "gas pump, gasoline pump, petrol pump, island dispenser", + "572": "goblet", + "573": "go-kart", + "574": "golf ball", + "575": "golfcart, golf cart", + "576": "gondola", + "577": "gong, tam-tam", + "578": "gown", + "579": "grand piano, grand", + "580": "greenhouse, nursery, glasshouse", + "581": "grille, radiator grille", + "582": "grocery store, grocery, food market, market", + "583": "guillotine", + "584": "hair slide", + "585": "hair spray", + "586": "half track", + "587": "hammer", + "588": "hamper", + "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "590": "hand-held computer, hand-held microcomputer", + "591": "handkerchief, hankie, hanky, hankey", + "592": "hard disc, hard disk, fixed disk", + "593": "harmonica, mouth organ, harp, mouth harp", + "594": "harp", + "595": "harvester, reaper", + "596": "hatchet", + "597": "holster", + "598": "home theater, home theatre", + "599": "honeycomb", + "600": "hook, claw", + "601": "hoopskirt, crinoline", + "602": "horizontal bar, high bar", + "603": "horse cart, horse-cart", + "604": "hourglass", + "605": "iPod", + "606": "iron, smoothing iron", + "607": "jack-o'-lantern", + "608": "jean, blue jean, denim", + "609": "jeep, landrover", + "610": "jersey, T-shirt, tee shirt", + "611": "jigsaw puzzle", + "612": "jinrikisha, ricksha, rickshaw", + "613": "joystick", + "614": "kimono", + "615": "knee pad", + "616": "knot", + "617": "lab coat, laboratory coat", + "618": "ladle", + "619": "lampshade, lamp shade", + "620": "laptop, laptop computer", + "621": "lawn mower, mower", + "622": "lens cap, lens cover", + "623": "letter opener, paper knife, paperknife", + "624": "library", + "625": "lifeboat", + "626": "lighter, light, igniter, ignitor", + "627": "limousine, limo", + "628": "liner, ocean liner", + "629": "lipstick, lip rouge", + "630": "Loafer", + "631": "lotion", + "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "633": "loupe, jeweler's loupe", + "634": "lumbermill, sawmill", + "635": "magnetic compass", + "636": "mailbag, postbag", + "637": "mailbox, letter box", + "638": "maillot", + "639": "maillot, tank suit", + "640": "manhole cover", + "641": "maraca", + "642": "marimba, xylophone", + "643": "mask", + "644": "matchstick", + "645": "maypole", + "646": "maze, labyrinth", + "647": "measuring cup", + "648": "medicine chest, medicine cabinet", + "649": "megalith, megalithic structure", + "650": "microphone, mike", + "651": "microwave, microwave oven", + "652": "military uniform", + "653": "milk can", + "654": "minibus", + "655": "miniskirt, mini", + "656": "minivan", + "657": "missile", + "658": "mitten", + "659": "mixing bowl", + "660": "mobile home, manufactured home", + "661": "Model T", + "662": "modem", + "663": "monastery", + "664": "monitor", + "665": "moped", + "666": "mortar", + "667": "mortarboard", + "668": "mosque", + "669": "mosquito net", + "670": "motor scooter, scooter", + "671": "mountain bike, all-terrain bike, off-roader", + "672": "mountain tent", + "673": "mouse, computer mouse", + "674": "mousetrap", + "675": "moving van", + "676": "muzzle", + "677": "nail", + "678": "neck brace", + "679": "necklace", + "680": "nipple", + "681": "notebook, notebook computer", + "682": "obelisk", + "683": "oboe, hautboy, hautbois", + "684": "ocarina, sweet potato", + "685": "odometer, hodometer, mileometer, milometer", + "686": "oil filter", + "687": "organ, pipe organ", + "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "689": "overskirt", + "690": "oxcart", + "691": "oxygen mask", + "692": "packet", + "693": "paddle, boat paddle", + "694": "paddlewheel, paddle wheel", + "695": "padlock", + "696": "paintbrush", + "697": "pajama, pyjama, pj's, jammies", + "698": "palace", + "699": "panpipe, pandean pipe, syrinx", + "700": "paper towel", + "701": "parachute, chute", + "702": "parallel bars, bars", + "703": "park bench", + "704": "parking meter", + "705": "passenger car, coach, carriage", + "706": "patio, terrace", + "707": "pay-phone, pay-station", + "708": "pedestal, plinth, footstall", + "709": "pencil box, pencil case", + "710": "pencil sharpener", + "711": "perfume, essence", + "712": "Petri dish", + "713": "photocopier", + "714": "pick, plectrum, plectron", + "715": "pickelhaube", + "716": "picket fence, paling", + "717": "pickup, pickup truck", + "718": "pier", + "719": "piggy bank, penny bank", + "720": "pill bottle", + "721": "pillow", + "722": "ping-pong ball", + "723": "pinwheel", + "724": "pirate, pirate ship", + "725": "pitcher, ewer", + "726": "plane, carpenter's plane, woodworking plane", + "727": "planetarium", + "728": "plastic bag", + "729": "plate rack", + "730": "plow, plough", + "731": "plunger, plumber's helper", + "732": "Polaroid camera, Polaroid Land camera", + "733": "pole", + "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "735": "poncho", + "736": "pool table, billiard table, snooker table", + "737": "pop bottle, soda bottle", + "738": "pot, flowerpot", + "739": "potter's wheel", + "740": "power drill", + "741": "prayer rug, prayer mat", + "742": "printer", + "743": "prison, prison house", + "744": "projectile, missile", + "745": "projector", + "746": "puck, hockey puck", + "747": "punching bag, punch bag, punching ball, punchball", + "748": "purse", + "749": "quill, quill pen", + "750": "quilt, comforter, comfort, puff", + "751": "racer, race car, racing car", + "752": "racket, racquet", + "753": "radiator", + "754": "radio, wireless", + "755": "radio telescope, radio reflector", + "756": "rain barrel", + "757": "recreational vehicle, RV, R.V.", + "758": "reel", + "759": "reflex camera", + "760": "refrigerator, icebox", + "761": "remote control, remote", + "762": "restaurant, eating house, eating place, eatery", + "763": "revolver, six-gun, six-shooter", + "764": "rifle", + "765": "rocking chair, rocker", + "766": "rotisserie", + "767": "rubber eraser, rubber, pencil eraser", + "768": "rugby ball", + "769": "rule, ruler", + "770": "running shoe", + "771": "safe", + "772": "safety pin", + "773": "saltshaker, salt shaker", + "774": "sandal", + "775": "sarong", + "776": "sax, saxophone", + "777": "scabbard", + "778": "scale, weighing machine", + "779": "school bus", + "780": "schooner", + "781": "scoreboard", + "782": "screen, CRT screen", + "783": "screw", + "784": "screwdriver", + "785": "seat belt, seatbelt", + "786": "sewing machine", + "787": "shield, buckler", + "788": "shoe shop, shoe-shop, shoe store", + "789": "shoji", + "790": "shopping basket", + "791": "shopping cart", + "792": "shovel", + "793": "shower cap", + "794": "shower curtain", + "795": "ski", + "796": "ski mask", + "797": "sleeping bag", + "798": "slide rule, slipstick", + "799": "sliding door", + "800": "slot, one-armed bandit", + "801": "snorkel", + "802": "snowmobile", + "803": "snowplow, snowplough", + "804": "soap dispenser", + "805": "soccer ball", + "806": "sock", + "807": "solar dish, solar collector, solar furnace", + "808": "sombrero", + "809": "soup bowl", + "810": "space bar", + "811": "space heater", + "812": "space shuttle", + "813": "spatula", + "814": "speedboat", + "815": "spider web, spider's web", + "816": "spindle", + "817": "sports car, sport car", + "818": "spotlight, spot", + "819": "stage", + "820": "steam locomotive", + "821": "steel arch bridge", + "822": "steel drum", + "823": "stethoscope", + "824": "stole", + "825": "stone wall", + "826": "stopwatch, stop watch", + "827": "stove", + "828": "strainer", + "829": "streetcar, tram, tramcar, trolley, trolley car", + "830": "stretcher", + "831": "studio couch, day bed", + "832": "stupa, tope", + "833": "submarine, pigboat, sub, U-boat", + "834": "suit, suit of clothes", + "835": "sundial", + "836": "sunglass", + "837": "sunglasses, dark glasses, shades", + "838": "sunscreen, sunblock, sun blocker", + "839": "suspension bridge", + "840": "swab, swob, mop", + "841": "sweatshirt", + "842": "swimming trunks, bathing trunks", + "843": "swing", + "844": "switch, electric switch, electrical switch", + "845": "syringe", + "846": "table lamp", + "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", + "848": "tape player", + "849": "teapot", + "850": "teddy, teddy bear", + "851": "television, television system", + "852": "tennis ball", + "853": "thatch, thatched roof", + "854": "theater curtain, theatre curtain", + "855": "thimble", + "856": "thresher, thrasher, threshing machine", + "857": "throne", + "858": "tile roof", + "859": "toaster", + "860": "tobacco shop, tobacconist shop, tobacconist", + "861": "toilet seat", + "862": "torch", + "863": "totem pole", + "864": "tow truck, tow car, wrecker", + "865": "toyshop", + "866": "tractor", + "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "868": "tray", + "869": "trench coat", + "870": "tricycle, trike, velocipede", + "871": "trimaran", + "872": "tripod", + "873": "triumphal arch", + "874": "trolleybus, trolley coach, trackless trolley", + "875": "trombone", + "876": "tub, vat", + "877": "turnstile", + "878": "typewriter keyboard", + "879": "umbrella", + "880": "unicycle, monocycle", + "881": "upright, upright piano", + "882": "vacuum, vacuum cleaner", + "883": "vase", + "884": "vault", + "885": "velvet", + "886": "vending machine", + "887": "vestment", + "888": "viaduct", + "889": "violin, fiddle", + "890": "volleyball", + "891": "waffle iron", + "892": "wall clock", + "893": "wallet, billfold, notecase, pocketbook", + "894": "wardrobe, closet, press", + "895": "warplane, military plane", + "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "897": "washer, automatic washer, washing machine", + "898": "water bottle", + "899": "water jug", + "900": "water tower", + "901": "whiskey jug", + "902": "whistle", + "903": "wig", + "904": "window screen", + "905": "window shade", + "906": "Windsor tie", + "907": "wine bottle", + "908": "wing", + "909": "wok", + "910": "wooden spoon", + "911": "wool, woolen, woollen", + "912": "worm fence, snake fence, snake-rail fence, Virginia fence", + "913": "wreck", + "914": "yawl", + "915": "yurt", + "916": "web site, website, internet site, site", + "917": "comic book", + "918": "crossword puzzle, crossword", + "919": "street sign", + "920": "traffic light, traffic signal, stoplight", + "921": "book jacket, dust cover, dust jacket, dust wrapper", + "922": "menu", + "923": "plate", + "924": "guacamole", + "925": "consomme", + "926": "hot pot, hotpot", + "927": "trifle", + "928": "ice cream, icecream", + "929": "ice lolly, lolly, lollipop, popsicle", + "930": "French loaf", + "931": "bagel, beigel", + "932": "pretzel", + "933": "cheeseburger", + "934": "hotdog, hot dog, red hot", + "935": "mashed potato", + "936": "head cabbage", + "937": "broccoli", + "938": "cauliflower", + "939": "zucchini, courgette", + "940": "spaghetti squash", + "941": "acorn squash", + "942": "butternut squash", + "943": "cucumber, cuke", + "944": "artichoke, globe artichoke", + "945": "bell pepper", + "946": "cardoon", + "947": "mushroom", + "948": "Granny Smith", + "949": "strawberry", + "950": "orange", + "951": "lemon", + "952": "fig", + "953": "pineapple, ananas", + "954": "banana", + "955": "jackfruit, jak, jack", + "956": "custard apple", + "957": "pomegranate", + "958": "hay", + "959": "carbonara", + "960": "chocolate sauce, chocolate syrup", + "961": "dough", + "962": "meat loaf, meatloaf", + "963": "pizza, pizza pie", + "964": "potpie", + "965": "burrito", + "966": "red wine", + "967": "espresso", + "968": "cup", + "969": "eggnog", + "970": "alp", + "971": "bubble", + "972": "cliff, drop, drop-off", + "973": "coral reef", + "974": "geyser", + "975": "lakeside, lakeshore", + "976": "promontory, headland, head, foreland", + "977": "sandbar, sand bar", + "978": "seashore, coast, seacoast, sea-coast", + "979": "valley, vale", + "980": "volcano", + "981": "ballplayer, baseball player", + "982": "groom, bridegroom", + "983": "scuba diver", + "984": "rapeseed", + "985": "daisy", + "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "987": "corn", + "988": "acorn", + "989": "hip, rose hip, rosehip", + "990": "buckeye, horse chestnut, conker", + "991": "coral fungus", + "992": "agaric", + "993": "gyromitra", + "994": "stinkhorn, carrion fungus", + "995": "earthstar", + "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "997": "bolete", + "998": "ear, spike, capitulum", + "999": "toilet tissue, toilet paper, bathroom tissue" + }, + "image_size": 384, + "initializer_range": 0.02, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "Afghan hound, Afghan": 160, + "African chameleon, Chamaeleo chamaeleon": 47, + "African crocodile, Nile crocodile, Crocodylus niloticus": 49, + "African elephant, Loxodonta africana": 386, + "African grey, African gray, Psittacus erithacus": 87, + "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275, + "Airedale, Airedale terrier": 191, + "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180, + "American alligator, Alligator mississipiensis": 50, + "American black bear, black bear, Ursus americanus, Euarctos americanus": 295, + "American chameleon, anole, Anolis carolinensis": 40, + "American coot, marsh hen, mud hen, water hen, Fulica americana": 137, + "American egret, great white heron, Egretta albus": 132, + "American lobster, Northern lobster, Maine lobster, Homarus americanus": 122, + "Angora, Angora rabbit": 332, + "Appenzeller": 240, + "Arabian camel, dromedary, Camelus dromedarius": 354, + "Arctic fox, white fox, Alopex lagopus": 279, + "Australian terrier": 193, + "Band Aid": 419, + "Bedlington terrier": 181, + "Bernese mountain dog": 239, + "Blenheim spaniel": 156, + "Border collie": 232, + "Border terrier": 182, + "Boston bull, Boston terrier": 195, + "Bouvier des Flandres, Bouviers des Flandres": 233, + "Brabancon griffon": 262, + "Brittany spaniel": 215, + "CD player": 485, + "Cardigan, Cardigan Welsh corgi": 264, + "Chesapeake Bay retriever": 209, + "Chihuahua": 151, + "Christmas stocking": 496, + "Crock Pot": 521, + "Dandie Dinmont, Dandie Dinmont terrier": 194, + "Doberman, Doberman pinscher": 236, + "Dungeness crab, Cancer magister": 118, + "Dutch oven": 544, + "Egyptian cat": 285, + "English foxhound": 167, + "English setter": 212, + "English springer, English springer spaniel": 217, + "EntleBucher": 241, + "Eskimo dog, husky": 248, + "European fire salamander, Salamandra salamandra": 25, + "European gallinule, Porphyrio porphyrio": 136, + "French bulldog": 245, + "French horn, horn": 566, + "French loaf": 930, + "German shepherd, German shepherd dog, German police dog, alsatian": 235, + "German short-haired pointer": 210, + "Gila monster, Heloderma suspectum": 45, + "Gordon setter": 214, + "Granny Smith": 948, + "Great Dane": 246, + "Great Pyrenees": 257, + "Greater Swiss Mountain dog": 238, + "Ibizan hound, Ibizan Podenco": 173, + "Indian cobra, Naja naja": 63, + "Indian elephant, Elephas maximus": 385, + "Irish setter, red setter": 213, + "Irish terrier": 184, + "Irish water spaniel": 221, + "Irish wolfhound": 170, + "Italian greyhound": 171, + "Japanese spaniel": 152, + "Kerry blue terrier": 183, + "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48, + "Labrador retriever": 208, + "Lakeland terrier": 189, + "Leonberg": 255, + "Lhasa, Lhasa apso": 204, + "Loafer": 630, + "Madagascar cat, ring-tailed lemur, Lemur catta": 383, + "Maltese dog, Maltese terrier, Maltese": 153, + "Mexican hairless": 268, + "Model T": 661, + "Newfoundland, Newfoundland dog": 256, + "Norfolk terrier": 185, + "Norwegian elkhound, elkhound": 174, + "Norwich terrier": 186, + "Old English sheepdog, bobtail": 229, + "Pekinese, Pekingese, Peke": 154, + "Pembroke, Pembroke Welsh corgi": 263, + "Persian cat": 283, + "Petri dish": 712, + "Polaroid camera, Polaroid Land camera": 732, + "Pomeranian": 259, + "Rhodesian ridgeback": 159, + "Rottweiler": 234, + "Saint Bernard, St Bernard": 247, + "Saluki, gazelle hound": 176, + "Samoyed, Samoyede": 258, + "Scotch terrier, Scottish terrier, Scottie": 199, + "Scottish deerhound, deerhound": 177, + "Sealyham terrier, Sealyham": 190, + "Shetland sheepdog, Shetland sheep dog, Shetland": 230, + "Shih-Tzu": 155, + "Siamese cat, Siamese": 284, + "Siberian husky": 250, + "Staffordshire bullterrier, Staffordshire bull terrier": 179, + "Sussex spaniel": 220, + "Tibetan mastiff": 244, + "Tibetan terrier, chrysanthemum dog": 200, + "Walker hound, Walker foxhound": 166, + "Weimaraner": 178, + "Welsh springer spaniel": 218, + "West Highland white terrier": 203, + "Windsor tie": 906, + "Yorkshire terrier": 187, + "abacus": 398, + "abaya": 399, + "academic gown, academic robe, judge's robe": 400, + "accordion, piano accordion, squeeze box": 401, + "acorn": 988, + "acorn squash": 941, + "acoustic guitar": 402, + "admiral": 321, + "affenpinscher, monkey pinscher, monkey dog": 252, + "agama": 42, + "agaric": 992, + "aircraft carrier, carrier, flattop, attack aircraft carrier": 403, + "airliner": 404, + "airship, dirigible": 405, + "albatross, mollymawk": 146, + "alligator lizard": 44, + "alp": 970, + "altar": 406, + "ambulance": 407, + "amphibian, amphibious vehicle": 408, + "analog clock": 409, + "anemone fish": 393, + "ant, emmet, pismire": 310, + "apiary, bee house": 410, + "apron": 411, + "armadillo": 363, + "artichoke, globe artichoke": 944, + "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412, + "assault rifle, assault gun": 413, + "axolotl, mud puppy, Ambystoma mexicanum": 29, + "baboon": 372, + "backpack, back pack, knapsack, packsack, rucksack, haversack": 414, + "badger": 362, + "bagel, beigel": 931, + "bakery, bakeshop, bakehouse": 415, + "balance beam, beam": 416, + "bald eagle, American eagle, Haliaeetus leucocephalus": 22, + "balloon": 417, + "ballplayer, baseball player": 981, + "ballpoint, ballpoint pen, ballpen, Biro": 418, + "banana": 954, + "banded gecko": 38, + "banjo": 420, + "bannister, banister, balustrade, balusters, handrail": 421, + "barbell": 422, + "barber chair": 423, + "barbershop": 424, + "barn": 425, + "barn spider, Araneus cavaticus": 73, + "barometer": 426, + "barracouta, snoek": 389, + "barrel, cask": 427, + "barrow, garden cart, lawn cart, wheelbarrow": 428, + "baseball": 429, + "basenji": 253, + "basketball": 430, + "basset, basset hound": 161, + "bassinet": 431, + "bassoon": 432, + "bath towel": 434, + "bathing cap, swimming cap": 433, + "bathtub, bathing tub, bath, tub": 435, + "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436, + "beacon, lighthouse, beacon light, pharos": 437, + "beagle": 162, + "beaker": 438, + "bearskin, busby, shako": 439, + "beaver": 337, + "bee": 309, + "bee eater": 92, + "beer bottle": 440, + "beer glass": 441, + "bell cote, bell cot": 442, + "bell pepper": 945, + "bib": 443, + "bicycle-built-for-two, tandem bicycle, tandem": 444, + "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349, + "bikini, two-piece": 445, + "binder, ring-binder": 446, + "binoculars, field glasses, opera glasses": 447, + "birdhouse": 448, + "bison": 347, + "bittern": 133, + "black and gold garden spider, Argiope aurantia": 72, + "black grouse": 80, + "black stork, Ciconia nigra": 128, + "black swan, Cygnus atratus": 100, + "black widow, Latrodectus mactans": 75, + "black-and-tan coonhound": 165, + "black-footed ferret, ferret, Mustela nigripes": 359, + "bloodhound, sleuthhound": 163, + "bluetick": 164, + "boa constrictor, Constrictor constrictor": 61, + "boathouse": 449, + "bobsled, bobsleigh, bob": 450, + "bolete": 997, + "bolo tie, bolo, bola tie, bola": 451, + "bonnet, poke bonnet": 452, + "book jacket, dust cover, dust jacket, dust wrapper": 921, + "bookcase": 453, + "bookshop, bookstore, bookstall": 454, + "borzoi, Russian wolfhound": 169, + "bottlecap": 455, + "bow": 456, + "bow tie, bow-tie, bowtie": 457, + "box turtle, box tortoise": 37, + "boxer": 242, + "brain coral": 109, + "brambling, Fringilla montifringilla": 10, + "brass, memorial tablet, plaque": 458, + "brassiere, bra, bandeau": 459, + "breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460, + "breastplate, aegis, egis": 461, + "briard": 226, + "broccoli": 937, + "broom": 462, + "brown bear, bruin, Ursus arctos": 294, + "bubble": 971, + "bucket, pail": 463, + "buckeye, horse chestnut, conker": 990, + "buckle": 464, + "bulbul": 16, + "bull mastiff": 243, + "bullet train, bullet": 466, + "bulletproof vest": 465, + "bullfrog, Rana catesbeiana": 30, + "burrito": 965, + "bustard": 138, + "butcher shop, meat market": 467, + "butternut squash": 942, + "cab, hack, taxi, taxicab": 468, + "cabbage butterfly": 324, + "cairn, cairn terrier": 192, + "caldron, cauldron": 469, + "can opener, tin opener": 473, + "candle, taper, wax light": 470, + "cannon": 471, + "canoe": 472, + "capuchin, ringtail, Cebus capucinus": 378, + "car mirror": 475, + "car wheel": 479, + "carbonara": 959, + "cardigan": 474, + "cardoon": 946, + "carousel, carrousel, merry-go-round, roundabout, whirligig": 476, + "carpenter's kit, tool kit": 477, + "carton": 478, + "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480, + "cassette": 481, + "cassette player": 482, + "castle": 483, + "catamaran": 484, + "cauliflower": 938, + "cello, violoncello": 486, + "cellular telephone, cellular phone, cellphone, cell, mobile phone": 487, + "centipede": 79, + "chain": 488, + "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490, + "chain saw, chainsaw": 491, + "chainlink fence": 489, + "chambered nautilus, pearly nautilus, nautilus": 117, + "cheeseburger": 933, + "cheetah, chetah, Acinonyx jubatus": 293, + "chest": 492, + "chickadee": 19, + "chiffonier, commode": 493, + "chime, bell, gong": 494, + "chimpanzee, chimp, Pan troglodytes": 367, + "china cabinet, china closet": 495, + "chiton, coat-of-mail shell, sea cradle, polyplacophore": 116, + "chocolate sauce, chocolate syrup": 960, + "chow, chow chow": 260, + "church, church building": 497, + "cicada, cicala": 316, + "cinema, movie theater, movie theatre, movie house, picture palace": 498, + "cleaver, meat cleaver, chopper": 499, + "cliff dwelling": 500, + "cliff, drop, drop-off": 972, + "cloak": 501, + "clog, geta, patten, sabot": 502, + "clumber, clumber spaniel": 216, + "cock": 7, + "cocker spaniel, English cocker spaniel, cocker": 219, + "cockroach, roach": 314, + "cocktail shaker": 503, + "coffee mug": 504, + "coffeepot": 505, + "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391, + "coil, spiral, volute, whorl, helix": 506, + "collie": 231, + "colobus, colobus monkey": 375, + "combination lock": 507, + "comic book": 917, + "common iguana, iguana, Iguana iguana": 39, + "common newt, Triturus vulgaris": 26, + "computer keyboard, keypad": 508, + "conch": 112, + "confectionery, confectionary, candy store": 509, + "consomme": 925, + "container ship, containership, container vessel": 510, + "convertible": 511, + "coral fungus": 991, + "coral reef": 973, + "corkscrew, bottle screw": 512, + "corn": 987, + "cornet, horn, trumpet, trump": 513, + "coucal": 91, + "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286, + "cowboy boot": 514, + "cowboy hat, ten-gallon hat": 515, + "coyote, prairie wolf, brush wolf, Canis latrans": 272, + "cradle": 516, + "crane": 517, + "crash helmet": 518, + "crate": 519, + "crayfish, crawfish, crawdad, crawdaddy": 124, + "crib, cot": 520, + "cricket": 312, + "croquet ball": 522, + "crossword puzzle, crossword": 918, + "crutch": 523, + "cucumber, cuke": 943, + "cuirass": 524, + "cup": 968, + "curly-coated retriever": 206, + "custard apple": 956, + "daisy": 985, + "dalmatian, coach dog, carriage dog": 251, + "dam, dike, dyke": 525, + "damselfly": 320, + "desk": 526, + "desktop computer": 527, + "dhole, Cuon alpinus": 274, + "dial telephone, dial phone": 528, + "diamondback, diamondback rattlesnake, Crotalus adamanteus": 67, + "diaper, nappy, napkin": 529, + "digital clock": 530, + "digital watch": 531, + "dingo, warrigal, warragal, Canis dingo": 273, + "dining table, board": 532, + "dishrag, dishcloth": 533, + "dishwasher, dish washer, dishwashing machine": 534, + "disk brake, disc brake": 535, + "dock, dockage, docking facility": 536, + "dogsled, dog sled, dog sleigh": 537, + "dome": 538, + "doormat, welcome mat": 539, + "dough": 961, + "dowitcher": 142, + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319, + "drake": 97, + "drilling platform, offshore rig": 540, + "drum, membranophone, tympan": 541, + "drumstick": 542, + "dugong, Dugong dugon": 149, + "dumbbell": 543, + "dung beetle": 305, + "ear, spike, capitulum": 998, + "earthstar": 995, + "echidna, spiny anteater, anteater": 102, + "eel": 390, + "eft": 27, + "eggnog": 969, + "electric fan, blower": 545, + "electric guitar": 546, + "electric locomotive": 547, + "electric ray, crampfish, numbfish, torpedo": 5, + "entertainment center": 548, + "envelope": 549, + "espresso": 967, + "espresso maker": 550, + "face powder": 551, + "feather boa, boa": 552, + "fiddler crab": 120, + "fig": 952, + "file, file cabinet, filing cabinet": 553, + "fire engine, fire truck": 555, + "fire screen, fireguard": 556, + "fireboat": 554, + "flagpole, flagstaff": 557, + "flamingo": 130, + "flat-coated retriever": 205, + "flatworm, platyhelminth": 110, + "flute, transverse flute": 558, + "fly": 308, + "folding chair": 559, + "football helmet": 560, + "forklift": 561, + "fountain": 562, + "fountain pen": 563, + "four-poster": 564, + "fox squirrel, eastern fox squirrel, Sciurus niger": 335, + "freight car": 565, + "frilled lizard, Chlamydosaurus kingi": 43, + "frying pan, frypan, skillet": 567, + "fur coat": 568, + "gar, garfish, garpike, billfish, Lepisosteus osseus": 395, + "garbage truck, dustcart": 569, + "garden spider, Aranea diademata": 74, + "garter snake, grass snake": 57, + "gas pump, gasoline pump, petrol pump, island dispenser": 571, + "gasmask, respirator, gas helmet": 570, + "gazelle": 353, + "geyser": 974, + "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388, + "giant schnauzer": 197, + "gibbon, Hylobates lar": 368, + "go-kart": 573, + "goblet": 572, + "golden retriever": 207, + "goldfinch, Carduelis carduelis": 11, + "goldfish, Carassius auratus": 1, + "golf ball": 574, + "golfcart, golf cart": 575, + "gondola": 576, + "gong, tam-tam": 577, + "goose": 99, + "gorilla, Gorilla gorilla": 366, + "gown": 578, + "grand piano, grand": 579, + "grasshopper, hopper": 311, + "great grey owl, great gray owl, Strix nebulosa": 24, + "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2, + "green lizard, Lacerta viridis": 46, + "green mamba": 64, + "green snake, grass snake": 55, + "greenhouse, nursery, glasshouse": 580, + "grey fox, gray fox, Urocyon cinereoargenteus": 280, + "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147, + "grille, radiator grille": 581, + "grocery store, grocery, food market, market": 582, + "groenendael": 224, + "groom, bridegroom": 982, + "ground beetle, carabid beetle": 302, + "guacamole": 924, + "guenon, guenon monkey": 370, + "guillotine": 583, + "guinea pig, Cavia cobaya": 338, + "gyromitra": 993, + "hair slide": 584, + "hair spray": 585, + "half track": 586, + "hammer": 587, + "hammerhead, hammerhead shark": 4, + "hamper": 588, + "hamster": 333, + "hand blower, blow dryer, blow drier, hair dryer, hair drier": 589, + "hand-held computer, hand-held microcomputer": 590, + "handkerchief, hankie, hanky, hankey": 591, + "hard disc, hard disk, fixed disk": 592, + "hare": 331, + "harmonica, mouth organ, harp, mouth harp": 593, + "harp": 594, + "hartebeest": 351, + "harvester, reaper": 595, + "harvestman, daddy longlegs, Phalangium opilio": 70, + "hatchet": 596, + "hay": 958, + "head cabbage": 936, + "hen": 8, + "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996, + "hermit crab": 125, + "hip, rose hip, rosehip": 989, + "hippopotamus, hippo, river horse, Hippopotamus amphibius": 344, + "hog, pig, grunter, squealer, Sus scrofa": 341, + "hognose snake, puff adder, sand viper": 54, + "holster": 597, + "home theater, home theatre": 598, + "honeycomb": 599, + "hook, claw": 600, + "hoopskirt, crinoline": 601, + "horizontal bar, high bar": 602, + "hornbill": 93, + "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66, + "horse cart, horse-cart": 603, + "hot pot, hotpot": 926, + "hotdog, hot dog, red hot": 934, + "hourglass": 604, + "house finch, linnet, Carpodacus mexicanus": 12, + "howler monkey, howler": 379, + "hummingbird": 94, + "hyena, hyaena": 276, + "iPod": 605, + "ibex, Capra ibex": 350, + "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296, + "ice cream, icecream": 928, + "ice lolly, lolly, lollipop, popsicle": 929, + "impala, Aepyceros melampus": 352, + "indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14, + "indri, indris, Indri indri, Indri brevicaudatus": 384, + "iron, smoothing iron": 606, + "isopod": 126, + "jacamar": 95, + "jack-o'-lantern": 607, + "jackfruit, jak, jack": 955, + "jaguar, panther, Panthera onca, Felis onca": 290, + "jay": 17, + "jean, blue jean, denim": 608, + "jeep, landrover": 609, + "jellyfish": 107, + "jersey, T-shirt, tee shirt": 610, + "jigsaw puzzle": 611, + "jinrikisha, ricksha, rickshaw": 612, + "joystick": 613, + "junco, snowbird": 13, + "keeshond": 261, + "kelpie": 227, + "killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148, + "kimono": 614, + "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121, + "king penguin, Aptenodytes patagonica": 145, + "king snake, kingsnake": 56, + "kit fox, Vulpes macrotis": 278, + "kite": 21, + "knee pad": 615, + "knot": 616, + "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105, + "komondor": 228, + "kuvasz": 222, + "lab coat, laboratory coat": 617, + "lacewing, lacewing fly": 318, + "ladle": 618, + "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301, + "lakeside, lakeshore": 975, + "lampshade, lamp shade": 619, + "langur": 374, + "laptop, laptop computer": 620, + "lawn mower, mower": 621, + "leaf beetle, chrysomelid": 304, + "leafhopper": 317, + "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34, + "lemon": 951, + "lens cap, lens cover": 622, + "leopard, Panthera pardus": 288, + "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387, + "letter opener, paper knife, paperknife": 623, + "library": 624, + "lifeboat": 625, + "lighter, light, igniter, ignitor": 626, + "limousine, limo": 627, + "limpkin, Aramus pictus": 135, + "liner, ocean liner": 628, + "lion, king of beasts, Panthera leo": 291, + "lionfish": 396, + "lipstick, lip rouge": 629, + "little blue heron, Egretta caerulea": 131, + "llama": 355, + "loggerhead, loggerhead turtle, Caretta caretta": 33, + "long-horned beetle, longicorn, longicorn beetle": 303, + "lorikeet": 90, + "lotion": 631, + "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632, + "loupe, jeweler's loupe": 633, + "lumbermill, sawmill": 634, + "lycaenid, lycaenid butterfly": 326, + "lynx, catamount": 287, + "macaque": 373, + "macaw": 88, + "magnetic compass": 635, + "magpie": 18, + "mailbag, postbag": 636, + "mailbox, letter box": 637, + "maillot": 638, + "maillot, tank suit": 639, + "malamute, malemute, Alaskan malamute": 249, + "malinois": 225, + "manhole cover": 640, + "mantis, mantid": 315, + "maraca": 641, + "marimba, xylophone": 642, + "marmoset": 377, + "marmot": 336, + "mashed potato": 935, + "mask": 643, + "matchstick": 644, + "maypole": 645, + "maze, labyrinth": 646, + "measuring cup": 647, + "meat loaf, meatloaf": 962, + "medicine chest, medicine cabinet": 648, + "meerkat, mierkat": 299, + "megalith, megalithic structure": 649, + "menu": 922, + "microphone, mike": 650, + "microwave, microwave oven": 651, + "military uniform": 652, + "milk can": 653, + "miniature pinscher": 237, + "miniature poodle": 266, + "miniature schnauzer": 196, + "minibus": 654, + "miniskirt, mini": 655, + "minivan": 656, + "mink": 357, + "missile": 657, + "mitten": 658, + "mixing bowl": 659, + "mobile home, manufactured home": 660, + "modem": 662, + "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323, + "monastery": 663, + "mongoose": 298, + "monitor": 664, + "moped": 665, + "mortar": 666, + "mortarboard": 667, + "mosque": 668, + "mosquito net": 669, + "motor scooter, scooter": 670, + "mountain bike, all-terrain bike, off-roader": 671, + "mountain tent": 672, + "mouse, computer mouse": 673, + "mousetrap": 674, + "moving van": 675, + "mud turtle": 35, + "mushroom": 947, + "muzzle": 676, + "nail": 677, + "neck brace": 678, + "necklace": 679, + "nematode, nematode worm, roundworm": 111, + "night snake, Hypsiglena torquata": 60, + "nipple": 680, + "notebook, notebook computer": 681, + "obelisk": 682, + "oboe, hautboy, hautbois": 683, + "ocarina, sweet potato": 684, + "odometer, hodometer, mileometer, milometer": 685, + "oil filter": 686, + "orange": 950, + "orangutan, orang, orangutang, Pongo pygmaeus": 365, + "organ, pipe organ": 687, + "oscilloscope, scope, cathode-ray oscilloscope, CRO": 688, + "ostrich, Struthio camelus": 9, + "otter": 360, + "otterhound, otter hound": 175, + "overskirt": 689, + "ox": 345, + "oxcart": 690, + "oxygen mask": 691, + "oystercatcher, oyster catcher": 143, + "packet": 692, + "paddle, boat paddle": 693, + "paddlewheel, paddle wheel": 694, + "padlock": 695, + "paintbrush": 696, + "pajama, pyjama, pj's, jammies": 697, + "palace": 698, + "panpipe, pandean pipe, syrinx": 699, + "paper towel": 700, + "papillon": 157, + "parachute, chute": 701, + "parallel bars, bars": 702, + "park bench": 703, + "parking meter": 704, + "partridge": 86, + "passenger car, coach, carriage": 705, + "patas, hussar monkey, Erythrocebus patas": 371, + "patio, terrace": 706, + "pay-phone, pay-station": 707, + "peacock": 84, + "pedestal, plinth, footstall": 708, + "pelican": 144, + "pencil box, pencil case": 709, + "pencil sharpener": 710, + "perfume, essence": 711, + "photocopier": 713, + "pick, plectrum, plectron": 714, + "pickelhaube": 715, + "picket fence, paling": 716, + "pickup, pickup truck": 717, + "pier": 718, + "piggy bank, penny bank": 719, + "pill bottle": 720, + "pillow": 721, + "pineapple, ananas": 953, + "ping-pong ball": 722, + "pinwheel": 723, + "pirate, pirate ship": 724, + "pitcher, ewer": 725, + "pizza, pizza pie": 963, + "plane, carpenter's plane, woodworking plane": 726, + "planetarium": 727, + "plastic bag": 728, + "plate": 923, + "plate rack": 729, + "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103, + "plow, plough": 730, + "plunger, plumber's helper": 731, + "pole": 733, + "polecat, fitch, foulmart, foumart, Mustela putorius": 358, + "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734, + "pomegranate": 957, + "poncho": 735, + "pool table, billiard table, snooker table": 736, + "pop bottle, soda bottle": 737, + "porcupine, hedgehog": 334, + "pot, flowerpot": 738, + "potpie": 964, + "potter's wheel": 739, + "power drill": 740, + "prairie chicken, prairie grouse, prairie fowl": 83, + "prayer rug, prayer mat": 741, + "pretzel": 932, + "printer": 742, + "prison, prison house": 743, + "proboscis monkey, Nasalis larvatus": 376, + "projectile, missile": 744, + "projector": 745, + "promontory, headland, head, foreland": 976, + "ptarmigan": 81, + "puck, hockey puck": 746, + "puffer, pufferfish, blowfish, globefish": 397, + "pug, pug-dog": 254, + "punching bag, punch bag, punching ball, punchball": 747, + "purse": 748, + "quail": 85, + "quill, quill pen": 749, + "quilt, comforter, comfort, puff": 750, + "racer, race car, racing car": 751, + "racket, racquet": 752, + "radiator": 753, + "radio telescope, radio reflector": 755, + "radio, wireless": 754, + "rain barrel": 756, + "ram, tup": 348, + "rapeseed": 984, + "recreational vehicle, RV, R.V.": 757, + "red fox, Vulpes vulpes": 277, + "red wine": 966, + "red wolf, maned wolf, Canis rufus, Canis niger": 271, + "red-backed sandpiper, dunlin, Erolia alpina": 140, + "red-breasted merganser, Mergus serrator": 98, + "redbone": 168, + "redshank, Tringa totanus": 141, + "reel": 758, + "reflex camera": 759, + "refrigerator, icebox": 760, + "remote control, remote": 761, + "restaurant, eating house, eating place, eatery": 762, + "revolver, six-gun, six-shooter": 763, + "rhinoceros beetle": 306, + "rifle": 764, + "ringlet, ringlet butterfly": 322, + "ringneck snake, ring-necked snake, ring snake": 53, + "robin, American robin, Turdus migratorius": 15, + "rock beauty, Holocanthus tricolor": 392, + "rock crab, Cancer irroratus": 119, + "rock python, rock snake, Python sebae": 62, + "rocking chair, rocker": 765, + "rotisserie": 766, + "rubber eraser, rubber, pencil eraser": 767, + "ruddy turnstone, Arenaria interpres": 139, + "ruffed grouse, partridge, Bonasa umbellus": 82, + "rugby ball": 768, + "rule, ruler": 769, + "running shoe": 770, + "safe": 771, + "safety pin": 772, + "saltshaker, salt shaker": 773, + "sandal": 774, + "sandbar, sand bar": 977, + "sarong": 775, + "sax, saxophone": 776, + "scabbard": 777, + "scale, weighing machine": 778, + "schipperke": 223, + "school bus": 779, + "schooner": 780, + "scoreboard": 781, + "scorpion": 71, + "screen, CRT screen": 782, + "screw": 783, + "screwdriver": 784, + "scuba diver": 983, + "sea anemone, anemone": 108, + "sea cucumber, holothurian": 329, + "sea lion": 150, + "sea slug, nudibranch": 115, + "sea snake": 65, + "sea urchin": 328, + "seashore, coast, seacoast, sea-coast": 978, + "seat belt, seatbelt": 785, + "sewing machine": 786, + "shield, buckler": 787, + "shoe shop, shoe-shop, shoe store": 788, + "shoji": 789, + "shopping basket": 790, + "shopping cart": 791, + "shovel": 792, + "shower cap": 793, + "shower curtain": 794, + "siamang, Hylobates syndactylus, Symphalangus syndactylus": 369, + "sidewinder, horned rattlesnake, Crotalus cerastes": 68, + "silky terrier, Sydney silky": 201, + "ski": 795, + "ski mask": 796, + "skunk, polecat, wood pussy": 361, + "sleeping bag": 797, + "slide rule, slipstick": 798, + "sliding door": 799, + "slot, one-armed bandit": 800, + "sloth bear, Melursus ursinus, Ursus ursinus": 297, + "slug": 114, + "snail": 113, + "snorkel": 801, + "snow leopard, ounce, Panthera uncia": 289, + "snowmobile": 802, + "snowplow, snowplough": 803, + "soap dispenser": 804, + "soccer ball": 805, + "sock": 806, + "soft-coated wheaten terrier": 202, + "solar dish, solar collector, solar furnace": 807, + "sombrero": 808, + "sorrel": 339, + "soup bowl": 809, + "space bar": 810, + "space heater": 811, + "space shuttle": 812, + "spaghetti squash": 940, + "spatula": 813, + "speedboat": 814, + "spider monkey, Ateles geoffroyi": 381, + "spider web, spider's web": 815, + "spindle": 816, + "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123, + "spoonbill": 129, + "sports car, sport car": 817, + "spotlight, spot": 818, + "spotted salamander, Ambystoma maculatum": 28, + "squirrel monkey, Saimiri sciureus": 382, + "stage": 819, + "standard poodle": 267, + "standard schnauzer": 198, + "starfish, sea star": 327, + "steam locomotive": 820, + "steel arch bridge": 821, + "steel drum": 822, + "stethoscope": 823, + "stingray": 6, + "stinkhorn, carrion fungus": 994, + "stole": 824, + "stone wall": 825, + "stopwatch, stop watch": 826, + "stove": 827, + "strainer": 828, + "strawberry": 949, + "street sign": 919, + "streetcar, tram, tramcar, trolley, trolley car": 829, + "stretcher": 830, + "studio couch, day bed": 831, + "stupa, tope": 832, + "sturgeon": 394, + "submarine, pigboat, sub, U-boat": 833, + "suit, suit of clothes": 834, + "sulphur butterfly, sulfur butterfly": 325, + "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89, + "sundial": 835, + "sunglass": 836, + "sunglasses, dark glasses, shades": 837, + "sunscreen, sunblock, sun blocker": 838, + "suspension bridge": 839, + "swab, swob, mop": 840, + "sweatshirt": 841, + "swimming trunks, bathing trunks": 842, + "swing": 843, + "switch, electric switch, electrical switch": 844, + "syringe": 845, + "tabby, tabby cat": 281, + "table lamp": 846, + "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32, + "tank, army tank, armored combat vehicle, armoured combat vehicle": 847, + "tape player": 848, + "tarantula": 76, + "teapot": 849, + "teddy, teddy bear": 850, + "television, television system": 851, + "tench, Tinca tinca": 0, + "tennis ball": 852, + "terrapin": 36, + "thatch, thatched roof": 853, + "theater curtain, theatre curtain": 854, + "thimble": 855, + "three-toed sloth, ai, Bradypus tridactylus": 364, + "thresher, thrasher, threshing machine": 856, + "throne": 857, + "thunder snake, worm snake, Carphophis amoenus": 52, + "tick": 78, + "tiger beetle": 300, + "tiger cat": 282, + "tiger shark, Galeocerdo cuvieri": 3, + "tiger, Panthera tigris": 292, + "tile roof": 858, + "timber wolf, grey wolf, gray wolf, Canis lupus": 269, + "titi, titi monkey": 380, + "toaster": 859, + "tobacco shop, tobacconist shop, tobacconist": 860, + "toilet seat": 861, + "toilet tissue, toilet paper, bathroom tissue": 999, + "torch": 862, + "totem pole": 863, + "toucan": 96, + "tow truck, tow car, wrecker": 864, + "toy poodle": 265, + "toy terrier": 158, + "toyshop": 865, + "tractor": 866, + "traffic light, traffic signal, stoplight": 920, + "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867, + "tray": 868, + "tree frog, tree-frog": 31, + "trench coat": 869, + "triceratops": 51, + "tricycle, trike, velocipede": 870, + "trifle": 927, + "trilobite": 69, + "trimaran": 871, + "tripod": 872, + "triumphal arch": 873, + "trolleybus, trolley coach, trackless trolley": 874, + "trombone": 875, + "tub, vat": 876, + "turnstile": 877, + "tusker": 101, + "typewriter keyboard": 878, + "umbrella": 879, + "unicycle, monocycle": 880, + "upright, upright piano": 881, + "vacuum, vacuum cleaner": 882, + "valley, vale": 979, + "vase": 883, + "vault": 884, + "velvet": 885, + "vending machine": 886, + "vestment": 887, + "viaduct": 888, + "vine snake": 59, + "violin, fiddle": 889, + "vizsla, Hungarian pointer": 211, + "volcano": 980, + "volleyball": 890, + "vulture": 23, + "waffle iron": 891, + "walking stick, walkingstick, stick insect": 313, + "wall clock": 892, + "wallaby, brush kangaroo": 104, + "wallet, billfold, notecase, pocketbook": 893, + "wardrobe, closet, press": 894, + "warplane, military plane": 895, + "warthog": 343, + "washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896, + "washer, automatic washer, washing machine": 897, + "water bottle": 898, + "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346, + "water jug": 899, + "water ouzel, dipper": 20, + "water snake": 58, + "water tower": 900, + "weasel": 356, + "web site, website, internet site, site": 916, + "weevil": 307, + "whippet": 172, + "whiptail, whiptail lizard": 41, + "whiskey jug": 901, + "whistle": 902, + "white stork, Ciconia ciconia": 127, + "white wolf, Arctic wolf, Canis lupus tundrarum": 270, + "wig": 903, + "wild boar, boar, Sus scrofa": 342, + "window screen": 904, + "window shade": 905, + "wine bottle": 907, + "wing": 908, + "wire-haired fox terrier": 188, + "wok": 909, + "wolf spider, hunting spider": 77, + "wombat": 106, + "wood rabbit, cottontail, cottontail rabbit": 330, + "wooden spoon": 910, + "wool, woolen, woollen": 911, + "worm fence, snake fence, snake-rail fence, Virginia fence": 912, + "wreck": 913, + "yawl": 914, + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986, + "yurt": 915, + "zebra": 340, + "zucchini, courgette": 939 + }, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "mlp_ratio": 4.0, + "model_type": "swin", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_channels": 3, + "num_heads": [ + 6, + 12, + 24, + 48 + ], + "num_layers": 4, + "num_return_sequences": 1, + "out_features": [ + "stage1", + "stage2", + "stage3", + "stage4" + ], + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 4, + "path_norm": true, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "qkv_bias": true, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "stage_names": [ + "stem", + "stage1", + "stage2", + "stage3", + "stage4" + ], + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "float32", + "torchscript": false, + "transformers_version": "4.26.0.dev0", + "typical_p": 1.0, + "use_absolute_embeddings": false, + "use_bfloat16": false, + "window_size": 12 + }, + "class_weight": 2.0, + "common_stride": 4, + "decoder_layers": 10, + "dice_weight": 5.0, + "dim_feedforward": 2048, + "dropout": 0.0, + "encoder_feedforward_dim": 1024, + "encoder_layers": 6, + "enforce_input_proj": false, + "enforce_input_projection": false, + "feature_size": 256, + "feature_strides": [ + 4, + 8, + 16, + 32 + ], + "hidden_dim": 256, + "id2label": { + "0": "wall", + "1": "building", + "2": "sky", + "3": "floor", + "4": "tree", + "5": "ceiling", + "6": "road", + "7": "bed ", + "8": "windowpane", + "9": "grass", + "10": "cabinet", + "11": "sidewalk", + "12": "person", + "13": "earth", + "14": "door", + "15": "table", + "16": "mountain", + "17": "plant", + "18": "curtain", + "19": "chair", + "20": "car", + "21": "water", + "22": "painting", + "23": "sofa", + "24": "shelf", + "25": "house", + "26": "sea", + "27": "mirror", + "28": "rug", + "29": "field", + "30": "armchair", + "31": "seat", + "32": "fence", + "33": "desk", + "34": "rock", + "35": "wardrobe", + "36": "lamp", + "37": "bathtub", + "38": "railing", + "39": "cushion", + "40": "base", + "41": "box", + "42": "column", + "43": "signboard", + "44": "chest of drawers", + "45": "counter", + "46": "sand", + "47": "sink", + "48": "skyscraper", + "49": "fireplace", + "50": "refrigerator", + "51": "grandstand", + "52": "path", + "53": "stairs", + "54": "runway", + "55": "case", + "56": "pool table", + "57": "pillow", + "58": "screen door", + "59": "stairway", + "60": "river", + "61": "bridge", + "62": "bookcase", + "63": "blind", + "64": "coffee table", + "65": "toilet", + "66": "flower", + "67": "book", + "68": "hill", + "69": "bench", + "70": "countertop", + "71": "stove", + "72": "palm", + "73": "kitchen island", + "74": "computer", + "75": "swivel chair", + "76": "boat", + "77": "bar", + "78": "arcade machine", + "79": "hovel", + "80": "bus", + "81": "towel", + "82": "light", + "83": "truck", + "84": "tower", + "85": "chandelier", + "86": "awning", + "87": "streetlight", + "88": "booth", + "89": "television receiver", + "90": "airplane", + "91": "dirt track", + "92": "apparel", + "93": "pole", + "94": "land", + "95": "bannister", + "96": "escalator", + "97": "ottoman", + "98": "bottle", + "99": "buffet", + "100": "poster", + "101": "stage", + "102": "van", + "103": "ship", + "104": "fountain", + "105": "conveyer belt", + "106": "canopy", + "107": "washer", + "108": "plaything", + "109": "swimming pool", + "110": "stool", + "111": "barrel", + "112": "basket", + "113": "waterfall", + "114": "tent", + "115": "bag", + "116": "minibike", + "117": "cradle", + "118": "oven", + "119": "ball", + "120": "food", + "121": "step", + "122": "tank", + "123": "trade name", + "124": "microwave", + "125": "pot", + "126": "animal", + "127": "bicycle", + "128": "lake", + "129": "dishwasher", + "130": "screen", + "131": "blanket", + "132": "sculpture", + "133": "hood", + "134": "sconce", + "135": "vase", + "136": "traffic light", + "137": "tray", + "138": "ashcan", + "139": "fan", + "140": "pier", + "141": "crt screen", + "142": "plate", + "143": "monitor", + "144": "bulletin board", + "145": "shower", + "146": "radiator", + "147": "glass", + "148": "clock", + "149": "flag" + }, + "ignore_value": 255, + "importance_sample_ratio": 0.75, + "init_std": 0.02, + "init_xavier_std": 1.0, + "label2id": { + "airplane": 90, + "animal": 126, + "apparel": 92, + "arcade machine": 78, + "armchair": 30, + "ashcan": 138, + "awning": 86, + "bag": 115, + "ball": 119, + "bannister": 95, + "bar": 77, + "barrel": 111, + "base": 40, + "basket": 112, + "bathtub": 37, + "bed ": 7, + "bench": 69, + "bicycle": 127, + "blanket": 131, + "blind": 63, + "boat": 76, + "book": 67, + "bookcase": 62, + "booth": 88, + "bottle": 98, + "box": 41, + "bridge": 61, + "buffet": 99, + "building": 1, + "bulletin board": 144, + "bus": 80, + "cabinet": 10, + "canopy": 106, + "car": 20, + "case": 55, + "ceiling": 5, + "chair": 19, + "chandelier": 85, + "chest of drawers": 44, + "clock": 148, + "coffee table": 64, + "column": 42, + "computer": 74, + "conveyer belt": 105, + "counter": 45, + "countertop": 70, + "cradle": 117, + "crt screen": 141, + "curtain": 18, + "cushion": 39, + "desk": 33, + "dirt track": 91, + "dishwasher": 129, + "door": 14, + "earth": 13, + "escalator": 96, + "fan": 139, + "fence": 32, + "field": 29, + "fireplace": 49, + "flag": 149, + "floor": 3, + "flower": 66, + "food": 120, + "fountain": 104, + "glass": 147, + "grandstand": 51, + "grass": 9, + "hill": 68, + "hood": 133, + "house": 25, + "hovel": 79, + "kitchen island": 73, + "lake": 128, + "lamp": 36, + "land": 94, + "light": 82, + "microwave": 124, + "minibike": 116, + "mirror": 27, + "monitor": 143, + "mountain": 16, + "ottoman": 97, + "oven": 118, + "painting": 22, + "palm": 72, + "path": 52, + "person": 12, + "pier": 140, + "pillow": 57, + "plant": 17, + "plate": 142, + "plaything": 108, + "pole": 93, + "pool table": 56, + "poster": 100, + "pot": 125, + "radiator": 146, + "railing": 38, + "refrigerator": 50, + "river": 60, + "road": 6, + "rock": 34, + "rug": 28, + "runway": 54, + "sand": 46, + "sconce": 134, + "screen": 130, + "screen door": 58, + "sculpture": 132, + "sea": 26, + "seat": 31, + "shelf": 24, + "ship": 103, + "shower": 145, + "sidewalk": 11, + "signboard": 43, + "sink": 47, + "sky": 2, + "skyscraper": 48, + "sofa": 23, + "stage": 101, + "stairs": 53, + "stairway": 59, + "step": 121, + "stool": 110, + "stove": 71, + "streetlight": 87, + "swimming pool": 109, + "swivel chair": 75, + "table": 15, + "tank": 122, + "television receiver": 89, + "tent": 114, + "toilet": 65, + "towel": 81, + "tower": 84, + "trade name": 123, + "traffic light": 136, + "tray": 137, + "tree": 4, + "truck": 83, + "van": 102, + "vase": 135, + "wall": 0, + "wardrobe": 35, + "washer": 107, + "water": 21, + "waterfall": 113, + "windowpane": 8 + }, + "mask_feature_size": 256, + "mask_weight": 5.0, + "model_type": "mask2former", + "no_object_weight": 0.1, + "num_attention_heads": 8, + "num_hidden_layers": 10, + "num_queries": 100, + "output_auxiliary_logits": null, + "oversample_ratio": 3.0, + "pre_norm": false, + "torch_dtype": "float32", + "train_num_points": 12544, + "transformers_version": null, + "use_auxiliary_loss": true +} diff --git a/models/mask2former-swin-large-ade-semantic/model.safetensors b/models/mask2former-swin-large-ade-semantic/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..50fd1febdfa679c3aae8d6b5573645fa4b0c02d9 --- /dev/null +++ b/models/mask2former-swin-large-ade-semantic/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b143c144341c15b4f20165cc6d2c9305fb1b66792f68a6e0e06d2b20dc063b14 +size 866052064 diff --git a/models/mask2former-swin-large-ade-semantic/preprocessor_config.json b/models/mask2former-swin-large-ade-semantic/preprocessor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..0723e20e276898ab516725965619f83deadbdc34 --- /dev/null +++ b/models/mask2former-swin-large-ade-semantic/preprocessor_config.json @@ -0,0 +1,27 @@ +{ + "_max_size": 2560, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "ignore_index": 255, + "image_mean": [ + 0.48500001430511475, + 0.4560000002384186, + 0.4059999883174896 + ], + "image_processor_type": "Mask2FormerImageProcessor", + "image_std": [ + 0.2290000021457672, + 0.2239999920129776, + 0.22499999403953552 + ], + "num_labels": 150, + "reduce_labels": false, + "resample": 2, + "rescale_factor": 0.00392156862745098, + "size": { + "height": 384, + "width": 384 + }, + "size_divisor": 32 +} diff --git a/models/mask2former-swin-large-ade-semantic/pytorch_model.bin b/models/mask2former-swin-large-ade-semantic/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..304a01a63f954cafbc9cd414f1fbe3fc409cabd8 --- /dev/null +++ b/models/mask2former-swin-large-ade-semantic/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dba4b80c7f85372fa985e5b192d1bf7d726e971fdbeebb67dc8c5a96d4f6f873 +size 866216517 diff --git a/models/realesrgan/RealESRGAN_x2plus.pth b/models/realesrgan/RealESRGAN_x2plus.pth new file mode 100644 index 0000000000000000000000000000000000000000..77cc0ef1e8d238fa5cfb409cda2e619a9459ddc9 --- /dev/null +++ b/models/realesrgan/RealESRGAN_x2plus.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb +size 67061725 diff --git a/models/sam_ffhq_aging.pt b/models/sam_ffhq_aging.pt new file mode 100644 index 0000000000000000000000000000000000000000..74a5d1177ef1358f808143e6d073c5809d2d45f0 --- /dev/null +++ b/models/sam_ffhq_aging.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fbd6085e2e6001f51f6dee896ae5ba7bd6c3fab5c06fadba75f2a12188ba143 +size 2270547237 diff --git a/models/shape_predictor_68_face_landmarks.dat b/models/shape_predictor_68_face_landmarks.dat new file mode 100644 index 0000000000000000000000000000000000000000..1e5da4f9a556bec8582e6c55b89b3e6bfdd60021 --- /dev/null +++ b/models/shape_predictor_68_face_landmarks.dat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f +size 99693937 diff --git a/models/zeroscratches/.gitattributes b/models/zeroscratches/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/models/zeroscratches/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/models/zeroscratches/README.md b/models/zeroscratches/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ec6840a4f19bf3afe401c3ec1465e012b49698f6 --- /dev/null +++ b/models/zeroscratches/README.md @@ -0,0 +1,44 @@ +--- +license: mit +--- + +# Zero Scratches +## Old Photo Restoration + +This is a lightweight implementation of [Microsoft Bringing Old Photos Back to Life](https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life) + +Test the [Hugging Face App](https://huggingface.co/spaces/leonelhs/ZeroScratches) + +### For developers +```shell +pip install zeroscatches +``` +### Basic usage +```python + +import PIL.Image +from zeroscratches import EraseScratches + + +image_path = "/path/to/image-scratched.jpg" +eraser = EraseScratches() + +image = PIL.Image.open(image_path) +new_img = eraser.erase(image) + +new_img = PIL.Image.fromarray(new_img) +new_img.show() +``` + +Get the pretrained models at [Hugging Face Zero Scratches](https://huggingface.co/leonelhs/zeroscratches) + +## Some Apps using the library: + +### [Face Shine](https://github.com/leonelhs/face-shine) +Face Shine Is a backend server for photo enhancement and restoration. + +### [Super Face](https://github.com/leonelhs/SuperFace/) +Super Face is a Python QT frontend for Face Shine server. + + + diff --git a/models/zeroscratches/detection/FT_Epoch_latest.pt b/models/zeroscratches/detection/FT_Epoch_latest.pt new file mode 100644 index 0000000000000000000000000000000000000000..2b688be71a8803225894e709332ef33f05044fc4 --- /dev/null +++ b/models/zeroscratches/detection/FT_Epoch_latest.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2d7ab04e9b3885c6b1991bb7a0b823129dd6e3ac078a9fd059ebd2a7ba59a95 +size 451663663 diff --git a/models/zeroscratches/restoration/VAE_A_quality/latest_net_G.pth b/models/zeroscratches/restoration/VAE_A_quality/latest_net_G.pth new file mode 100644 index 0000000000000000000000000000000000000000..1173e8d2510cc046bd1fa3a5997a5349ea1107db --- /dev/null +++ b/models/zeroscratches/restoration/VAE_A_quality/latest_net_G.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de39b6d0081746995afd2393a94a11a10ae052fd9484b4fb06b7fa6bd28dcf5d +size 3498887 diff --git a/models/zeroscratches/restoration/VAE_B_scratch/latest_net_G.pth b/models/zeroscratches/restoration/VAE_B_scratch/latest_net_G.pth new file mode 100644 index 0000000000000000000000000000000000000000..790534655029db70f8f091040eca687ba8305e4d --- /dev/null +++ b/models/zeroscratches/restoration/VAE_B_scratch/latest_net_G.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e22bd7274229beaabf27427646d4623c7b5a302829dd9d466514ff639ba40f6a +size 3498887 diff --git a/models/zeroscratches/restoration/mapping_Patch_Attention/latest_net_mapping_net.pth b/models/zeroscratches/restoration/mapping_Patch_Attention/latest_net_mapping_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..a2578262ecae61936d730f08b6efb23d9135a2c4 --- /dev/null +++ b/models/zeroscratches/restoration/mapping_Patch_Attention/latest_net_mapping_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15f7a145b4eb94565d2f5988f6ca9948d04cef1599f4ca6eb99429e7d2c7d783 +size 295645915 diff --git a/scripts/BG.py b/scripts/BG.py new file mode 100644 index 0000000000000000000000000000000000000000..d214b51b8b69871f0685e54250b8c4964e68e8d7 --- /dev/null +++ b/scripts/BG.py @@ -0,0 +1,92 @@ +import cv2 +import os +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.transforms.functional import normalize +from .briarmbg import BriaRMBG +import PIL +from PIL import Image +from typing import Tuple + +class BG: + def __init__(self): + self.net = BriaRMBG.from_pretrained("./models/RMBG-1.4") + self.device = "cpu" + + def _resize_image(self,image): + image = image.convert('RGB') + model_input_size = (1024, 1024) + image = image.resize(model_input_size, Image.BILINEAR) + return image + + def _BG_mask(self, image_rgb): + orig_image = Image.fromarray(image_rgb) + w, h = orig_image.size + image_rgb = self._resize_image(orig_image) + im_np = np.array(image_rgb) + im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) + im_tensor = torch.unsqueeze(im_tensor, 0) + im_tensor = torch.divide(im_tensor, 255.0) + im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) + im_tensor = im_tensor.to(self.device) + + # inference + with torch.no_grad(): + result = self.net(im_tensor) + + # post process + result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0) + ma = torch.max(result) + mi = torch.min(result) + result = (result - mi) / (ma - mi) + + threshold = 0.5 + mask_np = torch.where(result > threshold, torch.tensor(1), torch.tensor(0)) + mask_np = 1 - mask_np + + mask_np = mask_np.squeeze(0).cpu().numpy().astype(np.uint8) + + if np.count_nonzero(mask_np) == 0: + return None + + # Set kernel size based on image size + kernel_size = max(w, h) // 30 # Adjust this factor according to your preference + + # Morphological operations to remove gaps + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + processed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel) + processed_mask = cv2.morphologyEx(processed_mask, cv2.MORPH_CLOSE, kernel) + + # Additional dilation and erosion to remove small gaps within the mask + processed_mask = cv2.dilate(processed_mask, kernel, iterations=2) + processed_mask = cv2.erode(processed_mask, kernel, iterations=1) + + # Mask off the areas specified by the processed mask + new_mask = cv2.bitwise_and(mask_np, processed_mask) + + return new_mask + + def BG_remove(self,image_rgb,gamma=None): + + + mask = self._BG_mask(image_rgb) + if mask is None: + return image_rgb + binary_mask = np.uint8(mask) * 255 + + + if gamma: + binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma) + binary_mask = ~binary_mask + + image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA) + image_bgra[:, :, 3] = ~binary_mask + + return image_bgra + + def __del__(self): + del self.net + +if __name__ == "__main__": + pass diff --git a/scripts/RRDBNet_arch.py b/scripts/RRDBNet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..09a7ed5d360630b72775e266b7fbb215be1eb15a --- /dev/null +++ b/scripts/RRDBNet_arch.py @@ -0,0 +1,78 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out diff --git a/scripts/SKY.py b/scripts/SKY.py new file mode 100644 index 0000000000000000000000000000000000000000..4f249de7776d2e7879fa5494eec9555621c62864 --- /dev/null +++ b/scripts/SKY.py @@ -0,0 +1,62 @@ +from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation +import numpy as np +import cv2 +import os +import torch +from skimage.metrics import structural_similarity as ssim + +class SKY: + def __init__(self): + self.device = ("cuda" if torch.cuda.is_available() else "cpu") + self.processor = AutoImageProcessor.from_pretrained("models/mask2former-swin-large-ade-semantic") + self.model = Mask2FormerForUniversalSegmentation.from_pretrained("models/mask2former-swin-large-ade-semantic").to(self.device) + + def _SKY_mask(self,image_rgb): + + inputs = self.processor(images=image_rgb, return_tensors="pt").to(self.device) + + outputs = self.model(**inputs) + inputs.to("cpu") + del inputs + predicted_semantic_map = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[image_rgb.shape[:2]])[0] + mask = predicted_semantic_map.cpu().numpy() + predicted_semantic_map.to("cpu") + del predicted_semantic_map + + mask_np = (mask == 2) + + if np.count_nonzero(mask_np)==0: + return None + + return mask_np + + def SKY_remove(self,image_rgb,gamma=None): + + + mask = self._SKY_mask(image_rgb) + if mask is None: + return image_rgb + binary_mask = np.uint8(mask) * 255 + + + if gamma: + binary_mask = cv2.GaussianBlur(~binary_mask, (15, 15), gamma) + binary_mask = ~binary_mask + + image_bgra = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2RGBA) + image_bgra[:, :, 3] = ~binary_mask + + return image_bgra + + + def __del__(self): + self.model = None + self.processor = None + del self.model + del self.processor + torch.cuda.empty_cache() + import gc + gc.collect() + +if __name__ == "__main__": + pass diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64abe7eab939d28268c5682ef2d811cddf166f75 --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1,10 @@ + +__appname__ = "zeroscratches" +__version__ = "1.0.2" + +from .erase_scratches import EraseScratches + + + + + diff --git a/scripts/__pycache__/BG.cpython-310.pyc b/scripts/__pycache__/BG.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05efd5893b76a8de22fd824eab659556774afad7 Binary files /dev/null and b/scripts/__pycache__/BG.cpython-310.pyc differ diff --git a/scripts/__pycache__/RRDBNet_arch.cpython-310.pyc b/scripts/__pycache__/RRDBNet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e909ba6a3ebd33926907e609972771415bbd726 Binary files /dev/null and b/scripts/__pycache__/RRDBNet_arch.cpython-310.pyc differ diff --git a/scripts/__pycache__/SKY.cpython-310.pyc b/scripts/__pycache__/SKY.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9f3220d1d4d3d43306a0d0a369fc38201364aad Binary files /dev/null and b/scripts/__pycache__/SKY.cpython-310.pyc differ diff --git a/scripts/__pycache__/__init__.cpython-310.pyc b/scripts/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff9e19da3f77a90eade3d94c577b874602e72d28 Binary files /dev/null and b/scripts/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/__pycache__/align_all_parallel.cpython-310.pyc b/scripts/__pycache__/align_all_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3d4ae08b383e8805361f9932ae14dd58e88f42 Binary files /dev/null and b/scripts/__pycache__/align_all_parallel.cpython-310.pyc differ diff --git a/scripts/__pycache__/augmentations.cpython-310.pyc b/scripts/__pycache__/augmentations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..161f0473dc818d8bc53f6823241eb6eb3c770e8b Binary files /dev/null and b/scripts/__pycache__/augmentations.cpython-310.pyc differ diff --git a/scripts/__pycache__/briarmbg.cpython-310.pyc b/scripts/__pycache__/briarmbg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d9fb169f7a6bbc8878886e54b743751035cbb0 Binary files /dev/null and b/scripts/__pycache__/briarmbg.cpython-310.pyc differ diff --git a/scripts/__pycache__/colorizer.cpython-310.pyc b/scripts/__pycache__/colorizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e859ad48699aa7014206c4e9d5c6a6b5445f22 Binary files /dev/null and b/scripts/__pycache__/colorizer.cpython-310.pyc differ diff --git a/scripts/__pycache__/common.cpython-310.pyc b/scripts/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f897f0049d2d6847aed2ec6ee1eb0bd31328f7c2 Binary files /dev/null and b/scripts/__pycache__/common.cpython-310.pyc differ diff --git a/scripts/__pycache__/erase_scratches.cpython-310.pyc b/scripts/__pycache__/erase_scratches.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f2592028853958e24ffc598d476b2e39b7a2f65 Binary files /dev/null and b/scripts/__pycache__/erase_scratches.cpython-310.pyc differ diff --git a/scripts/__pycache__/generate_batch.cpython-310.pyc b/scripts/__pycache__/generate_batch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2181cd2f763ef4ab49098233bbcdba44ef96e78d Binary files /dev/null and b/scripts/__pycache__/generate_batch.cpython-310.pyc differ diff --git a/scripts/__pycache__/generate_facerender_batch.cpython-310.pyc b/scripts/__pycache__/generate_facerender_batch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee1319f1cd6b9243b1928007c68fed43e0550921 Binary files /dev/null and b/scripts/__pycache__/generate_facerender_batch.cpython-310.pyc differ diff --git a/scripts/__pycache__/morph_video.cpython-310.pyc b/scripts/__pycache__/morph_video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c949a5d8c92f67e4167842ad0c537e83914d1fc Binary files /dev/null and b/scripts/__pycache__/morph_video.cpython-310.pyc differ diff --git a/scripts/__pycache__/paths_config.cpython-310.pyc b/scripts/__pycache__/paths_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43214afd0085a9c6fae2cc0a7307945e74aeb07f Binary files /dev/null and b/scripts/__pycache__/paths_config.cpython-310.pyc differ diff --git a/scripts/__pycache__/psp.cpython-310.pyc b/scripts/__pycache__/psp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da3ce2f61424589a7e822a4de29a0030cfbdecd9 Binary files /dev/null and b/scripts/__pycache__/psp.cpython-310.pyc differ diff --git a/scripts/__pycache__/talker.cpython-310.pyc b/scripts/__pycache__/talker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67da45d8245b64c25a94cea6d8021477c091cd0f Binary files /dev/null and b/scripts/__pycache__/talker.cpython-310.pyc differ diff --git a/scripts/__pycache__/test_audio2coeff.cpython-310.pyc b/scripts/__pycache__/test_audio2coeff.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52e982f04f15311a6fe9d963ef29716e9dd2462e Binary files /dev/null and b/scripts/__pycache__/test_audio2coeff.cpython-310.pyc differ diff --git a/scripts/__pycache__/upscale_helper.cpython-310.pyc b/scripts/__pycache__/upscale_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..844b593458c2fd2376bf0d6af4a351e6c0e40203 Binary files /dev/null and b/scripts/__pycache__/upscale_helper.cpython-310.pyc differ diff --git a/scripts/align_all_parallel.py b/scripts/align_all_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c3cf5f5def48b4ba127c27d81aca447a2747d15b --- /dev/null +++ b/scripts/align_all_parallel.py @@ -0,0 +1,208 @@ +""" +brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) +author: lzhbrian (https://lzhbrian.me) +date: 2020.1.5 +note: code is heavily borrowed from + https://github.com/NVlabs/ffhq-dataset + http://dlib.net/face_landmark_detection.py.html + +requirements: + apt install cmake + conda install Pillow numpy scipy + pip install dlib + # download face landmark model from: + # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 +""" +from argparse import ArgumentParser +import time +import numpy as np +import PIL +import PIL.Image +import os +import scipy +import scipy.ndimage +import dlib +import multiprocessing as mp +import math + +from scripts.paths_config import model_paths +SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"] + + +def get_landmark(filepath, predictor): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + + shape = None + + for k, d in enumerate(dets): + shape = predictor(img, d) + + if not shape: + raise Exception("Could not find face in image. Try another!") + + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + return lm + + +def align_face(filepath, predictor): + """ + :param filepath: str + :return: PIL Image + """ + + lm = get_landmark(filepath, predictor) + + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = PIL.Image.open(filepath).convert("RGB") + + output_size = 256 + transform_size = 256 + enable_padding = True + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Save aligned image. + return img + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +def extract_on_paths(file_paths): + predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH) + pid = mp.current_process().name + print(f'\t{pid} is starting to extract on #{len(file_paths)} images') + tot_count = len(file_paths) + count = 0 + for file_path, res_path in file_paths: + count += 1 + if count % 100 == 0: + print(f'{pid} done with {count}/{tot_count}') + try: + res = align_face(file_path, predictor) + res = res.convert('RGB') + os.makedirs(os.path.dirname(res_path), exist_ok=True) + res.save(res_path) + except Exception: + continue + print('\tDone!') + + +def parse_args(): + parser = ArgumentParser(add_help=False) + parser.add_argument('--num_threads', type=int, default=1) + parser.add_argument('--root_path', type=str, default='') + args = parser.parse_args() + return args + + +def run(args): + root_path = args.root_path + out_crops_path = root_path + '_crops' + if not os.path.exists(out_crops_path): + os.makedirs(out_crops_path, exist_ok=True) + + file_paths = [] + for root, dirs, files in os.walk(root_path): + for file in files: + file_path = os.path.join(root, file) + fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path)) + res_path = f'{os.path.splitext(fname)[0]}.jpg' + if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path): + continue + file_paths.append((file_path, res_path)) + + file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) + print(len(file_chunks)) + pool = mp.Pool(args.num_threads) + print(f'Running on {len(file_paths)} paths\nHere we goooo') + tic = time.time() + pool.map(extract_on_paths, file_chunks) + toc = time.time() + print(f'Mischief managed in {str(toc - tic)}s') + + +if __name__ == '__main__': + args = parse_args() + run(args) diff --git a/scripts/audio2exp_models/__pycache__/audio2exp.cpython-310.pyc b/scripts/audio2exp_models/__pycache__/audio2exp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7b63acf91fd41d46a3a824b02ee36eaed089efe Binary files /dev/null and b/scripts/audio2exp_models/__pycache__/audio2exp.cpython-310.pyc differ diff --git a/scripts/audio2exp_models/__pycache__/networks.cpython-310.pyc b/scripts/audio2exp_models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1d5f29d102a36c24c8660dfc4be9df6fd55b641 Binary files /dev/null and b/scripts/audio2exp_models/__pycache__/networks.cpython-310.pyc differ diff --git a/scripts/audio2exp_models/audio2exp.py b/scripts/audio2exp_models/audio2exp.py new file mode 100644 index 0000000000000000000000000000000000000000..9e79a929560592687a505e13188796e2b0ca8772 --- /dev/null +++ b/scripts/audio2exp_models/audio2exp.py @@ -0,0 +1,41 @@ +from tqdm import tqdm +import torch +from torch import nn + + +class Audio2Exp(nn.Module): + def __init__(self, netG, cfg, device, prepare_training_loss=False): + super(Audio2Exp, self).__init__() + self.cfg = cfg + self.device = device + self.netG = netG.to(device) + + def test(self, batch): + + mel_input = batch['indiv_mels'] # bs T 1 80 16 + bs = mel_input.shape[0] + T = mel_input.shape[1] + + exp_coeff_pred = [] + + for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames + + current_mel_input = mel_input[:,i:i+10] + + #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 + ref = batch['ref'][:, :, :64][:, i:i+10] + ratio = batch['ratio_gt'][:, i:i+10] #bs T + + audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 + + curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 + + exp_coeff_pred += [curr_exp_coeff_pred] + + # BS x T x 64 + results_dict = { + 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) + } + return results_dict + + diff --git a/scripts/audio2exp_models/networks.py b/scripts/audio2exp_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..f052e18101f5446a527ae354b3621e7d0d4991cc --- /dev/null +++ b/scripts/audio2exp_models/networks.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F +from torch import nn + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + self.use_act = use_act + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + + if self.use_act: + return self.act(out) + else: + return out + +class SimpleWrapperV2(nn.Module): + def __init__(self) -> None: + super().__init__() + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0), + ) + + #### load the pre-trained audio_encoder + #self.audio_encoder = self.audio_encoder.to(device) + ''' + wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] + state_dict = self.audio_encoder.state_dict() + + for k,v in wav2lip_state_dict.items(): + if 'audio_encoder' in k: + print('init:', k) + state_dict[k.replace('module.audio_encoder.', '')] = v + self.audio_encoder.load_state_dict(state_dict) + ''' + + self.mapping1 = nn.Linear(512+64+1, 64) + #self.mapping2 = nn.Linear(30, 64) + #nn.init.constant_(self.mapping1.weight, 0.) + nn.init.constant_(self.mapping1.bias, 0.) + + def forward(self, x, ref, ratio): + x = self.audio_encoder(x).view(x.size(0), -1) + ref_reshape = ref.reshape(x.size(0), -1) + ratio = ratio.reshape(x.size(0), -1) + + y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) + out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial + return out diff --git a/scripts/audio2pose_models/__pycache__/audio2pose.cpython-310.pyc b/scripts/audio2pose_models/__pycache__/audio2pose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87e2950688912a99f17adfa6801265f25e11dc42 Binary files /dev/null and b/scripts/audio2pose_models/__pycache__/audio2pose.cpython-310.pyc differ diff --git a/scripts/audio2pose_models/__pycache__/audio_encoder.cpython-310.pyc b/scripts/audio2pose_models/__pycache__/audio_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b878fd608024e908eabd326c24e28e5edbc6853 Binary files /dev/null and b/scripts/audio2pose_models/__pycache__/audio_encoder.cpython-310.pyc differ diff --git a/scripts/audio2pose_models/__pycache__/cvae.cpython-310.pyc b/scripts/audio2pose_models/__pycache__/cvae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..242633e786ce1d3025824ee46e29954c53eb96ad Binary files /dev/null and b/scripts/audio2pose_models/__pycache__/cvae.cpython-310.pyc differ diff --git a/scripts/audio2pose_models/__pycache__/discriminator.cpython-310.pyc b/scripts/audio2pose_models/__pycache__/discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..323cd9477f715af95cc9abc9341ab77c7c377ac9 Binary files /dev/null and b/scripts/audio2pose_models/__pycache__/discriminator.cpython-310.pyc differ diff --git a/scripts/audio2pose_models/__pycache__/networks.cpython-310.pyc b/scripts/audio2pose_models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad01eacb795204089134e464c2dc6c6c96cadfa Binary files /dev/null and b/scripts/audio2pose_models/__pycache__/networks.cpython-310.pyc differ diff --git a/scripts/audio2pose_models/__pycache__/res_unet.cpython-310.pyc b/scripts/audio2pose_models/__pycache__/res_unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16f545c51e6a8e0194006b0f89dec9b9917a425a Binary files /dev/null and b/scripts/audio2pose_models/__pycache__/res_unet.cpython-310.pyc differ diff --git a/scripts/audio2pose_models/audio2pose.py b/scripts/audio2pose_models/audio2pose.py new file mode 100644 index 0000000000000000000000000000000000000000..60a55dec87112e7d0e2986f5c9a3a4188ed9d57c --- /dev/null +++ b/scripts/audio2pose_models/audio2pose.py @@ -0,0 +1,94 @@ +import torch +from torch import nn +from scripts.audio2pose_models.cvae import CVAE +from scripts.audio2pose_models.discriminator import PoseSequenceDiscriminator +from scripts.audio2pose_models.audio_encoder import AudioEncoder + +class Audio2Pose(nn.Module): + def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): + super().__init__() + self.cfg = cfg + self.seq_len = cfg.MODEL.CVAE.SEQ_LEN + self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE + self.device = device + + self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) + self.audio_encoder.eval() + for param in self.audio_encoder.parameters(): + param.requires_grad = False + + self.netG = CVAE(cfg) + self.netD_motion = PoseSequenceDiscriminator(cfg) + + + def forward(self, x): + + batch = {} + coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 + batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6 + batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6 + batch['class'] = x['class'].squeeze(0).cuda() # bs + indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 + + # forward + audio_emb_list = [] + audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG(batch) + + pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 + pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6 + pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6 + + batch['pose_pred'] = pose_pred + batch['pose_gt'] = pose_gt + + return batch + + def test(self, x): + + batch = {} + ref = x['ref'] #bs 1 70 + batch['ref'] = x['ref'][:,0,-6:] + batch['class'] = x['class'] + bs = ref.shape[0] + + indiv_mels= x['indiv_mels'] # bs T 1 80 16 + indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame + num_frames = x['num_frames'] + num_frames = int(num_frames) - 1 + + # + div = num_frames//self.seq_len + re = num_frames%self.seq_len + audio_emb_list = [] + pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, + device=batch['ref'].device)] + + for i in range(div): + z = torch.randn(bs, self.latent_dim).to(ref.device) + batch['z'] = z + audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG.test(batch) + pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 + + if re != 0: + z = torch.randn(bs, self.latent_dim).to(ref.device) + batch['z'] = z + audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 + if audio_emb.shape[1] != self.seq_len: + pad_dim = self.seq_len-audio_emb.shape[1] + pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) + audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) + batch['audio_emb'] = audio_emb + batch = self.netG.test(batch) + pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) + + pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) + batch['pose_motion_pred'] = pose_motion_pred + + pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 + + batch['pose_pred'] = pose_pred + return batch diff --git a/scripts/audio2pose_models/audio_encoder.py b/scripts/audio2pose_models/audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6279d2014a2e786a6c549f084339e18d00e50331 --- /dev/null +++ b/scripts/audio2pose_models/audio_encoder.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class AudioEncoder(nn.Module): + def __init__(self, wav2lip_checkpoint, device): + super(AudioEncoder, self).__init__() + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. + # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] + # state_dict = self.audio_encoder.state_dict() + + # for k,v in wav2lip_state_dict.items(): + # if 'audio_encoder' in k: + # state_dict[k.replace('module.audio_encoder.', '')] = v + # self.audio_encoder.load_state_dict(state_dict) + + + def forward(self, audio_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + dim = audio_embedding.shape[1] + audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) + + return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 diff --git a/scripts/audio2pose_models/cvae.py b/scripts/audio2pose_models/cvae.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8c9a540a4775b044058fa1657c1251d179127f --- /dev/null +++ b/scripts/audio2pose_models/cvae.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +from torch import nn +from scripts.audio2pose_models.res_unet import ResUnet + +def class2onehot(idx, class_num): + + assert torch.max(idx).item() < class_num + onehot = torch.zeros(idx.size(0), class_num).to(idx.device) + onehot.scatter_(1, idx, 1) + return onehot + +class CVAE(nn.Module): + def __init__(self, cfg): + super().__init__() + encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES + decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES + latent_size = cfg.MODEL.CVAE.LATENT_SIZE + num_classes = cfg.DATASET.NUM_CLASSES + audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE + audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE + seq_len = cfg.MODEL.CVAE.SEQ_LEN + + self.latent_size = latent_size + + self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len) + self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len) + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, batch): + batch = self.encoder(batch) + mu = batch['mu'] + logvar = batch['logvar'] + z = self.reparameterize(mu, logvar) + batch['z'] = z + return self.decoder(batch) + + def test(self, batch): + ''' + class_id = batch['class'] + z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) + batch['z'] = z + ''' + return self.decoder(batch) + +class ENCODER(nn.Module): + def __init__(self, layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len): + super().__init__() + + self.resunet = ResUnet() + self.num_classes = num_classes + self.seq_len = seq_len + + self.MLP = nn.Sequential() + layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 + for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): + self.MLP.add_module( + name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) + self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) + + self.linear_means = nn.Linear(layer_sizes[-1], latent_size) + self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) + self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) + + self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) + + def forward(self, batch): + class_id = batch['class'] + pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 + ref = batch['ref'] #bs 6 + bs = pose_motion_gt.shape[0] + audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size + + #pose encode + pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 + pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 + + #audio mapping + print(audio_in.shape) + audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size + audio_out = audio_out.reshape(bs, -1) + + class_bias = self.classbias[class_id] #bs latent_size + x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size + x_out = self.MLP(x_in) + + mu = self.linear_means(x_out) + logvar = self.linear_means(x_out) #bs latent_size + + batch.update({'mu':mu, 'logvar':logvar}) + return batch + +class DECODER(nn.Module): + def __init__(self, layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len): + super().__init__() + + self.resunet = ResUnet() + self.num_classes = num_classes + self.seq_len = seq_len + + self.MLP = nn.Sequential() + input_size = latent_size + seq_len*audio_emb_out_size + 6 + for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): + self.MLP.add_module( + name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) + if i+1 < len(layer_sizes): + self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) + else: + self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) + + self.pose_linear = nn.Linear(6, 6) + self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) + + self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) + + def forward(self, batch): + + z = batch['z'] #bs latent_size + bs = z.shape[0] + class_id = batch['class'] + ref = batch['ref'] #bs 6 + audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size + #print('audio_in: ', audio_in[:, :, :10]) + + audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size + #print('audio_out: ', audio_out[:, :, :10]) + audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size + class_bias = self.classbias[class_id] #bs latent_size + + z = z + class_bias + x_in = torch.cat([ref, z, audio_out], dim=-1) + x_out = self.MLP(x_in) # bs layer_sizes[-1] + x_out = x_out.reshape((bs, self.seq_len, -1)) + + #print('x_out: ', x_out) + + pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 + + pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 + + batch.update({'pose_motion_pred':pose_motion_pred}) + return batch diff --git a/scripts/audio2pose_models/discriminator.py b/scripts/audio2pose_models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..339c38e4812ff38a810f0f3a1c01812f6d5d78db --- /dev/null +++ b/scripts/audio2pose_models/discriminator.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F +from torch import nn + +class ConvNormRelu(nn.Module): + def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, + kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): + super().__init__() + if kernel_size is None: + if downsample: + kernel_size, stride, padding = 4, 2, 1 + else: + kernel_size, stride, padding = 3, 1, 1 + + if conv_type == '2d': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False, + ) + if norm == 'BN': + self.norm = nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = nn.InstanceNorm2d(out_channels) + else: + raise NotImplementedError + elif conv_type == '1d': + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False, + ) + if norm == 'BN': + self.norm = nn.BatchNorm1d(out_channels) + elif norm == 'IN': + self.norm = nn.InstanceNorm1d(out_channels) + else: + raise NotImplementedError + nn.init.kaiming_normal_(self.conv.weight) + + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if isinstance(self.norm, nn.InstanceNorm1d): + x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] + else: + x = self.norm(x) + x = self.act(x) + return x + + +class PoseSequenceDiscriminator(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU + + self.seq = nn.Sequential( + ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 + ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 + ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 + nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 + ) + + def forward(self, x): + x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) + x = self.seq(x) + x = x.squeeze(1) + return x \ No newline at end of file diff --git a/scripts/audio2pose_models/networks.py b/scripts/audio2pose_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa0b1390e7b4bb0e16057ac94d2fe84f48421af --- /dev/null +++ b/scripts/audio2pose_models/networks.py @@ -0,0 +1,140 @@ +import torch.nn as nn +import torch + + +class ResidualConv(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding): + super(ResidualConv, self).__init__() + + self.conv_block = nn.Sequential( + nn.BatchNorm2d(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, stride=stride, padding=padding + ), + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), + nn.BatchNorm2d(output_dim), + ) + + def forward(self, x): + + return self.conv_block(x) + self.conv_skip(x) + + +class Upsample(nn.Module): + def __init__(self, input_dim, output_dim, kernel, stride): + super(Upsample, self).__init__() + + self.upsample = nn.ConvTranspose2d( + input_dim, output_dim, kernel_size=kernel, stride=stride + ) + + def forward(self, x): + return self.upsample(x) + + +class Squeeze_Excite_Block(nn.Module): + def __init__(self, channel, reduction=16): + super(Squeeze_Excite_Block, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class ASPP(nn.Module): + def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): + super(ASPP, self).__init__() + + self.aspp_block1 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + self.aspp_block2 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + self.aspp_block3 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + + self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) + self._init_weights() + + def forward(self, x): + x1 = self.aspp_block1(x) + x2 = self.aspp_block2(x) + x3 = self.aspp_block3(x) + out = torch.cat([x1, x2, x3], dim=1) + return self.output(out) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Upsample_(nn.Module): + def __init__(self, scale=2): + super(Upsample_, self).__init__() + + self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) + + def forward(self, x): + return self.upsample(x) + + +class AttentionBlock(nn.Module): + def __init__(self, input_encoder, input_decoder, output_dim): + super(AttentionBlock, self).__init__() + + self.conv_encoder = nn.Sequential( + nn.BatchNorm2d(input_encoder), + nn.ReLU(), + nn.Conv2d(input_encoder, output_dim, 3, padding=1), + nn.MaxPool2d(2, 2), + ) + + self.conv_decoder = nn.Sequential( + nn.BatchNorm2d(input_decoder), + nn.ReLU(), + nn.Conv2d(input_decoder, output_dim, 3, padding=1), + ) + + self.conv_attn = nn.Sequential( + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, 1, 1), + ) + + def forward(self, x1, x2): + out = self.conv_encoder(x1) + self.conv_decoder(x2) + out = self.conv_attn(out) + return out * x2 \ No newline at end of file diff --git a/scripts/audio2pose_models/res_unet.py b/scripts/audio2pose_models/res_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..0c95968ae7f0b417ffdd721301849c680fba35b8 --- /dev/null +++ b/scripts/audio2pose_models/res_unet.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from scripts.audio2pose_models.networks import ResidualConv, Upsample + + +class ResUnet(nn.Module): + def __init__(self, channel=1, filters=[32, 64, 128, 256]): + super(ResUnet, self).__init__() + + self.input_layer = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), + nn.BatchNorm2d(filters[0]), + nn.ReLU(), + nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), + ) + self.input_skip = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) + ) + + self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) + self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) + + self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) + + self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) + self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) + + self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) + self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) + + self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) + self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) + + self.output_layer = nn.Sequential( + nn.Conv2d(filters[0], 1, 1, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + # Encode + x1 = self.input_layer(x) + self.input_skip(x) + x2 = self.residual_conv_1(x1) + x3 = self.residual_conv_2(x2) + # Bridge + x4 = self.bridge(x3) + + # Decode + x4 = self.upsample_1(x4) + x5 = torch.cat([x4, x3], dim=1) + + x6 = self.up_residual_conv1(x5) + + x6 = self.upsample_2(x6) + x7 = torch.cat([x6, x2], dim=1) + + x8 = self.up_residual_conv2(x7) + + x8 = self.upsample_3(x8) + x9 = torch.cat([x8, x1], dim=1) + + x10 = self.up_residual_conv3(x9) + + output = self.output_layer(x10) + + return output \ No newline at end of file diff --git a/scripts/augmentations.py b/scripts/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..efdb02eb045680d4ac8e4217c9b0b72fd1096db2 --- /dev/null +++ b/scripts/augmentations.py @@ -0,0 +1,24 @@ +import numpy as np +import torch + + +class AgeTransformer(object): + + def __init__(self, target_age): + self.target_age = target_age + + def __call__(self, img): + img = self.add_aging_channel(img) + return img + + def add_aging_channel(self, img): + target_age = self.__get_target_age() + target_age = int(target_age) / 100 # normalize aging amount to be in range [-1,1] + img = torch.cat((img, target_age * torch.ones((1, img.shape[1], img.shape[2])))) + return img + + def __get_target_age(self): + if self.target_age == "uniform_random": + return np.random.randint(low=0., high=101, size=1)[0] + else: + return self.target_age diff --git a/scripts/basicsr/VERSION b/scripts/basicsr/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..1892b926767774e9ba91f1e584fa71b4c56abb69 --- /dev/null +++ b/scripts/basicsr/VERSION @@ -0,0 +1 @@ +1.3.2 diff --git a/scripts/basicsr/__init__.py b/scripts/basicsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ffcccd7fc0f33b59d99d73d0436d60e561b0fc --- /dev/null +++ b/scripts/basicsr/__init__.py @@ -0,0 +1,11 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .train import * +from .utils import * +from .version import __gitsha__, __version__ diff --git a/scripts/basicsr/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6a5f3195e2769a9f787ec03306b44664f95fb12 Binary files /dev/null and b/scripts/basicsr/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/__pycache__/train.cpython-310.pyc b/scripts/basicsr/__pycache__/train.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52233a8b92833c11dd5a3f2c7d3431628d2d778c Binary files /dev/null and b/scripts/basicsr/__pycache__/train.cpython-310.pyc differ diff --git a/scripts/basicsr/__pycache__/version.cpython-310.pyc b/scripts/basicsr/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2865299ecb2541daa921ebecccb0265becc6217e Binary files /dev/null and b/scripts/basicsr/__pycache__/version.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__init__.py b/scripts/basicsr/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57f9c373b83eb889e6e9401182e085f9158c8ec3 --- /dev/null +++ b/scripts/basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from scripts.basicsr.utils import get_root_logger, scandir +from scripts.basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'scripts.basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/scripts/basicsr/archs/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d49111cdfea4c3e27160a7aedee017290945255 Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__pycache__/arcface_arch.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/arcface_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acd37dbd61638ea5babd22b6667766463e7d710 Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/arcface_arch.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__pycache__/arch_util.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/arch_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c12faedd0d5b45996357d1f6bf639da0e070bb2 Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/arch_util.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__pycache__/codeformer_arch.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/codeformer_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1627c66ad0f2c3db058bb0f425a0665d2c5b9707 Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/codeformer_arch.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77eb4a7e9445b191b7a5a7d2ba9df0802528483c Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63d0f6164d94273e8f8bb8b4d8fbbf4b6b4e0c0c Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/__pycache__/vqgan_arch.cpython-310.pyc b/scripts/basicsr/archs/__pycache__/vqgan_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d90bddbf4eb956e8342da298d271020cacc57b78 Binary files /dev/null and b/scripts/basicsr/archs/__pycache__/vqgan_arch.cpython-310.pyc differ diff --git a/scripts/basicsr/archs/arcface_arch.py b/scripts/basicsr/archs/arcface_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..0574e62af331938acd225e61bc9b1dcbcc6b8ab4 --- /dev/null +++ b/scripts/basicsr/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from scripts.basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x \ No newline at end of file diff --git a/scripts/basicsr/archs/arch_util.py b/scripts/basicsr/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e75f403422b8aa381b0a43807b99ed0018f38e97 --- /dev/null +++ b/scripts/basicsr/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from scripts.basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from scripts.basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple \ No newline at end of file diff --git a/scripts/basicsr/archs/codeformer_arch.py b/scripts/basicsr/archs/codeformer_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..777c552cfa656615a4053ece35b1a96551cf33d3 --- /dev/null +++ b/scripts/basicsr/archs/codeformer_arch.py @@ -0,0 +1,280 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from scripts.basicsr.archs.vqgan_arch import * +from scripts.basicsr.utils import get_root_logger +from scripts.basicsr.utils.registry import ARCH_REGISTRY + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator'], vqgan_path=None): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if vqgan_path is not None: + self.load_state_dict( + torch.load(vqgan_path, map_location='cpu')['params_ema']) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat \ No newline at end of file diff --git a/scripts/basicsr/archs/rrdbnet_arch.py b/scripts/basicsr/archs/rrdbnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..aabc379523633f08afb32972710b310163366e9a --- /dev/null +++ b/scripts/basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from scripts.basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out \ No newline at end of file diff --git a/scripts/basicsr/archs/vgg_arch.py b/scripts/basicsr/archs/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a96f66f3c5279364340502381e4ee5370256b5a --- /dev/null +++ b/scripts/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from scripts.basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/scripts/basicsr/archs/vqgan_arch.py b/scripts/basicsr/archs/vqgan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..65c7fb6108df8fe05d8c5c2b32e175bab2a81f28 --- /dev/null +++ b/scripts/basicsr/archs/vqgan_arch.py @@ -0,0 +1,434 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from scripts.basicsr.utils import get_root_logger +from scripts.basicsr.utils.registry import ARCH_REGISTRY + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + # min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) \ No newline at end of file diff --git a/scripts/basicsr/data/__init__.py b/scripts/basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1bd8f179737394d2318f740dd97fb7890cb63be --- /dev/null +++ b/scripts/basicsr/data/__init__.py @@ -0,0 +1,100 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from scripts.basicsr.data.prefetch_dataloader import PrefetchDataLoader +from scripts.basicsr.utils import get_root_logger, scandir +from scripts.basicsr.utils.dist_util import get_dist_info +from scripts.basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'scripts.basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must constain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/scripts/basicsr/data/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d19bc002bd2ad6892cce5196c6016bd96dfdceeb Binary files /dev/null and b/scripts/basicsr/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/data_sampler.cpython-310.pyc b/scripts/basicsr/data/__pycache__/data_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76149c31fb89083d8e8ea3ed7deed55689a33a67 Binary files /dev/null and b/scripts/basicsr/data/__pycache__/data_sampler.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/data_util.cpython-310.pyc b/scripts/basicsr/data/__pycache__/data_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3c2d376087ed8623c5c0169683780e6c9d9c40f Binary files /dev/null and b/scripts/basicsr/data/__pycache__/data_util.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/ffhq_blind_dataset.cpython-310.pyc b/scripts/basicsr/data/__pycache__/ffhq_blind_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be7b200d09a4b0cec46ad168fa7bfece432fb0da Binary files /dev/null and b/scripts/basicsr/data/__pycache__/ffhq_blind_dataset.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/ffhq_blind_joint_dataset.cpython-310.pyc b/scripts/basicsr/data/__pycache__/ffhq_blind_joint_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f3a577fc6cbb0aeb626988ff0457f4d27978f11 Binary files /dev/null and b/scripts/basicsr/data/__pycache__/ffhq_blind_joint_dataset.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/gaussian_kernels.cpython-310.pyc b/scripts/basicsr/data/__pycache__/gaussian_kernels.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f36d45ce79a5f5924cd8c25d6371ce52f7882211 Binary files /dev/null and b/scripts/basicsr/data/__pycache__/gaussian_kernels.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc b/scripts/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b212a4ebdcdffb96c9175f409fb89f45d68e685 Binary files /dev/null and b/scripts/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc b/scripts/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d00f2eb8f5c4532411633f736bac9c9c387e25d Binary files /dev/null and b/scripts/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc differ diff --git a/scripts/basicsr/data/__pycache__/transforms.cpython-310.pyc b/scripts/basicsr/data/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62bf69720464217b2d24fa3047ff879fc71cede3 Binary files /dev/null and b/scripts/basicsr/data/__pycache__/transforms.cpython-310.pyc differ diff --git a/scripts/basicsr/data/data_sampler.py b/scripts/basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2 --- /dev/null +++ b/scripts/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/scripts/basicsr/data/data_util.py b/scripts/basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbfdd4502d0c82e11b1abc791c375e687ceb716 --- /dev/null +++ b/scripts/basicsr/data/data_util.py @@ -0,0 +1,392 @@ +import cv2 +import math +import numpy as np +import torch +from os import path as osp +from PIL import Image, ImageDraw +from torch.nn import functional as F + +from scripts.basicsr.data.transforms import mod_crop +from scripts.basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x + + +def brush_stroke_mask(img, color=(255,255,255)): + min_num_vertex = 8 + max_num_vertex = 28 + mean_angle = 2*math.pi / 5 + angle_range = 2*math.pi / 12 + # training large mask ratio (training setting) + min_width = 30 + max_width = 70 + # very large mask ratio (test setting and refine after 200k) + # min_width = 80 + # max_width = 120 + def generate_mask(H, W, img=None): + average_radius = math.sqrt(H*H+W*W) / 8 + mask = Image.new('RGB', (W, H), 0) + if img is not None: mask = img # Image.fromarray(img) + + for _ in range(np.random.randint(1, 4)): + num_vertex = np.random.randint(min_num_vertex, max_num_vertex) + angle_min = mean_angle - np.random.uniform(0, angle_range) + angle_max = mean_angle + np.random.uniform(0, angle_range) + angles = [] + vertex = [] + for i in range(num_vertex): + if i % 2 == 0: + angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) + else: + angles.append(np.random.uniform(angle_min, angle_max)) + + h, w = mask.size + vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) + for i in range(num_vertex): + r = np.clip( + np.random.normal(loc=average_radius, scale=average_radius//2), + 0, 2*average_radius) + new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) + new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) + vertex.append((int(new_x), int(new_y))) + + draw = ImageDraw.Draw(mask) + width = int(np.random.uniform(min_width, max_width)) + draw.line(vertex, fill=color, width=width) + for v in vertex: + draw.ellipse((v[0] - width//2, + v[1] - width//2, + v[0] + width//2, + v[1] + width//2), + fill=color) + + return mask + + width, height = img.size + mask = generate_mask(height, width, img) + return mask + + +def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10): + """Generate a random free form mask with configuration. + Args: + config: Config should have configuration including IMG_SHAPES, + VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. + Returns: + tuple: (top, left, height, width) + Link: + https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py + """ + height = shape[0] + width = shape[1] + mask = np.zeros((height, width), np.float32) + times = np.random.randint(times-5, times) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len-20, max_len) + brush_w = 5 + np.random.randint(max_width-30, max_width) + end_x = (start_x + length * np.sin(angle)).astype(np.int32) + end_y = (start_y + length * np.cos(angle)).astype(np.int32) + cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) + start_x, start_y = end_x, end_y + return mask.astype(np.float32) \ No newline at end of file diff --git a/scripts/basicsr/data/ffhq_blind_dataset.py b/scripts/basicsr/data/ffhq_blind_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7f230010c6778efe2f84bf9563053ecd1bba9113 --- /dev/null +++ b/scripts/basicsr/data/ffhq_blind_dataset.py @@ -0,0 +1,299 @@ +import cv2 +import math +import random +import numpy as np +import os.path as osp +from scipy.io import loadmat +from PIL import Image +import torch +import torch.utils.data as data +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, + adjust_hue, adjust_saturation, normalize) +from scripts.basicsr.data import gaussian_kernels as gaussian_kernels +from scripts.basicsr.data.transforms import augment +from scripts.basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask +from scripts.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from scripts.basicsr.utils.registry import DATASET_REGISTRY + +@DATASET_REGISTRY.register() +class FFHQBlindDataset(data.Dataset): + + def __init__(self, opt): + super(FFHQBlindDataset, self).__init__() + logger = get_root_logger() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.gt_size = opt.get('gt_size', 512) + self.in_size = opt.get('in_size', 512) + assert self.gt_size >= self.in_size, 'Wrong setting.' + + self.mean = opt.get('mean', [0.5, 0.5, 0.5]) + self.std = opt.get('std', [0.5, 0.5, 0.5]) + + self.component_path = opt.get('component_path', None) + self.latent_gt_path = opt.get('latent_gt_path', None) + + if self.component_path is not None: + self.crop_components = True + self.components_dict = torch.load(self.component_path) + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4) + self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1) + self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3) + else: + self.crop_components = False + + if self.latent_gt_path is not None: + self.load_latent_gt = True + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}') + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + self.paths = paths_from_folder(self.gt_folder) + + # inpainting mask + self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False) + if self.gen_inpaint_mask: + logger.info(f'generate mask ...') + # self.mask_max_angle = opt.get('mask_max_angle', 10) + # self.mask_max_len = opt.get('mask_max_len', 150) + # self.mask_max_width = opt.get('mask_max_width', 50) + # self.mask_draw_times = opt.get('mask_draw_times', 10) + # # print + # logger.info(f'mask_max_angle: {self.mask_max_angle}') + # logger.info(f'mask_max_len: {self.mask_max_len}') + # logger.info(f'mask_max_width: {self.mask_max_width}') + # logger.info(f'mask_draw_times: {self.mask_draw_times}') + + # perform corrupt + self.use_corrupt = opt.get('use_corrupt', True) + self.use_motion_kernel = False + # self.use_motion_kernel = opt.get('use_motion_kernel', True) + + if self.use_motion_kernel: + self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) + motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth') + self.motion_kernels = torch.load(motion_kernel_path) + + if self.use_corrupt and not self.gen_inpaint_mask: + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.blur_sigma = opt['blur_sigma'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + # print + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob', None) + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None) + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + + # to gray + self.gray_prob = opt.get('gray_prob', 0.0) + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + + def get_component_locations(self, name, status): + components_bbox = self.components_dict[name] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0] + components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0] + components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0] + + locations_gt = {} + locations_in = {} + for part in ['left_eye', 'right_eye', 'nose', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + elif part == 'nose': + half_len *= self.nose_enlarge_ratio + elif part == 'mouth': + half_len *= self.mouth_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations_gt[part] = loc + loc_in = loc/(self.gt_size//self.in_size) + locations_in[part] = loc_in + return locations_gt, locations_in + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + name = osp.basename(gt_path)[:-4] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + + if self.load_latent_gt: + if status[0]: + latent_gt = self.latent_gt_dict['hflip'][name] + else: + latent_gt = self.latent_gt_dict['orig'][name] + + if self.crop_components: + locations_gt, locations_in = self.get_component_locations(name, status) + + # generate in image + img_in = img_gt + if self.use_corrupt and not self.gen_inpaint_mask: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in = cv2.filter2D(img_in,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, + [-math.pi, math.pi], + noise_range=None) + img_in = cv2.filter2D(img_in, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) + noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma + img_in = img_in + noise + img_in = np.clip(img_in, 0, 1) + + # jpeg + if self.jpeg_range is not None: + jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p] + _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) + img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + # if self.gen_inpaint_mask: + # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size), + # max_angle = self.mask_max_angle, max_len = self.mask_max_len, + # max_width = self.mask_max_width, times = self.mask_draw_times) + # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \ + # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1) + + # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size) + + if self.gen_inpaint_mask: + img_in = (img_in*255).astype('uint8') + img_in = brush_stroke_mask(Image.fromarray(img_in)) + img_in = np.array(img_in) / 255. + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_in = self.color_jitter(img_in, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) + img_in = np.tile(img_in[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + + # round and clip + img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + + # Set vgg range_norm=True if use the normalization here + # normalize + normalize(img_in, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path} + + if self.crop_components: + return_dict['locations_in'] = locations_in + return_dict['locations_gt'] = locations_gt + + if self.load_latent_gt: + return_dict['latent_gt'] = latent_gt + + # if self.gen_inpaint_mask: + # return_dict['inpaint_mask'] = inpaint_mask + + return return_dict + + + def __len__(self): + return len(self.paths) \ No newline at end of file diff --git a/scripts/basicsr/data/ffhq_blind_joint_dataset.py b/scripts/basicsr/data/ffhq_blind_joint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..214fd1acd5405811f54b6c37ed47814e917e33e3 --- /dev/null +++ b/scripts/basicsr/data/ffhq_blind_joint_dataset.py @@ -0,0 +1,324 @@ +import cv2 +import math +import random +import numpy as np +import os.path as osp +from scipy.io import loadmat +import torch +import torch.utils.data as data +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, + adjust_hue, adjust_saturation, normalize) +from scripts.basicsr.data import gaussian_kernels as gaussian_kernels +from scripts.basicsr.data.transforms import augment +from scripts.basicsr.data.data_util import paths_from_folder +from scripts.basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from scripts.basicsr.utils.registry import DATASET_REGISTRY + +@DATASET_REGISTRY.register() +class FFHQBlindJointDataset(data.Dataset): + + def __init__(self, opt): + super(FFHQBlindJointDataset, self).__init__() + logger = get_root_logger() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.gt_size = opt.get('gt_size', 512) + self.in_size = opt.get('in_size', 512) + assert self.gt_size >= self.in_size, 'Wrong setting.' + + self.mean = opt.get('mean', [0.5, 0.5, 0.5]) + self.std = opt.get('std', [0.5, 0.5, 0.5]) + + self.component_path = opt.get('component_path', None) + self.latent_gt_path = opt.get('latent_gt_path', None) + + if self.component_path is not None: + self.crop_components = True + self.components_dict = torch.load(self.component_path) + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4) + self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1) + self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3) + else: + self.crop_components = False + + if self.latent_gt_path is not None: + self.load_latent_gt = True + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}') + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + self.paths = paths_from_folder(self.gt_folder) + + # perform corrupt + self.use_corrupt = opt.get('use_corrupt', True) + self.use_motion_kernel = False + # self.use_motion_kernel = opt.get('use_motion_kernel', True) + + if self.use_motion_kernel: + self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) + motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth') + self.motion_kernels = torch.load(motion_kernel_path) + + if self.use_corrupt: + # degradation configurations + self.blur_kernel_size = self.opt['blur_kernel_size'] + self.kernel_list = self.opt['kernel_list'] + self.kernel_prob = self.opt['kernel_prob'] + # Small degradation + self.blur_sigma = self.opt['blur_sigma'] + self.downsample_range = self.opt['downsample_range'] + self.noise_range = self.opt['noise_range'] + self.jpeg_range = self.opt['jpeg_range'] + # Large degradation + self.blur_sigma_large = self.opt['blur_sigma_large'] + self.downsample_range_large = self.opt['downsample_range_large'] + self.noise_range_large = self.opt['noise_range_large'] + self.jpeg_range_large = self.opt['jpeg_range_large'] + + # print + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob', None) + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None) + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + + # to gray + self.gray_prob = opt.get('gray_prob', 0.0) + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + + def get_component_locations(self, name, status): + components_bbox = self.components_dict[name] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0] + components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0] + components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0] + + locations_gt = {} + locations_in = {} + for part in ['left_eye', 'right_eye', 'nose', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + elif part == 'nose': + half_len *= self.nose_enlarge_ratio + elif part == 'mouth': + half_len *= self.mouth_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations_gt[part] = loc + loc_in = loc/(self.gt_size//self.in_size) + locations_in[part] = loc_in + return locations_gt, locations_in + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + name = osp.basename(gt_path)[:-4] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + + if self.load_latent_gt: + if status[0]: + latent_gt = self.latent_gt_dict['hflip'][name] + else: + latent_gt = self.latent_gt_dict['orig'][name] + + if self.crop_components: + locations_gt, locations_in = self.get_component_locations(name, status) + + # generate in image + img_in = img_gt + if self.use_corrupt: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in = cv2.filter2D(img_in,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, + [-math.pi, math.pi], + noise_range=None) + img_in = cv2.filter2D(img_in, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) + noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma + img_in = img_in + noise + img_in = np.clip(img_in, 0, 1) + + # jpeg + if self.jpeg_range is not None: + jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p] + _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) + img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + + # generate in_large with large degradation + img_in_large = img_gt + + if self.use_corrupt: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in_large = cv2.filter2D(img_in_large,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma_large, + self.blur_sigma_large, + [-math.pi, math.pi], + noise_range=None) + img_in_large = cv2.filter2D(img_in_large, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1]) + img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range_large is not None: + noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.) + noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma + img_in_large = img_in_large + noise + img_in_large = np.clip(img_in_large, 0, 1) + + # jpeg + if self.jpeg_range_large is not None: + jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p] + _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param) + img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_in = self.color_jitter(img_in, self.color_jitter_shift) + img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) + img_in = np.tile(img_in[:, :, None], [1, 1, 3]) + img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY) + img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue) + + # round and clip + img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255. + + # Set vgg range_norm=True if use the normalization here + # normalize + normalize(img_in, self.mean, self.std, inplace=True) + normalize(img_in_large, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path} + + if self.crop_components: + return_dict['locations_in'] = locations_in + return_dict['locations_gt'] = locations_gt + + if self.load_latent_gt: + return_dict['latent_gt'] = latent_gt + + return return_dict + + + def __len__(self): + return len(self.paths) diff --git a/scripts/basicsr/data/gaussian_kernels.py b/scripts/basicsr/data/gaussian_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce57f0ae52bb4efce9212dd09960ac9c7358c3a --- /dev/null +++ b/scripts/basicsr/data/gaussian_kernels.py @@ -0,0 +1,690 @@ +import math +import numpy as np +import random +from scipy.ndimage.interpolation import shift +from scipy.stats import multivariate_normal + + +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + Returns: + ndarray: Rotated sigma matrix. + """ + D = np.array([[sig_x**2, 0], [0, sig_y**2]]) + U = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + return np.dot(U, np.dot(D, U.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + Args: + kernel_size (int): + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), + yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(D, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + Args: + D (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, D) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None): + """Generate a bivariate skew Gaussian kernel. + Described in `A multivariate skew normal distribution`_ by Shi et. al (2004). + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + D (ndarrasy): skew matrix. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + .. _A multivariate skew normal distribution: + https://www.sciencedirect.com/science/article/pii/S0047259X03001313 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + pdf = pdf2(sigma_matrix, grid) + cdf = cdf2(D, grid) + kernel = pdf * cdf + kernel = kernel / np.sum(kernel) + return kernel + + +def mass_center_shift(kernel_size, kernel): + """Calculate the shift of the mass center of a kenrel. + Args: + kernel_size (int): + kernel (ndarray): normalized kernel. + Returns: + delta_h (float): + delta_w (float): + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1) + delta_h = np.dot(row_sum, ax) + delta_w = np.dot(col_sum, ax) + return delta_h, delta_w + + +def bivariate_skew_Gaussian_center(kernel_size, + sig_x, + sig_y, + theta, + D, + grid=None): + """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + D (ndarrasy): skew matrix. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): centered and normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid) + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest') + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_anisotropic_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + grid=None): + """Generate a bivariate anisotropic Gaussian kernel. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None): + """Generate a bivariate isotropic Gaussian kernel. + Args: + kernel_size (int): + sig (float): + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = np.array([[sig**2, 0], [0, sig**2]]) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + beta, + grid=None): + """Generate a bivariate generalized Gaussian kernel. + Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_ + by Pascal et. al (2013). + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions: + https://arxiv.org/abs/1302.6498 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp( + -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None): + """Generate a plateau-like anisotropic kernel. + 1 / (1+x^(beta)) + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal( + np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None): + """Generate a plateau-like isotropic kernel. + 1 / (1+x^(beta)) + Args: + kernel_size (int): + sig (float): + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = np.array([[sig**2, 0], [0, sig**2]]) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal( + np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_skew_Gaussian_center(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + strict=False): + """Randomly generate bivariate skew Gaussian kernels at center. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + + sigma_max = np.max([sigma_x, sigma_y]) + thres = 3 / sigma_max + D = [[np.random.uniform(-thres, thres), + np.random.uniform(-thres, thres)], + [np.random.uniform(-thres, thres), + np.random.uniform(-thres, thres)]] + + kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y, + rotation, D) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, D + else: + return kernel + + +def random_bivariate_anisotropic_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + strict=False): + """Randomly generate bivariate anisotropic Gaussian kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + + kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y, + rotation) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation + else: + return kernel + + +def random_bivariate_isotropic_Gaussian(kernel_size, + sigma_range, + noise_range=None, + strict=False): + """Randomly generate bivariate isotropic Gaussian kernels. + Args: + kernel_size (int): + sigma_range (tuple): [0.6, 5] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.' + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + + kernel = bivariate_isotropic_Gaussian(kernel_size, sigma) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma + else: + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate generalized Gaussian kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, + rotation, beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, beta + else: + return kernel + + +def random_bivariate_plateau_type1(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate plateau type1 kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation, + beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, beta + else: + return kernel + + +def random_bivariate_plateau_type1_iso(kernel_size, + sigma_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate plateau type1 kernels (iso). + Args: + kernel_size (int): + sigma_range (tuple): [0.6, 5] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.' + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + beta = np.random.uniform(beta_range[0], beta_range[1]) + + kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma, beta + else: + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=[0.6, 5], + sigma_y_range=[0.6, 5], + rotation_range=[-math.pi, math.pi], + beta_range=[0.5, 8], + noise_range=None): + """Randomly generate mixed kernels. + Args: + kernel_list (tuple): a list name of kenrel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_isotropic_Gaussian( + kernel_size, sigma_x_range, noise_range=noise_range) + elif kernel_type == 'aniso': + kernel = random_bivariate_anisotropic_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range) + elif kernel_type == 'skew': + kernel = random_bivariate_skew_Gaussian_center( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range) + elif kernel_type == 'generalized': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=noise_range) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau_type1_iso( + kernel_size, sigma_x_range, beta_range, noise_range=noise_range) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau_type1( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=noise_range) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def show_one_kernel(): + import matplotlib.pyplot as plt + kernel_size = 21 + + # bivariate skew Gaussian + D = [[0, 0], [0, 0]] + D = [[3 / 4, 0], [0, 0.5]] + kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D) + # bivariate anisotropic Gaussian + kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4) + # bivariate anisotropic Gaussian + kernel = bivariate_isotropic_Gaussian(kernel_size, 1) + # bivariate generalized Gaussian + kernel = bivariate_generalized_Gaussian( + kernel_size, 2, 4, -math.pi / 4, beta=4) + + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + print(delta_h, delta_w) + + fig, axs = plt.subplots(nrows=2, ncols=2) + # axs.set_axis_off() + ax = axs[0][0] + im = ax.matshow(kernel, cmap='jet', origin='upper') + fig.colorbar(im, ax=ax) + + # image + ax = axs[0][1] + kernel_vis = kernel - np.min(kernel) + kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + ax.imshow(kernel_vis, interpolation='nearest') + + _, xx, yy = mesh_grid(kernel_size) + # contour + ax = axs[1][0] + CS = ax.contour(xx, yy, kernel, origin='upper') + ax.clabel(CS, inline=1, fontsize=3) + + # contourf + ax = axs[1][1] + kernel = kernel / np.max(kernel) + p = ax.contourf( + xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + fig.colorbar(p) + + plt.show() + + +def show_plateau_kernel(): + import matplotlib.pyplot as plt + kernel_size = 21 + + kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None) + kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5) + kernel_gau = bivariate_generalized_Gaussian( + kernel_size, 2, 4, -math.pi / 8, 2, grid=None) + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + print(delta_h, delta_w) + + # kernel_slice = kernel[10, :] + # kernel_gau_slice = kernel_gau[10, :] + # kernel_norm_slice = kernel_norm[10, :] + # fig, ax = plt.subplots() + # t = list(range(1, 22)) + + # ax.plot(t, kernel_gau_slice) + # ax.plot(t, kernel_slice) + # ax.plot(t, kernel_norm_slice) + + # t = np.arange(0, 10, 0.1) + # y = np.exp(-0.5 * t) + # y2 = np.reciprocal(1 + t) + # print(t.shape) + # print(y.shape) + # ax.plot(t, y) + # ax.plot(t, y2) + # plt.show() + + fig, axs = plt.subplots(nrows=2, ncols=2) + # axs.set_axis_off() + ax = axs[0][0] + im = ax.matshow(kernel, cmap='jet', origin='upper') + fig.colorbar(im, ax=ax) + + # image + ax = axs[0][1] + kernel_vis = kernel - np.min(kernel) + kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + ax.imshow(kernel_vis, interpolation='nearest') + + _, xx, yy = mesh_grid(kernel_size) + # contour + ax = axs[1][0] + CS = ax.contour(xx, yy, kernel, origin='upper') + ax.clabel(CS, inline=1, fontsize=3) + + # contourf + ax = axs[1][1] + kernel = kernel / np.max(kernel) + p = ax.contourf( + xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + fig.colorbar(p) + + plt.show() diff --git a/scripts/basicsr/data/paired_image_dataset.py b/scripts/basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cdaab1e22f2111a155b7e10a7c0fc7b635670eb6 --- /dev/null +++ b/scripts/basicsr/data/paired_image_dataset.py @@ -0,0 +1,101 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from scripts.basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file +from scripts.basicsr.data.transforms import augment, paired_random_crop +from scripts.basicsr.utils import FileClient, imfrombytes, img2tensor +from scripts.basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class PairedImageDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and + GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info_file': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_flip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) + + # TODO: color space transform + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/scripts/basicsr/data/prefetch_dataloader.py b/scripts/basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/scripts/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/scripts/basicsr/data/transforms.py b/scripts/basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..aead9dc73ed063e1c5865040eaa2652b26aa3ad3 --- /dev/null +++ b/scripts/basicsr/data/transforms.py @@ -0,0 +1,165 @@ +import cv2 +import random + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/scripts/basicsr/losses/__init__.py b/scripts/basicsr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69358acf433ed5860771992bff74ce6eb9cd2a9d --- /dev/null +++ b/scripts/basicsr/losses/__init__.py @@ -0,0 +1,26 @@ +from copy import deepcopy + +from scripts.basicsr.utils import get_root_logger +from scripts.basicsr.utils.registry import LOSS_REGISTRY +from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, + gradient_penalty_loss, r1_penalty) + +__all__ = [ + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', + 'r1_penalty', 'g_path_regularize' +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/scripts/basicsr/losses/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e4bfec7d4fadda6f92aa4e3b646cd4c1b8e386 Binary files /dev/null and b/scripts/basicsr/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/losses/__pycache__/loss_util.cpython-310.pyc b/scripts/basicsr/losses/__pycache__/loss_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc1de1b772ea11402263b5b3a46cddcbe8d28a9 Binary files /dev/null and b/scripts/basicsr/losses/__pycache__/loss_util.cpython-310.pyc differ diff --git a/scripts/basicsr/losses/__pycache__/losses.cpython-310.pyc b/scripts/basicsr/losses/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0a7dc3c2a350cd221d5cb407c5101749c91a0c Binary files /dev/null and b/scripts/basicsr/losses/__pycache__/losses.cpython-310.pyc differ diff --git a/scripts/basicsr/losses/loss_util.py b/scripts/basicsr/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698 --- /dev/null +++ b/scripts/basicsr/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/scripts/basicsr/losses/losses.py b/scripts/basicsr/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..15320b301f3d278c0f773eb2b5106a99a6d5309b --- /dev/null +++ b/scripts/basicsr/losses/losses.py @@ -0,0 +1,455 @@ +import math +import lpips +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from scripts.basicsr.archs.vgg_arch import VGGFeatureExtractor +from scripts.basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. + Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight) + + def forward(self, pred, weight=None): + y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :]) + x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1]) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'mse': + self.criterion = torch.nn.MSELoss(reduction='mean') + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class LPIPSLoss(nn.Module): + def __init__(self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False,): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean() + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/scripts/basicsr/metrics/__init__.py b/scripts/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e59370f599457a95357b2e8ad08b8052bfda98e7 --- /dev/null +++ b/scripts/basicsr/metrics/__init__.py @@ -0,0 +1,19 @@ +from copy import deepcopy + +from scripts.basicsr.utils.registry import METRIC_REGISTRY +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/scripts/basicsr/metrics/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/metrics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..146691a555c28deb4639790f7699abef6712feba Binary files /dev/null and b/scripts/basicsr/metrics/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc b/scripts/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8209a5bce45ce753718483f01778df60c34f712b Binary files /dev/null and b/scripts/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc differ diff --git a/scripts/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc b/scripts/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de4a586a8800e8b32af318553e1bb36cf25b3a47 Binary files /dev/null and b/scripts/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc differ diff --git a/scripts/basicsr/metrics/metric_util.py b/scripts/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe6c079c8b2c135e4a94b9746a7e3229981b57f --- /dev/null +++ b/scripts/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from scripts.basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/scripts/basicsr/metrics/psnr_ssim.py b/scripts/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..9781cb82594574fe8d49bb7bde6b7a05b092a3e4 --- /dev/null +++ b/scripts/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from scripts.basicsr.metrics.metric_util import reorder_image, to_y_channel +from scripts.basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/scripts/basicsr/models/__init__.py b/scripts/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b65346d6805516aee3b4d0ddd37278f3af7fd8ac --- /dev/null +++ b/scripts/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from scripts.basicsr.utils import get_root_logger, scandir +from scripts.basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'scripts.basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must constain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/scripts/basicsr/models/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5b137ab7c14d25462cb09b85a1459cf71ad6deb Binary files /dev/null and b/scripts/basicsr/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/base_model.cpython-310.pyc b/scripts/basicsr/models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..870a2915aaa89a9f7678013698bdad4e2b31b722 Binary files /dev/null and b/scripts/basicsr/models/__pycache__/base_model.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/codeformer_idx_model.cpython-310.pyc b/scripts/basicsr/models/__pycache__/codeformer_idx_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68692dc7c10f8ad50378a01b5c5f0d7ac30b2a4e Binary files /dev/null and b/scripts/basicsr/models/__pycache__/codeformer_idx_model.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/codeformer_joint_model.cpython-310.pyc b/scripts/basicsr/models/__pycache__/codeformer_joint_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b43535fc36898a58b964a0e14e08169764e43fd3 Binary files /dev/null and b/scripts/basicsr/models/__pycache__/codeformer_joint_model.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/codeformer_model.cpython-310.pyc b/scripts/basicsr/models/__pycache__/codeformer_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce6703293cbb382015eed48070372a4ca20465f2 Binary files /dev/null and b/scripts/basicsr/models/__pycache__/codeformer_model.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc b/scripts/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82443ad55c79a16f1da10e4fd4fa849e4988ac5f Binary files /dev/null and b/scripts/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/sr_model.cpython-310.pyc b/scripts/basicsr/models/__pycache__/sr_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a893e92b40a2964813c47be8d27c43223f353422 Binary files /dev/null and b/scripts/basicsr/models/__pycache__/sr_model.cpython-310.pyc differ diff --git a/scripts/basicsr/models/__pycache__/vqgan_model.cpython-310.pyc b/scripts/basicsr/models/__pycache__/vqgan_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752e3c74a7ef678770dc65b5cf99a2581dcceb17 Binary files /dev/null and b/scripts/basicsr/models/__pycache__/vqgan_model.cpython-310.pyc differ diff --git a/scripts/basicsr/models/base_model.py b/scripts/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e20fba25ff92dd76983593f5fa565fe80eb6d028 --- /dev/null +++ b/scripts/basicsr/models/base_model.py @@ -0,0 +1,322 @@ +import logging +import os +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from scripts.basicsr.models import lr_scheduler as lr_scheduler +from scripts.basicsr.utils.dist_util import master_only + +logger = logging.getLogger('basicsr') + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}') + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + torch.save(save_dict, save_path) + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with differnet name or different size when loading models. + + 1. Print keys with differnet names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + net = self.get_bare_model(net) + logger.info(f'Loading {net.__class__.__name__} model from {load_path}.') + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/scripts/basicsr/models/codeformer_idx_model.py b/scripts/basicsr/models/codeformer_idx_model.py new file mode 100644 index 0000000000000000000000000000000000000000..63a65fa750135aed2692c643f3a4834cfd0b6db6 --- /dev/null +++ b/scripts/basicsr/models/codeformer_idx_model.py @@ -0,0 +1,220 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from scripts.basicsr.archs import build_network +from scripts.basicsr.metrics import calculate_metric +from scripts.basicsr.utils import get_root_logger, imwrite, tensor2img +from scripts.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class CodeFormerIdxModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.b = self.gt.shape[0] + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt['datasets']['train'].get('latent_gt_path', None) is not None: + self.generate_idx_gt = False + elif self.opt.get('network_vqgan', None) is not None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.') + + logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}') + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + + self.net_g.train() + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = self.hq_vqgan_fix.encoder(self.gt) + _, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + self.idx_gt = min_encoding_indices.view(self.b, -1) + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + + logits, lq_feat = self.net_g(self.input, w=0, code_only=True) + + l_g_total = 0 + loss_dict = OrderedDict() + # hq_feat_loss + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=0) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=0) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/scripts/basicsr/models/codeformer_joint_model.py b/scripts/basicsr/models/codeformer_joint_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2d3159a490e8ce438600ee1af6d6f4b3d67a6e --- /dev/null +++ b/scripts/basicsr/models/codeformer_joint_model.py @@ -0,0 +1,350 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + + +from scripts.basicsr.archs import build_network +from scripts.basicsr.losses import build_loss +from scripts.basicsr.metrics import calculate_metric +from scripts.basicsr.utils import get_root_logger, imwrite, tensor2img +from scripts.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class CodeFormerJointModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.input_large_de = data['in_large_de'].to(self.device) + self.b = self.gt.shape[0] + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt['datasets']['train'].get('latent_gt_path', None) is not None: + self.generate_idx_gt = False + elif self.opt.get('network_vqgan', None) is not None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.') + + logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}') + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8) + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + + self.fix_generator = train_opt.get('fix_generator', True) + logger.info(f'fix_generator: {self.fix_generator}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = self.hq_vqgan_fix.encoder(self.gt) + output, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + self.idx_gt = min_encoding_indices.view(self.b, -1) + + if current_iter <= 40000: # small degradation + small_per_n = 1 + w = 1 + elif current_iter <= 80000: # small degradation + small_per_n = 1 + w = 1.3 + elif current_iter <= 120000: # large degradation + small_per_n = 120000 + w = 0 + else: # mixed degradation + small_per_n = 15 + w = 1.3 + + if current_iter % small_per_n == 0: + self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True) + large_de = False + else: + logits, lq_feat = self.net_g(self.input_large_de, code_only=True) + large_de = True + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + + l_g_total = 0 + loss_dict = OrderedDict() + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # hq_feat_loss + if not 'transformer' in self.opt['network_g']['fix_modules']: + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + # pixel loss + if not large_de: # when large degradation don't need image-level loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_pix + l_g_percep + if not self.fix_generator: + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + else: + largest_fuse_size = self.opt['network_g']['connect_list'][-1] + last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + + d_weight *= self.scale_adaptive_gan_weight # 0.8 + loss_dict['d_weight'] = d_weight + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + # optimize net_d + if not large_de: + if current_iter > self.net_d_start_iter: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=1) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=1) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/scripts/basicsr/models/codeformer_model.py b/scripts/basicsr/models/codeformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b83feebe34c62d57022f6c7d0f5447bcc478ebd1 --- /dev/null +++ b/scripts/basicsr/models/codeformer_model.py @@ -0,0 +1,332 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from scripts.basicsr.archs import build_network +from scripts.basicsr.losses import build_loss +from scripts.basicsr.metrics import calculate_metric +from scripts.basicsr.utils import get_root_logger, imwrite, tensor2img +from scripts.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class CodeFormerModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.b = self.gt.shape[0] + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + self.generate_idx_gt = False + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + self.fidelity_weight = train_opt.get('fidelity_weight', 1.0) + self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8) + + + self.net_g.train() + # define network net_d + if self.fidelity_weight > 0: + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + + self.fix_generator = train_opt.get('fix_generator', True) + logger.info(f'fix_generator: {self.fix_generator}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + if self.fidelity_weight > 0: + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = self.hq_vqgan_fix.encoder(self.gt) + output, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + self.idx_gt = min_encoding_indices.view(self.b, -1) + + if self.fidelity_weight > 0: + self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True) + else: + logits, lq_feat = self.net_g(self.input, w=0, code_only=True) + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256]) + + l_g_total = 0 + loss_dict = OrderedDict() + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # hq_feat_loss + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_pix + l_g_percep + if not self.fix_generator: + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + else: + largest_fuse_size = self.opt['network_g']['connect_list'][-1] + last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + + d_weight *= self.scale_adaptive_gan_weight # 0.8 + loss_dict['d_weight'] = d_weight + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + # optimize net_d + if current_iter > self.net_d_start_iter and self.fidelity_weight > 0: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + if self.fidelity_weight > 0: + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/scripts/basicsr/models/lr_scheduler.py b/scripts/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a423ce656c044ed5861a0056020074a517f04156 --- /dev/null +++ b/scripts/basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/scripts/basicsr/models/sr_model.py b/scripts/basicsr/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..31a2b92ec5b4fdf1cc829076d9666f0544737dfa --- /dev/null +++ b/scripts/basicsr/models/sr_model.py @@ -0,0 +1,209 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from scripts.basicsr.archs import build_network +from scripts.basicsr.losses import build_loss +from scripts.basicsr.metrics import calculate_metric +from scripts.basicsr.utils import get_root_logger, imwrite, tensor2img +from scripts.basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'ema_decay'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'ema_decay'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/scripts/basicsr/models/vqgan_model.py b/scripts/basicsr/models/vqgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..519d3ce2ff9d3be8fc56f43194ed3608cb7f08f2 --- /dev/null +++ b/scripts/basicsr/models/vqgan_model.py @@ -0,0 +1,285 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from scripts.basicsr.archs import build_network +from scripts.basicsr.losses import build_loss +from scripts.basicsr.metrics import calculate_metric +from scripts.basicsr.utils import get_root_logger, imwrite, tensor2img +from scripts.basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class VQGANModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.b = self.gt.shape[0] + + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + if train_opt.get('codebook_opt'): + self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0) + else: + self.l_weight_codebook = 1.0 + + self.vqgan_quantizer = self.opt['network_g']['quantizer'] + logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + self.disc_weight = train_opt.get('disc_weight', 0.8) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def adopt_weight(self, weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + loss_dict = OrderedDict() + if self.opt['network_g']['quantizer'] == 'gumbel': + self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1) + if current_iter%1000 == 0: + logger.info(f'temperature: {self.net_g.module.quantize.temperature}') + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output, l_codebook, quant_stats = self.net_g(self.gt) + + l_codebook = l_codebook*self.l_weight_codebook + + l_g_total = 0 + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + # fake_g_pred = self.net_d(self.output_1024) + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_total + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter) + d_weight *= self.disc_weight # tamming setting 0.8 + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total += l_codebook + loss_dict['l_codebook'] = l_codebook + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + if current_iter > self.net_d_start_iter: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.gt) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.gt) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/scripts/basicsr/ops/__init__.py b/scripts/basicsr/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/basicsr/ops/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..110be97fc55b7206e629162b710875da109aaf4b Binary files /dev/null and b/scripts/basicsr/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/ops/dcn/__init__.py b/scripts/basicsr/ops/dcn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff --- /dev/null +++ b/scripts/basicsr/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/scripts/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c18b3404b06ab17c86e0802215125ac24faabb7 Binary files /dev/null and b/scripts/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc b/scripts/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d01c180909f97701358f0b249b3a0557abc783fa Binary files /dev/null and b/scripts/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc differ diff --git a/scripts/basicsr/ops/dcn/deform_conv.py b/scripts/basicsr/ops/dcn/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..734154f9ed9447d585eae7df6886acb136f8a3cf --- /dev/null +++ b/scripts/basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,377 @@ +import math +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +try: + from . import deform_conv_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, \ + f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} is not divisible ' \ + f'by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/scripts/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/scripts/basicsr/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d9424908ed2dbd4ac3cdb98d13e09287a4d2f2d --- /dev/null +++ b/scripts/basicsr/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/scripts/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/scripts/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..98752dccf8c58817ca1a952554dd3f33188a2d34 --- /dev/null +++ b/scripts/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/scripts/basicsr/ops/dcn/src/deform_conv_ext.cpp b/scripts/basicsr/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41c6df6f721bd95a525fd6a03dd9882e863de042 --- /dev/null +++ b/scripts/basicsr/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/scripts/basicsr/ops/fused_act/__init__.py b/scripts/basicsr/ops/fused_act/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422 --- /dev/null +++ b/scripts/basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/scripts/basicsr/ops/fused_act/fused_act.py b/scripts/basicsr/ops/fused_act/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..588f815e596ab0fc83ab0f9d21426c22ec5ed7c3 --- /dev/null +++ b/scripts/basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,89 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import torch +from torch import nn +from torch.autograd import Function + +try: + from . import fused_act_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/scripts/basicsr/ops/fused_act/src/fused_bias_act.cpp b/scripts/basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85ed0a79fb9c75f83470ac834090f03608d998ee --- /dev/null +++ b/scripts/basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/scripts/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/scripts/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c7ff53ce8306db2b3c582ec7fa6696a38b4df0 --- /dev/null +++ b/scripts/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/scripts/basicsr/ops/upfirdn2d/__init__.py b/scripts/basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd --- /dev/null +++ b/scripts/basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/scripts/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/scripts/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43d0b6783a5b512b55815a291fcac2bebeea31e0 --- /dev/null +++ b/scripts/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/scripts/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/scripts/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8870063bae4468deab2e721f0978fe9facfb01b1 --- /dev/null +++ b/scripts/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/scripts/basicsr/ops/upfirdn2d/upfirdn2d.py b/scripts/basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..667f96e1ded35d48f163f37e21d1ed8ff191aac3 --- /dev/null +++ b/scripts/basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,186 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +try: + from . import upfirdn2d_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/scripts/basicsr/setup.py b/scripts/basicsr/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b24d0450a016f2f36d634c4f79516e8137639cb8 --- /dev/null +++ b/scripts/basicsr/setup.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import sys +import time +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from utils.misc import gpu_is_available + +version_file = './basicsr/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('./basicsr/VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {'cxx': []} + + # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda + else: + print(f'Compiling {name} without CUDA') + extension = CppExtension + + return extension( + name=f'{module}.{name}', + sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args) + + +def get_requirements(filename='requirements.txt'): + with open(os.path.join('.', filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + if '--cuda_ext' in sys.argv: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), + make_cuda_ext( + name='fused_act_ext', + module='ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + sys.argv.remove('--cuda_ext') + else: + ext_modules = [] + + write_version_py() + setup( + name='basicsr', + version=get_version(), + description='Open Source Image and Video Super-Resolution Toolbox', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, restoration, super resolution', + url='https://github.com/xinntao/BasicSR', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + zip_safe=False) diff --git a/scripts/basicsr/train.py b/scripts/basicsr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5e31b6249ef9c027c05a2f9d708205ffc091f2 --- /dev/null +++ b/scripts/basicsr/train.py @@ -0,0 +1,225 @@ +import argparse +import datetime +import logging +import math +import copy +import random +import time +import torch +from os import path as osp + +from scripts.basicsr.data import build_dataloader, build_dataset +from scripts.basicsr.data.data_sampler import EnlargedSampler +from scripts.basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from scripts.basicsr.models import build_model +from scripts.basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) +from scripts.basicsr.utils.dist_util import get_dist_info, init_dist +from scripts.basicsr.utils.options import dict2str, parse + +import warnings +# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. +warnings.filterwarnings("ignore", category=UserWarning) + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, root_path, is_train=is_train) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger'): + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loader = None, None + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + + elif phase == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def train_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(root_path, is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = build_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + model = build_model(opt) + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}') + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_time = time.time() - iter_time + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_time, 'data_time': data_time}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and opt['datasets'].get('val') is not None \ + and (current_iter % opt['val']['val_freq'] == 0): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None and opt['datasets'].get('val'): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/scripts/basicsr/utils/__init__.py b/scripts/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcc1d540462712387523d1e326d1dfc2bcfbf32 --- /dev/null +++ b/scripts/basicsr/utils/__init__.py @@ -0,0 +1,29 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt' +] diff --git a/scripts/basicsr/utils/__pycache__/__init__.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d34aa7b17ef01f4f06ab110c6fd4c63bfe7d9f Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/dist_util.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/dist_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2386eedfa7912220e57f91e31e92a2975fea81b Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/dist_util.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/download_util.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/download_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fedeee1e6b41bdd47d2ef70dd7e30127347cd27 Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/download_util.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/file_client.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/file_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba90c8cfc131de6e19035a29221fadd1cee092c4 Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/file_client.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/img_util.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/img_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb13406d8c37c21681f5c83bb4bff8430843398b Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/img_util.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/logger.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0144321b6463def4b58462e354c3f1f621fdc93e Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/logger.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c027ed472e8d93245be4027771548234f5cf9152 Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/misc.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c463deb1cf0c765a863d6a52c883ed7c94f45c Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/options.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..000ed1dd767d100500b3c580b788165f006e65e0 Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/options.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/realesrgan_utils.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/realesrgan_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87f8955cddebeee8fb616fb17045660d0b853153 Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/realesrgan_utils.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/__pycache__/registry.cpython-310.pyc b/scripts/basicsr/utils/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6634177a2e26efdfff7aad55fdc5ef1263ec89f7 Binary files /dev/null and b/scripts/basicsr/utils/__pycache__/registry.cpython-310.pyc differ diff --git a/scripts/basicsr/utils/dist_util.py b/scripts/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0 --- /dev/null +++ b/scripts/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/scripts/basicsr/utils/download_util.py b/scripts/basicsr/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2a267915743ee3f3232bc8fe992466b52468979a --- /dev/null +++ b/scripts/basicsr/utils/download_util.py @@ -0,0 +1,95 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/scripts/basicsr/utils/file_client.py b/scripts/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..7f38d9796da3899048924f2f803d1088927966b0 --- /dev/null +++ b/scripts/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/scripts/basicsr/utils/img_util.py b/scripts/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5aba82ce08eefaeb3e56ea5a3a09c342ae513522 --- /dev/null +++ b/scripts/basicsr/utils/img_util.py @@ -0,0 +1,171 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] + \ No newline at end of file diff --git a/scripts/basicsr/utils/lmdb_util.py b/scripts/basicsr/utils/lmdb_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a10f60ffca2e36ac5f5564aafd70e79d06a723 --- /dev/null +++ b/scripts/basicsr/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/scripts/basicsr/utils/logger.py b/scripts/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f692d3c7dbaaf2aa40890a02f6dbd5eae8d0c5a1 --- /dev/null +++ b/scripts/basicsr/utils/logger.py @@ -0,0 +1,169 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class MessageLogger(): + """Message logger for printing. + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger: + # if k.startswith('l_'): + # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + # else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = logging.getLogger('basicsr') + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + Currently, only log the software version. + """ + import torch + import torchvision + + from scripts.basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg \ No newline at end of file diff --git a/scripts/basicsr/utils/matlab_functions.py b/scripts/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ce1004a2c9f8521505c4b5889d3c24a909c70d --- /dev/null +++ b/scripts/basicsr/utils/matlab_functions.py @@ -0,0 +1,347 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/scripts/basicsr/utils/misc.py b/scripts/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f425d68ebede7ee1c79858ac3c8b05c6f90b74c5 --- /dev/null +++ b/scripts/basicsr/utils/misc.py @@ -0,0 +1,157 @@ +import os +import re +import random +import time +import torch +import numpy as np +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 12, 0] + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (basename + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/scripts/basicsr/utils/options.py b/scripts/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..5e29b1a40ee1a517d89b8508ddc056aed51e4550 --- /dev/null +++ b/scripts/basicsr/utils/options.py @@ -0,0 +1,108 @@ +import yaml +import time +from collections import OrderedDict +from os import path as osp +from scripts.basicsr.utils.misc import get_time_str + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, root_path, is_train=True): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path, mode='r') as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt['is_train'] = is_train + + # opt['name'] = f"{get_time_str()}_{opt['name']}" + if opt['path'].get('resume_state', None): # Shangchen added + resume_state_path = opt['path'].get('resume_state') + opt['name'] = resume_state_path.split("/")[-3] + else: + opt['name'] = f"{get_time_str()}_{opt['name']}" + + + # datasets + for phase, dataset in opt['datasets'].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg diff --git a/scripts/basicsr/utils/realesrgan_utils.py b/scripts/basicsr/utils/realesrgan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe4c6e49b8ebd1e0e8a4916283eb2970fa09369 --- /dev/null +++ b/scripts/basicsr/utils/realesrgan_utils.py @@ -0,0 +1,302 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from torch.nn import functional as F +from scripts.basicsr.utils.download_util import load_file_from_url +from scripts.basicsr.utils.misc import get_device + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + # if gpu_id: + # self.device = torch.device( + # f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + # else: + # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + + self.device = get_device(gpu_id) if device is None else device + + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + try: + with torch.no_grad(): + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img_t = self.post_process() + output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + del output_img_t + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for RealESRGAN: {error}") + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') \ No newline at end of file diff --git a/scripts/basicsr/utils/registry.py b/scripts/basicsr/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827 --- /dev/null +++ b/scripts/basicsr/utils/registry.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/scripts/basicsr/utils/video_util.py b/scripts/basicsr/utils/video_util.py new file mode 100644 index 0000000000000000000000000000000000000000..20a2ff14c4016b4ec543051471fc930ad71d83f9 --- /dev/null +++ b/scripts/basicsr/utils/video_util.py @@ -0,0 +1,125 @@ +''' +The code is modified from the Real-ESRGAN: +https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py + +''' +import cv2 +import sys +import numpy as np + +try: + import ffmpeg +except ImportError: + import pip + pip.main(['install', '--user', 'ffmpeg-python']) + import ffmpeg + +def get_video_meta_info(video_path): + ret = {} + probe = ffmpeg.probe(video_path) + video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] + has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) + ret['width'] = video_streams[0]['width'] + ret['height'] = video_streams[0]['height'] + ret['fps'] = eval(video_streams[0]['avg_frame_rate']) + ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None + ret['nb_frames'] = int(video_streams[0]['nb_frames']) + return ret + +class VideoReader: + def __init__(self, video_path): + self.paths = [] # for image&folder type + self.audio = None + try: + self.stream_reader = ( + ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', + loglevel='error').run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + except FileNotFoundError: + print('Please install ffmpeg (not ffmpeg-python) by running\n', + '\t$ conda install -c conda-forge ffmpeg') + sys.exit(0) + + meta = get_video_meta_info(video_path) + self.width = meta['width'] + self.height = meta['height'] + self.input_fps = meta['fps'] + self.audio = meta['audio'] + self.nb_frames = meta['nb_frames'] + + self.idx = 0 + + def get_resolution(self): + return self.height, self.width + + def get_fps(self): + if self.input_fps is not None: + return self.input_fps + return 24 + + def get_audio(self): + return self.audio + + def __len__(self): + return self.nb_frames + + def get_frame_from_stream(self): + img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel + if not img_bytes: + return None + img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) + return img + + def get_frame_from_list(self): + if self.idx >= self.nb_frames: + return None + img = cv2.imread(self.paths[self.idx]) + self.idx += 1 + return img + + def get_frame(self): + return self.get_frame_from_stream() + + + def close(self): + self.stream_reader.stdin.close() + self.stream_reader.wait() + + +class VideoWriter: + def __init__(self, video_save_path, height, width, fps, audio): + if height > 2160: + print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', + 'We highly recommend to decrease the outscale(aka, -s).') + if audio is not None: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + audio, + video_save_path, + pix_fmt='yuv420p', + vcodec='libx264', + loglevel='error', + acodec='copy').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + else: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + video_save_path, pix_fmt='yuv420p', vcodec='libx264', + loglevel='error').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + + def write_frame(self, frame): + try: + frame = frame.astype(np.uint8).tobytes() + self.stream_writer.stdin.write(frame) + except BrokenPipeError: + print('Please re-install ffmpeg and libx264 by running\n', + '\t$ conda install -c conda-forge ffmpeg\n', + '\t$ conda install -c conda-forge x264') + sys.exit(0) + + def close(self): + self.stream_writer.stdin.close() + self.stream_writer.wait() \ No newline at end of file diff --git a/scripts/basicsr/version.py b/scripts/basicsr/version.py new file mode 100644 index 0000000000000000000000000000000000000000..773b6c85a224e0c5d2753a20e5a74204f5da59cb --- /dev/null +++ b/scripts/basicsr/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Wed Apr 3 11:07:36 2024 +__version__ = '1.3.2' +__gitsha__ = '8392d03' +version_info = (1, 3, 2) diff --git a/scripts/briarmbg.py b/scripts/briarmbg.py new file mode 100644 index 0000000000000000000000000000000000000000..183dca77dbb42bfed714af12022d6ed09eef9035 --- /dev/null +++ b/scripts/briarmbg.py @@ -0,0 +1,456 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin + +class REBNCONV(nn.Module): + def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): + super(REBNCONV,self).__init__() + + self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self,x): + + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src,tar): + + src = F.interpolate(src,size=tar.shape[2:],mode='bilinear') + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7,self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) + hx6dup = _upsample_like(hx6d,hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + + hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-5 ### +class RSU5(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4 ### +class RSU4(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4F ### +class RSU4F(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) + hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__(self, in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1): + super(myrebnconv,self).__init__() + + self.conv = nn.Conv2d(in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self,x): + return self.rl(self.bn(self.conv(x))) + + +class BriaRMBG(nn.Module, PyTorchModelHubMixin): + + def __init__(self,config:dict={"in_ch":3,"out_ch":1}): + super(BriaRMBG,self).__init__() + in_ch=config["in_ch"] + out_ch=config["out_ch"] + self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) + self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage1 = RSU7(64,32,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,32,128) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(128,64,256) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(256,128,512) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(512,256,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,256,512) + + # decoder + self.stage5d = RSU4F(1024,256,512) + self.stage4d = RSU4(1024,128,256) + self.stage3d = RSU5(512,64,128) + self.stage2d = RSU6(256,32,64) + self.stage1d = RSU7(128,16,64) + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def forward(self,x): + + hx = x + + hxin = self.conv_in(hx) + #hx = self.pool_in(hxin) + + #stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6,hx5) + + #-------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) + + + #side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1,x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2,x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3,x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4,x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5,x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,x) + + return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] \ No newline at end of file diff --git a/scripts/colorizer.py b/scripts/colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..91eb9df39719aff2c90fc455ff224f09b828acb7 --- /dev/null +++ b/scripts/colorizer.py @@ -0,0 +1,36 @@ +import cv2 +import numpy as np +import onnxruntime + +def colorize_image(input_image): + + ort_session = onnxruntime.InferenceSession('models/deoldify_artistic.onnx') + input_name = ort_session.get_inputs()[0].name + output_name = ort_session.get_outputs()[0].name + + # Preprocess image + temp_frame = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) + temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) + temp_frame = cv2.resize(temp_frame, (256, 256)) + temp_frame = temp_frame.transpose((2, 0, 1)) + temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32) + + # Run inference + ort_outs = ort_session.run([output_name], {input_name: temp_frame}) + result = ort_outs[0][0] + + # Postprocess result + colorized_frame = result.transpose(1, 2, 0) + colorized_frame = cv2.resize(colorized_frame, (input_image.shape[1], input_image.shape[0])) + temp_blue_channel, _, _ = cv2.split(input_image) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB) + _, color_green_channel, color_red_channel = cv2.split(colorized_frame) + colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel)) + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR) + + return colorized_frame.astype(np.uint8) + +if __name__ == "__main__": + pass + diff --git a/scripts/common.py b/scripts/common.py new file mode 100644 index 0000000000000000000000000000000000000000..002901428cedc2b9556c92a63b1aec63b22190c4 --- /dev/null +++ b/scripts/common.py @@ -0,0 +1,47 @@ +from PIL import Image +import matplotlib.pyplot as plt + + +# Log images +def log_image(x, opts): + return tensor2im(x) + + +def tensor2im(var): + var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() + var = ((var + 1) / 2) + var[var < 0] = 0 + var[var > 1] = 1 + var = var * 255 + return Image.fromarray(var.astype('uint8')) + + +def vis_faces(log_hooks): + display_count = len(log_hooks) + fig = plt.figure(figsize=(12, 4 * display_count)) + gs = fig.add_gridspec(display_count, 4) + for i in range(display_count): + hooks_dict = log_hooks[i] + vis_faces_with_age(hooks_dict, fig, gs, i) + plt.tight_layout() + return fig + + +def vis_faces_with_age(hooks_dict, fig, gs, i): + fig.add_subplot(gs[i, 0]) + plt.imshow(hooks_dict['input_face']) + plt.title('Input\nOut Sim={:.2f}\nInput Age={:.2f}'.format(float(hooks_dict['diff_input_real']), + float(hooks_dict['input_age_real']))) + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target\nIn={:.2f},Out={:.2f}\nTarget Age={:.2f}'.format(float(hooks_dict['diff_views_real']), + float(hooks_dict['diff_target_real']), + float(hooks_dict['target_age_real']))) + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output\nTarget Sim={:.2f}\nOuput Age={:.2f}'.format(float(hooks_dict['diff_target_real']), + float(hooks_dict['output_age_real']))) + fig.add_subplot(gs[i, 3]) + plt.imshow(hooks_dict['recovered_face']) + plt.title('Recovered\nTarget Sim={:.2f}\nOuput Age={:.2f}'.format(float(hooks_dict['diff_target_cycle']), + float(hooks_dict['output_age_cycle']))) diff --git a/scripts/config/auido2exp.yaml b/scripts/config/auido2exp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7369dbf350476e14a1d600507f1f8b7d8aa6ecd3 --- /dev/null +++ b/scripts/config/auido2exp.yaml @@ -0,0 +1,58 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt + TRAIN_BATCH_SIZE: 32 + EVAL_BATCH_SIZE: 32 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm + LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + NUM_REPEATS: 2 + T: 40 + + +MODEL: + FRAMEWORK: V2 + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 128 + SEQ_LEN: 32 + LATENT_SIZE: 256 + ENCODER_LAYER_SIZES: [192, 1024] + DECODER_LAYER_SIZES: [1024, 192] + + +TRAIN: + MAX_EPOCH: 300 + GENERATOR: + LR: 2.0e-5 + DISCRIMINATOR: + LR: 1.0e-5 + LOSS: + W_FEAT: 0 + W_COEFF_EXP: 2 + W_LM: 1.0e-2 + W_LM_MOUTH: 0 + W_REG: 0 + W_SYNC: 0 + W_COLOR: 0 + W_EXPRESSION: 0 + W_LIPREADING: 0.01 + W_LIPREADING_VV: 0 + W_EYE_BLINK: 4 + +TAG: + NAME: small_dataset + + diff --git a/scripts/config/auido2pose.yaml b/scripts/config/auido2pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc61f94d12f406f2d8d02545e55b61075051484d --- /dev/null +++ b/scripts/config/auido2pose.yaml @@ -0,0 +1,49 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt + TRAIN_BATCH_SIZE: 64 + EVAL_BATCH_SIZE: 1 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + + +MODEL: + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 6 + SEQ_LEN: 32 + LATENT_SIZE: 64 + ENCODER_LAYER_SIZES: [192, 128] + DECODER_LAYER_SIZES: [128, 192] + + +TRAIN: + MAX_EPOCH: 150 + GENERATOR: + LR: 1.0e-4 + DISCRIMINATOR: + LR: 1.0e-4 + LOSS: + LAMBDA_REG: 1 + LAMBDA_LANDMARKS: 0 + LAMBDA_VERTICES: 0 + LAMBDA_GAN_MOTION: 0.7 + LAMBDA_GAN_COEFF: 0 + LAMBDA_KL: 1 + +TAG: + NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder + + diff --git a/scripts/config/facerender.yaml b/scripts/config/facerender.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9494ef82dfa16b16b7aa0b848ebdd6b23e739e2a --- /dev/null +++ b/scripts/config/facerender.yaml @@ -0,0 +1,45 @@ +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False # True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 # 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + mapping_params: + coeff_nc: 70 + descriptor_nc: 1024 + layer: 3 + num_kp: 15 + num_bins: 66 + diff --git a/scripts/config/facerender_still.yaml b/scripts/config/facerender_still.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b4d66dade3e655ac4cfc25a994ca28e53821d80 --- /dev/null +++ b/scripts/config/facerender_still.yaml @@ -0,0 +1,45 @@ +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False # True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 # 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + mapping_params: + coeff_nc: 73 + descriptor_nc: 1024 + layer: 3 + num_kp: 15 + num_bins: 66 + diff --git a/scripts/config/similarity_Lm3D_all.mat b/scripts/config/similarity_Lm3D_all.mat new file mode 100644 index 0000000000000000000000000000000000000000..a0e23588302bc71fc899eef53ff06df5f4df4c1d Binary files /dev/null and b/scripts/config/similarity_Lm3D_all.mat differ diff --git a/scripts/encoders/__init__.py b/scripts/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/encoders/__pycache__/__init__.cpython-310.pyc b/scripts/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f3052d26ec1d385f65c3497b68fea39ec1a8d19 Binary files /dev/null and b/scripts/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/encoders/__pycache__/helpers.cpython-310.pyc b/scripts/encoders/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69df3329e4dc676a7161d42b9cd2c08b3769b3ac Binary files /dev/null and b/scripts/encoders/__pycache__/helpers.cpython-310.pyc differ diff --git a/scripts/encoders/__pycache__/psp_encoders.cpython-310.pyc b/scripts/encoders/__pycache__/psp_encoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a528ca00662d5db4536ac6c79b56d364722fdc95 Binary files /dev/null and b/scripts/encoders/__pycache__/psp_encoders.cpython-310.pyc differ diff --git a/scripts/encoders/helpers.py b/scripts/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b51fdf97141407fcc1c9d249a086ddbfd042469f --- /dev/null +++ b/scripts/encoders/helpers.py @@ -0,0 +1,119 @@ +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/scripts/encoders/model_irse.py b/scripts/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..461ba1e332dceabfb59de711dbd408ecfd2cae9e --- /dev/null +++ b/scripts/encoders/model_irse.py @@ -0,0 +1,48 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from scripts.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) diff --git a/scripts/encoders/psp_encoders.py b/scripts/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0489c8dc15271a1057520ef42c84d13373b4a5 --- /dev/null +++ b/scripts/encoders/psp_encoders.py @@ -0,0 +1,114 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module + +from scripts.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE +from scripts.stylegan2.model import EqualLinear + + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + self.style_count = n_styles + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def _upsample_add(self, x, y): + '''Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + ''' + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = self._upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = self._upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out diff --git a/scripts/erase_scratches.py b/scripts/erase_scratches.py new file mode 100644 index 0000000000000000000000000000000000000000..0dddd2b4728aded81e3aafbac7baf1c199c34e81 --- /dev/null +++ b/scripts/erase_scratches.py @@ -0,0 +1,71 @@ +############################################################################# +# +# Source from: +# https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life +# Forked from: +# +# Reimplemented by: Leonel Hernández +# +############################################################################## +import logging +import os.path + +import PIL.Image +import cv2 +import numpy as np +import torch +from torchvision.transforms import transforms + +from scripts.erasescratches.models import Pix2PixHDModel_Mapping +from scripts.erasescratches.options import Options +from scripts.maskscratches import ScratchesDetector +from scripts.util import irregular_hole_synthesize, tensor_to_ndarray + +REPO_ID = "leonelhs/zeroscratches" + + +class EraseScratches: + + def __init__(self): + + snapshot_folder = './models/zeroscratches' + model_path = os.path.join(snapshot_folder, "restoration") + self.detector = ScratchesDetector(snapshot_folder) + gpu_ids = [] + if torch.cuda.is_available(): + gpu_ids = [d for d in range(torch.cuda.device_count())] + self.options = Options(model_path, gpu_ids) + self.model_scratches = Pix2PixHDModel_Mapping() + self.model_scratches.initialize(self.options) + self.model_scratches.eval() + + def erase(self, image) -> np.array: + transformed, mask = self.detector.process(image) + logging.info("Start erasing scratches") + + img_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + mask_transform = transforms.ToTensor() + + if self.options.mask_dilation != 0: + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask) + mask = cv2.dilate(mask, kernel, iterations=self.options.mask_dilation) + mask = PIL.Image.fromarray(mask.astype('uint8')) + + transformed = irregular_hole_synthesize(transformed, mask) + mask = mask_transform(mask) + mask = mask[:1, :, :] # Convert to single channel + mask = mask.unsqueeze(0) + transformed = img_transform(transformed) + transformed = transformed.unsqueeze(0) + + try: + with torch.no_grad(): + generated = self.model_scratches.inference(transformed, mask) + except Exception as ex: + raise TypeError("Skip photo due to an error:\n%s" % str(ex)) + + tensor_restored = (generated.data.cpu() + 1.0) / 2.0 + return tensor_to_ndarray(tensor_restored) diff --git a/scripts/erasescratches/__init__.py b/scripts/erasescratches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21360c9d459717a41f94ea02386ac0b20bc53b7d --- /dev/null +++ b/scripts/erasescratches/__init__.py @@ -0,0 +1 @@ +from .models.mapping_model import Pix2PixHDModel_Mapping diff --git a/scripts/erasescratches/__pycache__/__init__.cpython-310.pyc b/scripts/erasescratches/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a26e317a5252134aae165eb81e80c0d1ca7a563 Binary files /dev/null and b/scripts/erasescratches/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/erasescratches/__pycache__/__init__.cpython-311.pyc b/scripts/erasescratches/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..506eae825560ac7ba2810b2517bed4bec9eed1ec Binary files /dev/null and b/scripts/erasescratches/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/erasescratches/__pycache__/options.cpython-310.pyc b/scripts/erasescratches/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f35572f7e0fc3d77efc72e11380e0c4787edb4b Binary files /dev/null and b/scripts/erasescratches/__pycache__/options.cpython-310.pyc differ diff --git a/scripts/erasescratches/__pycache__/options.cpython-311.pyc b/scripts/erasescratches/__pycache__/options.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a74852cf7d7e8a092d24c23096135556b07451b6 Binary files /dev/null and b/scripts/erasescratches/__pycache__/options.cpython-311.pyc differ diff --git a/scripts/erasescratches/models/NonLocal_feature_mapping_model.py b/scripts/erasescratches/models/NonLocal_feature_mapping_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f3aaf084fc4a2d06186908af02c533897ae310e9 --- /dev/null +++ b/scripts/erasescratches/models/NonLocal_feature_mapping_model.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import logging +import torch.nn as nn +from . import networks + + +class Mapping_Model_with_mask(nn.Module): + def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): + super(Mapping_Model_with_mask, self).__init__() + + norm_layer = networks.get_norm_layer(norm_type=norm) + activation = nn.ReLU(True) + model = [] + + tmp_nc = 64 + n_up = 4 + + for i in range(n_up): + ic = min(tmp_nc * (2 ** i), mc) + oc = min(tmp_nc * (2 ** (i + 1)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + + self.before_NL = nn.Sequential(*model) + + if opt.NL_res: + self.NL = networks.NonLocalBlock2D_with_mask_Res( + mc, + mc, + opt.NL_fusion_method, + opt.correlation_renormalize, + opt.softmax_temperature, + opt.use_self, + opt.cosin_similarity, + ) + print("You are using NL + Res") + + model = [] + for i in range(n_blocks): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + for i in range(n_up - 1): + ic = min(64 * (2 ** (4 - i)), mc) + oc = min(64 * (2 ** (3 - i)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] + if opt.feat_dim > 0 and opt.feat_dim < 64: + model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] + # model += [nn.Conv2d(64, 1, 1, 1, 0)] + self.after_NL = nn.Sequential(*model) + + def forward(self, input, mask): + x1 = self.before_NL(input) + del input + x2 = self.NL(x1, mask) + del x1, mask + x3 = self.after_NL(x2) + del x2 + + return x3 + + +class Mapping_Model_with_mask_2(nn.Module): # Multi-Scale Patch Attention + def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): + super(Mapping_Model_with_mask_2, self).__init__() + + norm_layer = networks.get_norm_layer(norm_type=norm) + activation = nn.ReLU(True) + model = [] + + tmp_nc = 64 + n_up = 4 + + for i in range(n_up): + ic = min(tmp_nc * (2 ** i), mc) + oc = min(tmp_nc * (2 ** (i + 1)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + logging.info("Mapping: You are using multi-scale patch attention, conv combine + mask input") + + self.before_NL = nn.Sequential(*model) + + if opt.mapping_exp == 1: + self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8) + + model = [] + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + self.res_block_1 = nn.Sequential(*model) + + if opt.mapping_exp == 1: + self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4) + + model = [] + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + self.res_block_2 = nn.Sequential(*model) + + if opt.mapping_exp == 1: + self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2) + # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2) + + model = [] + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + for i in range(n_up - 1): + ic = min(64 * (2 ** (4 - i)), mc) + oc = min(64 * (2 ** (3 - i)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] + if opt.feat_dim > 0 and opt.feat_dim < 64: + model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] + # model += [nn.Conv2d(64, 1, 1, 1, 0)] + self.after_NL = nn.Sequential(*model) + + def forward(self, input, mask): + x1 = self.before_NL(input) + x2 = self.NL_scale_1(x1, mask) + x3 = self.res_block_1(x2) + x4 = self.NL_scale_2(x3, mask) + x5 = self.res_block_2(x4) + x6 = self.NL_scale_3(x5, mask) + x7 = self.after_NL(x6) + return x7 + + def inference_forward(self, input, mask): + x1 = self.before_NL(input) + del input + x2 = self.NL_scale_1.inference_forward(x1, mask) + del x1 + x3 = self.res_block_1(x2) + del x2 + x4 = self.NL_scale_2.inference_forward(x3, mask) + del x3 + x5 = self.res_block_2(x4) + del x4 + x6 = self.NL_scale_3.inference_forward(x5, mask) + del x5 + x7 = self.after_NL(x6) + del x6 + return x7 diff --git a/scripts/erasescratches/models/__init__.py b/scripts/erasescratches/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cead806d103ab5b2e54a544314c41f3474a1bbab --- /dev/null +++ b/scripts/erasescratches/models/__init__.py @@ -0,0 +1 @@ +from .mapping_model import Pix2PixHDModel_Mapping diff --git a/scripts/erasescratches/models/__pycache__/NonLocal_feature_mapping_model.cpython-310.pyc b/scripts/erasescratches/models/__pycache__/NonLocal_feature_mapping_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d372d10969b74684b3a7fc6fc83b8123bb1e473 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/NonLocal_feature_mapping_model.cpython-310.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/NonLocal_feature_mapping_model.cpython-311.pyc b/scripts/erasescratches/models/__pycache__/NonLocal_feature_mapping_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd6159eb0faf68f94ff40235f36cce3b23107a8 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/NonLocal_feature_mapping_model.cpython-311.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/__init__.cpython-310.pyc b/scripts/erasescratches/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807ce794c599a393ab1b4c1e57f502a0d9dfbb83 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/__init__.cpython-311.pyc b/scripts/erasescratches/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57c2a7e0e5721ca31019a617f75a4bc5540e0008 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/base_model.cpython-310.pyc b/scripts/erasescratches/models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f952cda65a7b441587277416dcb72a02f2767c3f Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/base_model.cpython-310.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/base_model.cpython-311.pyc b/scripts/erasescratches/models/__pycache__/base_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f2c149032cb8773b3819b79eb5ea0706516445 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/base_model.cpython-311.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/mapping_model.cpython-310.pyc b/scripts/erasescratches/models/__pycache__/mapping_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bd08185f3cd2f18a6fa59d904d73e6f93eb1d58 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/mapping_model.cpython-310.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/mapping_model.cpython-311.pyc b/scripts/erasescratches/models/__pycache__/mapping_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..938ed61c2dd9089b5dd6a0289f5a6b20d9976912 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/mapping_model.cpython-311.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/networks.cpython-310.pyc b/scripts/erasescratches/models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c9a1015efb9b2f73f3f28d4ae31804e96415031 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/networks.cpython-310.pyc differ diff --git a/scripts/erasescratches/models/__pycache__/networks.cpython-311.pyc b/scripts/erasescratches/models/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3384fe434d616b4eea47d0be458cc659333b72e8 Binary files /dev/null and b/scripts/erasescratches/models/__pycache__/networks.cpython-311.pyc differ diff --git a/scripts/erasescratches/models/base_model.py b/scripts/erasescratches/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4043116050e057f31099cda3ecae6ee3fa46cb2a --- /dev/null +++ b/scripts/erasescratches/models/base_model.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import torch +import sys + + +class BaseModel(torch.nn.Module): + def name(self): + return "BaseModel" + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = "%s_net_%s.pth" % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda() + + def save_optimizer(self, optimizer, optimizer_label, epoch_label): + save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(optimizer.state_dict(), save_path) + + def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""): + save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) + if not save_dir: + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + + if not os.path.isfile(save_path): + print("%s not exists yet!" % save_path) + else: + optimizer.load_state_dict(torch.load(save_path)) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label, save_dir=""): + save_filename = "%s_net_%s.pth" % (epoch_label, network_label) + if not save_dir: + save_dir = self.save_dir + + # print(save_dir) + # print(self.save_dir) + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + print("%s not exists yet!" % save_path) + # if network_label == 'G': + # raise('Generator must exist!') + else: + # network.load_state_dict(torch.load(save_path)) + try: + # print(save_path) + network.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + model_dict = network.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + # if self.opt.verbose: + print( + "Pretrained network %s has excessive layers; Only loading layers that are used" + % network_label + ) + except: + print( + "Pretrained network %s has fewer layers; The following are not initialized:" + % network_label + ) + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + if sys.version_info >= (3, 0): + not_initialized = set() + else: + from sets import Set + + not_initialized = Set() + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split(".")[0]) + + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + def update_learning_rate(): + pass diff --git a/scripts/erasescratches/models/mapping_model.py b/scripts/erasescratches/models/mapping_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f0d0bd6971f8d5c1946e676b84c7755f942a49 --- /dev/null +++ b/scripts/erasescratches/models/mapping_model.py @@ -0,0 +1,340 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import torch + +from .NonLocal_feature_mapping_model import * +from .base_model import BaseModel +from scripts.util.image_pool import ImagePool + + +class Mapping_Model(nn.Module): + def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): + super(Mapping_Model, self).__init__() + + norm_layer = networks.get_norm_layer(norm_type=norm) + activation = nn.ReLU(True) + model = [] + tmp_nc = 64 + n_up = 4 + + print("Mapping: You are using the mapping model without global restoration.") + + for i in range(n_up): + ic = min(tmp_nc * (2 ** i), mc) + oc = min(tmp_nc * (2 ** (i + 1)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + for i in range(n_blocks): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + for i in range(n_up - 1): + ic = min(64 * (2 ** (4 - i)), mc) + oc = min(64 * (2 ** (3 - i)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] + if opt.feat_dim > 0 and opt.feat_dim < 64: + model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] + # model += [nn.Conv2d(64, 1, 1, 1, 0)] + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class Pix2PixHDModel_Mapping(BaseModel): + def name(self): + return "Pix2PixHDModel_Mapping" + + def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2): + flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2) + + def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2): + return [ + l + for (l, f) in zip( + (g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags + ) + if f + ] + + return loss_filter + + def initialize(self, opt): + BaseModel.initialize(self, opt) + if opt.resize_or_crop != "none" or not opt.isTrain: + torch.backends.cudnn.benchmark = True + self.isTrain = opt.isTrain + input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc + + ##### define networks + # Generator network + netG_input_nc = input_nc + self.netG_A = networks.GlobalGenerator_DCDCv2( + netG_input_nc, + opt.output_nc, + opt.ngf, + opt.k_size, + opt.n_downsample_global, + networks.get_norm_layer(norm_type=opt.norm), + opt=opt, + ) + self.netG_B = networks.GlobalGenerator_DCDCv2( + netG_input_nc, + opt.output_nc, + opt.ngf, + opt.k_size, + opt.n_downsample_global, + networks.get_norm_layer(norm_type=opt.norm), + opt=opt, + ) + + if opt.non_local == "Setting_42" or opt.NL_use_mask: + if opt.mapping_exp == 1: + self.mapping_net = Mapping_Model_with_mask_2( + min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), + opt.map_mc, + n_blocks=opt.mapping_n_block, + opt=opt, + ) + else: + self.mapping_net = Mapping_Model_with_mask( + min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), + opt.map_mc, + n_blocks=opt.mapping_n_block, + opt=opt, + ) + else: + self.mapping_net = Mapping_Model( + min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), + opt.map_mc, + n_blocks=opt.mapping_n_block, + opt=opt, + ) + + self.mapping_net.apply(networks.weights_init) + + if opt.load_pretrain != "": + self.load_network(self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain) + + if not opt.no_load_VAE: + + self.load_network(self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA) + self.load_network(self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB) + for param in self.netG_A.parameters(): + param.requires_grad = False + for param in self.netG_B.parameters(): + param.requires_grad = False + self.netG_A.eval() + self.netG_B.eval() + + if opt.gpu_ids: + self.netG_A.cuda(opt.gpu_ids[0]) + self.netG_B.cuda(opt.gpu_ids[0]) + self.mapping_net.cuda(opt.gpu_ids[0]) + + if not self.isTrain: + self.load_network(self.mapping_net, "mapping_net", opt.which_epoch) + + # Discriminator network + if self.isTrain: + use_sigmoid = opt.no_lsgan + netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc + if not opt.no_instance: + netD_input_nc += 1 + + self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, + opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + # set loss functions and optimizers + if self.isTrain: + if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: + raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") + self.fake_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + + # define loss functions + self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1, + opt.use_two_stage_mapping) + + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + + self.criterionFeat = torch.nn.L1Loss() + self.criterionFeat_feat = torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss() + + if self.opt.image_L1: + self.criterionImage = torch.nn.L1Loss() + else: + self.criterionImage = torch.nn.SmoothL1Loss() + + print(self.criterionFeat_feat) + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) + + # Names so we can breakout loss + self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake', + 'Smooth_L1', 'G_Feat_L2_Stage_1') + + # initialize optimizers + # optimizer G + + if opt.no_TTUR: + beta1, beta2 = opt.beta1, 0.999 + G_lr, D_lr = opt.lr, opt.lr + else: + beta1, beta2 = 0, 0.9 + G_lr, D_lr = opt.lr / 2, opt.lr * 2 + + if not opt.no_load_VAE: + params = list(self.mapping_net.parameters()) + self.optimizer_mapping = torch.optim.Adam(params, lr=G_lr, betas=(beta1, beta2)) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2)) + + print("---------- Optimizers initialized -------------") + + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cuda() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type == 16: + input_label = input_label.half() + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cuda() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cuda()) + + return input_label, inst_map, real_image, feat_map + + def discriminate(self, input_label, test_image, use_pool=False): + input_concat = torch.cat((input_label, test_image.detach()), dim=1) + if use_pool: + fake_query = self.fake_pool.query(input_concat) + return self.netD.forward(fake_query) + else: + return self.netD.forward(input_concat) + + def forward(self, label, inst, image, feat, pair=True, infer=False, last_label=None, last_image=None): + # Encode Inputs + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + + # Fake Generation + input_concat = input_label + + label_feat = self.netG_A.forward(input_concat, flow='enc') + # print('label:') + # print(label_feat.min(), label_feat.max(), label_feat.mean()) + # label_feat = label_feat / 16.0 + + if self.opt.NL_use_mask: + label_feat_map = self.mapping_net(label_feat.detach(), inst) + else: + label_feat_map = self.mapping_net(label_feat.detach()) + + fake_image = self.netG_B.forward(label_feat_map, flow='dec') + image_feat = self.netG_B.forward(real_image, flow='enc') + + loss_feat_l2_stage_1 = 0 + loss_feat_l2 = self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat + + if self.opt.feat_gan: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(label_feat.detach(), label_feat_map, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(label_feat.detach(), image_feat) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((label_feat.detach(), label_feat_map), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + else: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + if pair: + pred_real = self.discriminate(input_label, real_image) + else: + pred_real = self.discriminate(last_label, last_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + + # GAN feature matching loss + loss_G_GAN_Feat = 0 + if not self.opt.no_ganFeat_loss and pair: + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(self.opt.num_D): + for j in range(len(pred_fake[i]) - 1): + tmp = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + loss_G_GAN_Feat += D_weights * feat_weights * tmp + else: + loss_G_GAN_Feat = torch.zeros(1).to(label.device) + + # VGG feature matching loss + loss_G_VGG = 0 + if not self.opt.no_vgg_loss: + loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat if pair else torch.zeros( + 1).to(label.device) + + smooth_l1_loss = 0 + if self.opt.Smooth_L1: + smooth_l1_loss = self.criterionImage(fake_image, real_image) * self.opt.L1_weight + + return [self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake, + smooth_l1_loss, loss_feat_l2_stage_1), None if not infer else fake_image] + + def inference(self, label, inst): + + use_gpu = len(self.opt.gpu_ids) > 0 + if use_gpu: + input_concat = label.data.cuda() + inst_data = inst.cuda() + else: + input_concat = label.data + inst_data = inst + + label_feat = self.netG_A.forward(input_concat, flow="enc") + + if self.opt.NL_use_mask: + if self.opt.inference_optimize: + label_feat_map = self.mapping_net.inference_forward(label_feat.detach(), inst_data) + else: + label_feat_map = self.mapping_net(label_feat.detach(), inst_data) + else: + label_feat_map = self.mapping_net(label_feat.detach()) + + fake_image = self.netG_B.forward(label_feat_map, flow="dec") + return fake_image + + +class InferenceModel(Pix2PixHDModel_Mapping): + def forward(self, label, inst): + return self.inference(label, inst) diff --git a/scripts/erasescratches/models/networks.py b/scripts/erasescratches/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8e43ea8cae1dd70ded3e4d525dfed17a1a3445f2 --- /dev/null +++ b/scripts/erasescratches/models/networks.py @@ -0,0 +1,857 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import functools +from torch.autograd import Variable +import numpy as np +from torch.nn.utils import spectral_norm + +# from util.util import SwitchNorm2d +import torch.nn.functional as F + + +############################################################################### +# Functions +############################################################################### +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type="instance"): + if norm_type == "batch": + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == "instance": + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == "spectral": + norm_layer = spectral_norm() + elif norm_type == "SwitchNorm": + norm_layer = SwitchNorm2d + else: + raise NotImplementedError("normalization layer [%s] is not found" % norm_type) + return norm_layer + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print("Total number of parameters: %d" % num_params) + + +def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, + n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): + norm_layer = get_norm_layer(norm_type=norm) + if netG == 'global': + # if opt.self_gen: + if opt.use_v2: + netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) + else: + netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, + norm_layer, opt=opt) + else: + raise ('generator not implemented!') + print(netG) + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + netG.cuda(gpu_ids[0]) + netG.apply(weights_init) + return netG + + +def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, + gpu_ids=[]): + norm_layer = get_norm_layer(norm_type=norm) + netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) + print(netD) + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + netD.cuda(gpu_ids[0]) + netD.apply(weights_init) + return netD + + +class GlobalGenerator_DCDCv2(nn.Module): + def __init__( + self, + input_nc, + output_nc, + ngf=64, + k_size=3, + n_downsampling=8, + norm_layer=nn.BatchNorm2d, + padding_type="reflect", + opt=None, + ): + super(GlobalGenerator_DCDCv2, self).__init__() + activation = nn.ReLU(True) + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), + norm_layer(ngf), + activation, + ] + ### downsample + for i in range(opt.start_r): + mult = 2 ** i + model += [ + nn.Conv2d( + min(ngf * mult, opt.mc), + min(ngf * mult * 2, opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + ), + norm_layer(min(ngf * mult * 2, opt.mc)), + activation, + ] + for i in range(opt.start_r, n_downsampling - 1): + mult = 2 ** i + model += [ + nn.Conv2d( + min(ngf * mult, opt.mc), + min(ngf * mult * 2, opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + ), + norm_layer(min(ngf * mult * 2, opt.mc)), + activation, + ] + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + mult = 2 ** (n_downsampling - 1) + + if opt.spatio_size == 32: + model += [ + nn.Conv2d( + min(ngf * mult, opt.mc), + min(ngf * mult * 2, opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + ), + norm_layer(min(ngf * mult * 2, opt.mc)), + activation, + ] + if opt.spatio_size == 64: + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + # model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)] + if opt.feat_dim > 0: + model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] + self.encoder = nn.Sequential(*model) + + # decode + model = [] + if opt.feat_dim > 0: + model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] + # model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)] + o_pad = 0 if k_size == 4 else 1 + mult = 2 ** n_downsampling + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + + if opt.spatio_size == 32: + model += [ + nn.ConvTranspose2d( + min(ngf * mult, opt.mc), + min(int(ngf * mult / 2), opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + output_padding=o_pad, + ), + norm_layer(min(int(ngf * mult / 2), opt.mc)), + activation, + ] + if opt.spatio_size == 64: + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + + for i in range(1, n_downsampling - opt.start_r): + mult = 2 ** (n_downsampling - i) + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + nn.ConvTranspose2d( + min(ngf * mult, opt.mc), + min(int(ngf * mult / 2), opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + output_padding=o_pad, + ), + norm_layer(min(int(ngf * mult / 2), opt.mc)), + activation, + ] + for i in range(n_downsampling - opt.start_r, n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [ + nn.ConvTranspose2d( + min(ngf * mult, opt.mc), + min(int(ngf * mult / 2), opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + output_padding=o_pad, + ), + norm_layer(min(int(ngf * mult / 2), opt.mc)), + activation, + ] + if opt.use_segmentation_model: + model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] + else: + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), + nn.Tanh(), + ] + self.decoder = nn.Sequential(*model) + + def forward(self, input, flow="enc_dec"): + if flow == "enc": + return self.encoder(input) + elif flow == "dec": + return self.decoder(input) + elif flow == "enc_dec": + x = self.encoder(input) + x = self.decoder(x) + return x + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__( + self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 + ): + super(ResnetBlock, self).__init__() + self.opt = opt + self.dilation = dilation + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): + conv_block = [] + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(self.dilation)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(self.dilation)] + elif padding_type == "zero": + p = self.dilation + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), + norm_layer(dim), + activation, + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == "zero": + p = 1 + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class Encoder(nn.Module): + def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True), + ] + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + model += [ + nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + ] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [ + nn.ConvTranspose2d( + ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 + ), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True), + ] + + model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + # instance-wise average pooling + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + for b in range(input.size()[0]): + indices = (inst[b: b + 1] == int(i)).nonzero() # n x 4 + for j in range(self.output_nc): + output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[ + indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] + ] = mean_feat + return outputs_mean + + +def SN(module, mode=True): + if mode: + return torch.nn.utils.spectral_norm(module) + + return module + + +class NonLocalBlock2D_with_mask_Res(nn.Module): + def __init__( + self, + in_channels, + inter_channels, + mode="add", + re_norm=False, + temperature=1.0, + use_self=False, + cosin=False, + ): + super(NonLocalBlock2D_with_mask_Res, self).__init__() + + self.cosin = cosin + self.renorm = re_norm + self.in_channels = in_channels + self.inter_channels = inter_channels + + self.g = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.W = nn.Conv2d( + in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 + ) + # for pytorch 0.3.1 + # nn.init.constant(self.W.weight, 0) + # nn.init.constant(self.W.bias, 0) + # for pytorch 0.4.0 + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) + self.theta = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.phi = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.mode = mode + self.temperature = temperature + self.use_self = use_self + + norm_layer = get_norm_layer(norm_type="instance") + activation = nn.ReLU(True) + + model = [] + for i in range(3): + model += [ + ResnetBlock( + inter_channels, + padding_type="reflect", + activation=activation, + norm_layer=norm_layer, + opt=None, + ) + ] + self.res_block = nn.Sequential(*model) + + def forward(self, x, mask): ## The shape of mask is Batch*1*H*W + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + + theta_x = theta_x.permute(0, 2, 1) + + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + if self.cosin: + theta_x = F.normalize(theta_x, dim=2) + phi_x = F.normalize(phi_x, dim=1) + + f = torch.matmul(theta_x, phi_x) + + f /= self.temperature + + f_div_C = F.softmax(f, dim=2) + + tmp = 1 - mask + mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") + mask[mask > 0] = 1.0 + mask = 1 - mask + + tmp = F.interpolate(tmp, (x.size(2), x.size(3))) + mask *= tmp + + mask_expand = mask.view(batch_size, 1, -1) + mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) + + # mask = 1 - mask + # mask=F.interpolate(mask,(x.size(2),x.size(3))) + # mask_expand=mask.view(batch_size,1,-1) + # mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1) + + if self.use_self: + mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 + + # print(mask_expand.shape) + # print(f_div_C.shape) + + f_div_C = mask_expand * f_div_C + if self.renorm: + f_div_C = F.normalize(f_div_C, p=1, dim=2) + + ########################### + + y = torch.matmul(f_div_C, g_x) + + y = y.permute(0, 2, 1).contiguous() + + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + + W_y = self.res_block(W_y) + + if self.mode == "combine": + full_mask = mask.repeat(1, self.inter_channels, 1, 1) + z = full_mask * x + (1 - full_mask) * W_y + return z + + +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, + use_sigmoid=False, num_D=3, getIntermFeat=False): + super(MultiscaleDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + self.getIntermFeat = getIntermFeat + + for i in range(num_D): + netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) + if getIntermFeat: + for j in range(n_layers + 2): + setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) + else: + setattr(self, 'layer' + str(i), netD.model) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def singleD_forward(self, model, input): + if self.getIntermFeat: + result = [input] + for i in range(len(model)): + result.append(model[i](result[-1])) + return result[1:] + else: + return [model(input)] + + def forward(self, input): + num_D = self.num_D + result = [] + input_downsampled = input + for i in range(num_D): + if self.getIntermFeat: + model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in + range(self.n_layers + 2)] + else: + model = getattr(self, 'layer' + str(num_D - 1 - i)) + result.append(self.singleD_forward(model, input_downsampled)) + if i != (num_D - 1): + input_downsampled = self.downsample(input_downsampled) + return result + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, + getIntermFeat=False): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + sequence = [ + [SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), opt.use_SN), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), opt.use_SN), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), opt.use_SN), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), opt.use_SN)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + else: + return self.model(input) + + +class Patch_Attention_4(nn.Module): # While combine the feature map, use conv and mask + def __init__(self, in_channels, inter_channels, patch_size): + super(Patch_Attention_4, self).__init__() + + self.patch_size = patch_size + + self.F_Combine = nn.Conv2d(in_channels=1025, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True) + norm_layer = get_norm_layer(norm_type="instance") + activation = nn.ReLU(True) + + model = [] + for i in range(1): + model += [ + ResnetBlock( + inter_channels, + padding_type="reflect", + activation=activation, + norm_layer=norm_layer, + opt=None, + ) + ] + self.res_block = nn.Sequential(*model) + + def Hard_Compose(self, input, dim, index): + # batch index select + # input: [B,C,HW] + # dim: scalar > 0 + # index: [B, HW] + views = [input.size(0)] + [1 if i != dim else -1 for i in range(1, len(input.size()))] + expanse = list(input.size()) + expanse[0] = -1 + expanse[dim] = -1 + index = index.view(views).expand(expanse) + return torch.gather(input, dim, index) + + def forward(self, z, mask): ## The shape of mask is Batch*1*H*W + + x = self.res_block(z) + + b, c, h, w = x.shape + + ## mask resize + dilation + # tmp = 1 - mask + mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") + mask[mask > 0] = 1.0 + + # mask = 1 - mask + # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) + # mask *= tmp + # mask=1-mask + ## 1: mask position 0: non-mask + + mask_unfold = F.unfold(mask, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) + non_mask_region = (torch.mean(mask_unfold, dim=1, keepdim=True) > 0.6).float() + all_patch_num = h * w / self.patch_size / self.patch_size + non_mask_region = non_mask_region.repeat(1, int(all_patch_num), 1) + + x_unfold = F.unfold(x, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) + y_unfold = x_unfold.permute(0, 2, 1) + x_unfold_normalized = F.normalize(x_unfold, dim=1) + y_unfold_normalized = F.normalize(y_unfold, dim=2) + correlation_matrix = torch.bmm(y_unfold_normalized, x_unfold_normalized) + correlation_matrix = correlation_matrix.masked_fill(non_mask_region == 1., -1e9) + correlation_matrix = F.softmax(correlation_matrix, dim=2) + + # print(correlation_matrix) + + R, max_arg = torch.max(correlation_matrix, dim=2) + + composed_unfold = self.Hard_Compose(x_unfold, 2, max_arg) + composed_fold = F.fold(composed_unfold, output_size=(h, w), kernel_size=(self.patch_size, self.patch_size), + padding=0, stride=self.patch_size) + + concat_1 = torch.cat((z, composed_fold, mask), dim=1) + concat_1 = self.F_Combine(concat_1) + + return concat_1 + + def inference_forward(self, z, mask): ## Reduce the extra memory cost + + x = self.res_block(z) + + b, c, h, w = x.shape + + ## mask resize + dilation + # tmp = 1 - mask + mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") + mask[mask > 0] = 1.0 + # mask = 1 - mask + # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) + # mask *= tmp + # mask=1-mask + ## 1: mask position 0: non-mask + + mask_unfold = F.unfold(mask, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) + non_mask_region = (torch.mean(mask_unfold, dim=1, keepdim=True) > 0.6).float()[0, 0, :] # 1*1*all_patch_num + + all_patch_num = h * w / self.patch_size / self.patch_size + + mask_index = torch.nonzero(non_mask_region, as_tuple=True)[0] + + if len(mask_index) == 0: ## No mask patch is selected, no attention is needed + + composed_fold = x + + else: + + unmask_index = torch.nonzero(non_mask_region != 1, as_tuple=True)[0] + + x_unfold = F.unfold(x, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) + + Query_Patch = torch.index_select(x_unfold, 2, mask_index) + Key_Patch = torch.index_select(x_unfold, 2, unmask_index) + + Query_Patch = Query_Patch.permute(0, 2, 1) + Query_Patch_normalized = F.normalize(Query_Patch, dim=2) + Key_Patch_normalized = F.normalize(Key_Patch, dim=1) + + correlation_matrix = torch.bmm(Query_Patch_normalized, Key_Patch_normalized) + correlation_matrix = F.softmax(correlation_matrix, dim=2) + + R, max_arg = torch.max(correlation_matrix, dim=2) + + composed_unfold = self.Hard_Compose(Key_Patch, 2, max_arg) + x_unfold[:, :, mask_index] = composed_unfold + composed_fold = F.fold(x_unfold, output_size=(h, w), kernel_size=(self.patch_size, self.patch_size), + padding=0, stride=self.patch_size) + + concat_1 = torch.cat((z, composed_fold, mask), dim=1) + concat_1 = self.F_Combine(concat_1) + + return concat_1 + + +############################################################################## +# Losses +############################################################################## +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + if isinstance(input[0], list): + loss = 0 + for input_i in input: + pred = input_i[-1] + target_tensor = self.get_target_tensor(pred, target_is_real) + loss += self.loss(pred, target_tensor) + return loss + else: + target_tensor = self.get_target_tensor(input[-1], target_is_real) + return self.loss(input[-1], target_tensor) + + +# VGG Loss + +from torchvision import models + + +class VGG19_torch(torch.nn.Module): + def __init__(self, requires_grad=False): + super(VGG19_torch, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class VGGLoss_torch(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss_torch, self).__init__() + self.vgg = VGG19_torch().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss diff --git a/scripts/erasescratches/options.py b/scripts/erasescratches/options.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a68e5a51d2a3d9c1b1dacc4ee7ef641dbef05b --- /dev/null +++ b/scripts/erasescratches/options.py @@ -0,0 +1,47 @@ +import os + + +class Options: + def __init__(self, checkpoints, gpu_ids): + self.serial_batches = True # no shuffle + self.no_flip = True # no flip + self.label_nc = 0 + self.n_downsample_global = 3 + self.mc = 64 + self.k_size = 4 + self.start_r = 1 + self.mapping_n_block = 6 + self.map_mc = 512 + self.no_instance = True + self.checkpoints_dir = checkpoints + self.gpu_ids = gpu_ids + self.mapping_net_dilation = 1 + self.use_segmentation_model = False + self.feat_dim = -1 + self.spatio_size = 64 + self.resize_or_crop = 'scale_width' + self.isTrain = False + self.input_nc = 3 + self.output_nc = 3 + self.ngf = 64 + self.norm = "instance" + self.load_pretrain = "" + self.which_epoch = "latest" + self.load_pretrain = "" + self.no_load_VAE = False + self.use_vae_which_epoch = "latest" + + self.NL_res = True + self.use_SN = True + self.correlation_renormalize = True + self.NL_use_mask = True + self.NL_fusion_method = "combine" + self.non_local = "Setting_42" + # self.name = "mapping_scratch" + self.load_pretrainA = os.path.join(self.checkpoints_dir, "VAE_A_quality") + self.load_pretrainB = os.path.join(self.checkpoints_dir, "VAE_B_scratch") + + self.mapping_exp = 1 + self.inference_optimize = True + self.mask_dilation = 3 + self.name = "mapping_Patch_Attention" diff --git a/scripts/face3d/__pycache__/extract_kp_videos_safe.cpython-310.pyc b/scripts/face3d/__pycache__/extract_kp_videos_safe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69bd6552068cf14289cb3c4d206cde5dc25bca71 Binary files /dev/null and b/scripts/face3d/__pycache__/extract_kp_videos_safe.cpython-310.pyc differ diff --git a/scripts/face3d/data/__init__.py b/scripts/face3d/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9761c518a1b07c5996165869742af0a52c82bc --- /dev/null +++ b/scripts/face3d/data/__init__.py @@ -0,0 +1,116 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import numpy as np +import importlib +import torch.utils.data +from face3d.data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt, rank=0): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt, rank=rank) + dataset = data_loader.load_data() + return dataset + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt, rank=0): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + self.sampler = None + print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) + if opt.use_ddp and opt.isTrain: + world_size = opt.world_size + self.sampler = torch.utils.data.distributed.DistributedSampler( + self.dataset, + num_replicas=world_size, + rank=rank, + shuffle=not opt.serial_batches + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + sampler=self.sampler, + num_workers=int(opt.num_threads / world_size), + batch_size=int(opt.batch_size / world_size), + drop_last=True) + else: + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=(not opt.serial_batches) and opt.isTrain, + num_workers=int(opt.num_threads), + drop_last=True + ) + + def set_epoch(self, epoch): + self.dataset.current_epoch = epoch + if self.sampler is not None: + self.sampler.set_epoch(epoch) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/scripts/face3d/data/base_dataset.py b/scripts/face3d/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd57d082d519f512d7114b4f867b6695fb7de06 --- /dev/null +++ b/scripts/face3d/data/base_dataset.py @@ -0,0 +1,125 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + # self.root = opt.dataroot + self.current_epoch = 0 + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_transform(grayscale=False): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + transform_list += [transforms.ToTensor()] + return transforms.Compose(transform_list) + +def get_affine_mat(opt, size): + shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False + w, h = size + + if 'shift' in opt.preprocess: + shift_pixs = int(opt.shift_pixs) + shift_x = random.randint(-shift_pixs, shift_pixs) + shift_y = random.randint(-shift_pixs, shift_pixs) + if 'scale' in opt.preprocess: + scale = 1 + opt.scale_delta * (2 * random.random() - 1) + if 'rot' in opt.preprocess: + rot_angle = opt.rot_angle * (2 * random.random() - 1) + rot_rad = -rot_angle * np.pi/180 + if 'flip' in opt.preprocess: + flip = random.random() > 0.5 + + shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) + flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) + shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) + rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) + scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) + shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) + + affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin + affine_inv = np.linalg.inv(affine) + return affine, affine_inv, flip + +def apply_img_affine(img, affine_inv, method=Image.BICUBIC): + return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) + +def apply_lm_affine(landmark, affine, flip, size): + _, h = size + lm = landmark.copy() + lm[:, 1] = h - 1 - lm[:, 1] + lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) + lm = lm @ np.transpose(affine) + lm[:, :2] = lm[:, :2] / lm[:, 2:] + lm = lm[:, :2] + lm[:, 1] = h - 1 - lm[:, 1] + if flip: + lm_ = lm.copy() + lm_[:17] = lm[16::-1] + lm_[17:22] = lm[26:21:-1] + lm_[22:27] = lm[21:16:-1] + lm_[31:36] = lm[35:30:-1] + lm_[36:40] = lm[45:41:-1] + lm_[40:42] = lm[47:45:-1] + lm_[42:46] = lm[39:35:-1] + lm_[46:48] = lm[41:39:-1] + lm_[48:55] = lm[54:47:-1] + lm_[55:60] = lm[59:54:-1] + lm_[60:65] = lm[64:59:-1] + lm_[65:68] = lm[67:64:-1] + lm = lm_ + return lm diff --git a/scripts/face3d/data/flist_dataset.py b/scripts/face3d/data/flist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b6945c80aa756074a5d3c02b9443b15ddcfc57 --- /dev/null +++ b/scripts/face3d/data/flist_dataset.py @@ -0,0 +1,125 @@ +"""This script defines the custom dataset for Deep3DFaceRecon_pytorch +""" + +import os.path +from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine +from data.image_folder import make_dataset +from PIL import Image +import random +import util.util as util +import numpy as np +import json +import torch +from scipy.io import loadmat, savemat +import pickle +from util.preprocess import align_img, estimate_norm +from util.load_mats import load_lm3d + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath = line.strip() + imlist.append(impath) + + return imlist + +def jason_flist_reader(flist): + with open(flist, 'r') as fp: + info = json.load(fp) + return info + +def parse_label(label): + return torch.tensor(np.array(label).astype(np.float32)) + + +class FlistDataset(BaseDataset): + """ + It requires one directories to host training images '/path/to/data/train' + You can train the model with the dataset flag '--dataroot /path/to/data'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + self.lm3d_std = load_lm3d(opt.bfm_folder) + + msk_names = default_flist_reader(opt.flist) + self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] + + self.size = len(self.msk_paths) + self.opt = opt + + self.name = 'train' if opt.isTrain else 'val' + if '_' in opt.flist: + self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] + + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + img (tensor) -- an image in the input domain + msk (tensor) -- its corresponding attention mask + lm (tensor) -- its corresponding 3d landmarks + im_paths (str) -- image paths + aug_flag (bool) -- a flag used to tell whether its raw or augmented + """ + msk_path = self.msk_paths[index % self.size] # make sure index is within then range + img_path = msk_path.replace('mask/', '') + lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' + + raw_img = Image.open(img_path).convert('RGB') + raw_msk = Image.open(msk_path).convert('RGB') + raw_lm = np.loadtxt(lm_path).astype(np.float32) + + _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) + + aug_flag = self.opt.use_aug and self.opt.isTrain + if aug_flag: + img, lm, msk = self._augmentation(img, lm, self.opt, msk) + + _, H = img.size + M = estimate_norm(lm, H) + transform = get_transform() + img_tensor = transform(img) + msk_tensor = transform(msk)[:1, ...] + lm_tensor = parse_label(lm) + M_tensor = parse_label(M) + + + return {'imgs': img_tensor, + 'lms': lm_tensor, + 'msks': msk_tensor, + 'M': M_tensor, + 'im_paths': img_path, + 'aug_flag': aug_flag, + 'dataset': self.name} + + def _augmentation(self, img, lm, opt, msk=None): + affine, affine_inv, flip = get_affine_mat(opt, img.size) + img = apply_img_affine(img, affine_inv) + lm = apply_lm_affine(lm, affine, flip, img.size) + if msk is not None: + msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) + return img, lm, msk + + + + + def __len__(self): + """Return the total number of images in the dataset. + """ + return self.size diff --git a/scripts/face3d/data/image_folder.py b/scripts/face3d/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..efadc2ecbe2fb4b53b78230aba25ec505eff0e55 --- /dev/null +++ b/scripts/face3d/data/image_folder.py @@ -0,0 +1,66 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" +import numpy as np +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/scripts/face3d/data/template_dataset.py b/scripts/face3d/data/template_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdf16be2a8a834b204c45d88c86857b37b9bd25 --- /dev/null +++ b/scripts/face3d/data/template_dataset.py @@ -0,0 +1,75 @@ +"""Dataset class template + +This module provides a template for users to implement custom datasets. +You can specify '--dataset_mode template' to use this dataset. +The class name should be consistent with both the filename and its dataset_mode option. +The filename should be _dataset.py +The class name should be Dataset.py +You need to implement the following functions: + -- : Add dataset-specific options and rewrite default values for existing options. + -- <__init__>: Initialize this dataset class. + -- <__getitem__>: Return a data point and its metadata information. + -- <__len__>: Return the number of images. +""" +from data.base_dataset import BaseDataset, get_transform +# from data.image_folder import make_dataset +# from PIL import Image + + +class TemplateDataset(BaseDataset): + """A template dataset class for you to implement custom datasets.""" + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') + parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values + return parser + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + # save the option and dataset root + BaseDataset.__init__(self, opt) + # get the image paths of your dataset; + self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root + # define the default transform function. You can use ; You can also define your custom transform function + self.transform = get_transform(opt) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index -- a random integer for data indexing + + Returns: + a dictionary of data with their names. It usually contains the data itself and its metadata information. + + Step 1: get a random image path: e.g., path = self.image_paths[index] + Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). + Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) + Step 4: return a data point as a dictionary. + """ + path = 'temp' # needs to be a string + data_A = None # needs to be a tensor + data_B = None # needs to be a tensor + return {'data_A': data_A, 'data_B': data_B, 'path': path} + + def __len__(self): + """Return the total number of images.""" + return len(self.image_paths) diff --git a/scripts/face3d/extract_kp_videos.py b/scripts/face3d/extract_kp_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..21616a3b4b5077ffdce99621395237b4edcff58c --- /dev/null +++ b/scripts/face3d/extract_kp_videos.py @@ -0,0 +1,108 @@ +import os +import cv2 +import time +import glob +import argparse +import face_alignment +import numpy as np +from PIL import Image +from tqdm import tqdm +from itertools import cycle + +from torch.multiprocessing import Pool, Process, set_start_method + +class KeypointExtractor(): + def __init__(self, device): + self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, + device=device) + + def extract_keypoint(self, images, name=None, info=True): + if isinstance(images, list): + keypoints = [] + if info: + i_range = tqdm(images,desc='landmark Det:') + else: + i_range = images + + for image in i_range: + current_kp = self.extract_keypoint(image) + if np.mean(current_kp) == -1 and keypoints: + keypoints.append(keypoints[-1]) + else: + keypoints.append(current_kp[None]) + + keypoints = np.concatenate(keypoints, 0) + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + else: + while True: + try: + keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] + break + except RuntimeError as e: + if str(e).startswith('CUDA'): + print("Warning: out of memory, sleep for 1s") + time.sleep(1) + else: + print(e) + break + except TypeError: + print('No face detected in this image') + shape = [68, 2] + keypoints = -1. * np.ones(shape) + break + if name is not None: + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + +def read_video(filename): + frames = [] + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + kp_extractor = KeypointExtractor() + images = read_video(filename) + name = filename.split('/')[-2:] + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + kp_extractor.extract_keypoint( + images, + name=os.path.join(opt.output_dir, name[-2], name[-1]) + ) + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=4) + + opt = parser.parse_args() + filenames = list() + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + os.listdir(f'{opt.input_dir}') + print(f'{opt.input_dir}/*.{ext}') + filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) + print('Total number of videos:', len(filenames)) + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/scripts/face3d/extract_kp_videos_safe.py b/scripts/face3d/extract_kp_videos_safe.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf984877dfa269927e54d1089dbbe141f3429e5 --- /dev/null +++ b/scripts/face3d/extract_kp_videos_safe.py @@ -0,0 +1,151 @@ +import os +import cv2 +import time +import glob +import argparse +import numpy as np +from PIL import Image +import torch +from tqdm import tqdm +from itertools import cycle +from torch.multiprocessing import Pool, Process, set_start_method + +from facexlib.alignment import landmark_98_to_68 +from facexlib.detection import init_detection_model + +from facexlib.utils import load_file_from_url +from scripts.face3d.util.my_awing_arch import FAN + +def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): + if model_name == 'awing_fan': + model = FAN(num_modules=4, num_landmarks=98, device=device) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) + model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) + model.eval() + model = model.to(device) + return model + + +class KeypointExtractor(): + def __init__(self, device='cuda'): + + ### gfpgan/weights + try: + import webui # in webui + root_path = 'extensions/SadTalker/gfpgan/weights' + + except: + root_path = 'gfpgan/weights' + + self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) + self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) + + def extract_keypoint(self, images, name=None, info=True): + if isinstance(images, list): + keypoints = [] + if info: + i_range = tqdm(images,desc='landmark Det:') + else: + i_range = images + + for image in i_range: + current_kp = self.extract_keypoint(image) + # current_kp = self.detector.get_landmarks(np.array(image)) + if np.mean(current_kp) == -1 and keypoints: + keypoints.append(keypoints[-1]) + else: + keypoints.append(current_kp[None]) + + keypoints = np.concatenate(keypoints, 0) + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + else: + while True: + try: + with torch.no_grad(): + # face detection -> face alignment. + img = np.array(images) + bboxes = self.det_net.detect_faces(images, 0.97) + + bboxes = bboxes[0] + img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] + + keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] + + #### keypoints to the original location + keypoints[:,0] += int(bboxes[0]) + keypoints[:,1] += int(bboxes[1]) + + break + except RuntimeError as e: + if str(e).startswith('CUDA'): + print("Warning: out of memory, sleep for 1s") + time.sleep(1) + else: + print(e) + break + except TypeError: + print('No face detected in this image') + shape = [68, 2] + keypoints = -1. * np.ones(shape) + break + if name is not None: + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + +def read_video(filename): + frames = [] + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + kp_extractor = KeypointExtractor() + images = read_video(filename) + name = filename.split('/')[-2:] + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + kp_extractor.extract_keypoint( + images, + name=os.path.join(opt.output_dir, name[-2], name[-1]) + ) + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=4) + + opt = parser.parse_args() + filenames = list() + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + os.listdir(f'{opt.input_dir}') + print(f'{opt.input_dir}/*.{ext}') + filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) + print('Total number of videos:', len(filenames)) + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/scripts/face3d/models/__init__.py b/scripts/face3d/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c62f3c6f1d1c87c4d45b93c55069d6f15f0e4e34 --- /dev/null +++ b/scripts/face3d/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from scripts.face3d.models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "face3d.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/scripts/face3d/models/__pycache__/__init__.cpython-310.pyc b/scripts/face3d/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4dd2d781e0930610a7626300c12d8342643e36d Binary files /dev/null and b/scripts/face3d/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/face3d/models/__pycache__/base_model.cpython-310.pyc b/scripts/face3d/models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8e7914eb6eec519b35607223fdf1c2a925a5a41 Binary files /dev/null and b/scripts/face3d/models/__pycache__/base_model.cpython-310.pyc differ diff --git a/scripts/face3d/models/__pycache__/networks.cpython-310.pyc b/scripts/face3d/models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d51220522bbcd9bab1813ecc6643fd75f17c9b4f Binary files /dev/null and b/scripts/face3d/models/__pycache__/networks.cpython-310.pyc differ diff --git a/scripts/face3d/models/arcface_torch/README.md b/scripts/face3d/models/arcface_torch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2ee63a861229b68873561fa39bfa7c9a8b53b947 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/README.md @@ -0,0 +1,164 @@ +# Distributed Arcface Training in Pytorch + +This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions +identity on a single server. + +## Requirements + +- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). +- `pip install -r requirements.txt`. +- Download the dataset + from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) + . + +## How to Training + +To train a model, run `train.py` with the path to the configs: + +### 1. Single node, 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +``` + +### 2. Multiple nodes, each node 8 GPUs: + +Node 0: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +Node 1: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +### 3.Training resnet2060 with 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py +``` + +## Model Zoo + +- The models are available for non-commercial research purposes only. +- All models can be found in here. +- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw +- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) + +### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) + +ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face +recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. +As the result, we can evaluate the FAIR performance for different algorithms. + +For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The +globalised multi-racial testset contains 242,143 identities and 1,624,305 images. + +For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). +Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. +There are totally 13,928 positive pairs and 96,983,824 negative pairs. + +| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | +| :---: | :--- | :--- | :--- |:--- |:--- | +| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | +| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | +| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | +| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | +| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | +| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | +| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | +| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | +| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | +| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | + +### Performance on IJB-C and Verification Datasets + +| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | +| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | +| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| +| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| +| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| +| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| +| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| +| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| +| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| +| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| +| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| + +[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) + + +## [Speed Benchmark](docs/speed_benchmark.md) + +**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of +classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same +accuracy with several times faster training performance and smaller GPU memory. +Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a +sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a +sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, +we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed +training and mixed precision training. + +![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) + +More details see +[speed_benchmark.md](docs/speed_benchmark.md) in docs. + +### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) + +`-` means training failed because of gpu memory limitations. + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|1400000 | **1672** | 3043 | 4738 | +|5500000 | **-** | **1389** | 3975 | +|8000000 | **-** | **-** | 3565 | +|16000000 | **-** | **-** | 2679 | +|29000000 | **-** | **-** | **1855** | + +### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|1400000 | 32252 | 11178 | 6056 | +|5500000 | **-** | 32188 | 9854 | +|8000000 | **-** | **-** | 12310 | +|16000000 | **-** | **-** | 19950 | +|29000000 | **-** | **-** | 32324 | + +## Evaluation ICCV2021-MFR and IJB-C + +More details see [eval.md](docs/eval.md) in docs. + +## Test + +We tested many versions of PyTorch. Please create an issue if you are having trouble. + +- [x] torch 1.6.0 +- [x] torch 1.7.1 +- [x] torch 1.8.0 +- [x] torch 1.9.0 + +## Citation + +``` +@inproceedings{deng2019arcface, + title={Arcface: Additive angular margin loss for deep face recognition}, + author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={4690--4699}, + year={2019} +} +@inproceedings{an2020partical_fc, + title={Partial FC: Training 10 Million Identities on a Single Machine}, + author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and + Zhang, Debing and Fu Ying}, + booktitle={Arxiv 2010.05222}, + year={2020} +} +``` diff --git a/scripts/face3d/models/arcface_torch/backbones/__init__.py b/scripts/face3d/models/arcface_torch/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55bd4c5d1889a1a998b52eb56793bbc1eef1b691 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/backbones/__init__.py @@ -0,0 +1,25 @@ +from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 +from .mobilefacenet import get_mbf + + +def get_model(name, **kwargs): + # resnet + if name == "r18": + return iresnet18(False, **kwargs) + elif name == "r34": + return iresnet34(False, **kwargs) + elif name == "r50": + return iresnet50(False, **kwargs) + elif name == "r100": + return iresnet100(False, **kwargs) + elif name == "r200": + return iresnet200(False, **kwargs) + elif name == "r2060": + from .iresnet2060 import iresnet2060 + return iresnet2060(False, **kwargs) + elif name == "mbf": + fp16 = kwargs.get("fp16", False) + num_features = kwargs.get("num_features", 512) + return get_mbf(fp16=fp16, num_features=num_features) + else: + raise ValueError() \ No newline at end of file diff --git a/scripts/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc b/scripts/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d52b7b9080b35650b24cb322f093e07c48eeb09 Binary files /dev/null and b/scripts/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc b/scripts/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac59f22cbc7d834c7f09947f6e80d23830e78f32 Binary files /dev/null and b/scripts/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc differ diff --git a/scripts/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc b/scripts/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d46935d9e4043e5fc2e17e40330ef88645c46a Binary files /dev/null and b/scripts/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc differ diff --git a/scripts/face3d/models/arcface_torch/backbones/iresnet.py b/scripts/face3d/models/arcface_torch/backbones/iresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d3b9c240c24687d432197f976ee01fbf423216 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/backbones/iresnet.py @@ -0,0 +1,187 @@ +import torch +from torch import nn + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) + diff --git a/scripts/face3d/models/arcface_torch/backbones/iresnet2060.py b/scripts/face3d/models/arcface_torch/backbones/iresnet2060.py new file mode 100644 index 0000000000000000000000000000000000000000..21d1122144d207637d2444cba1f68fe630c89f31 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/backbones/iresnet2060.py @@ -0,0 +1,176 @@ +import torch +from torch import nn + +assert torch.__version__ >= "1.8.1" +from torch.utils.checkpoint import checkpoint_sequential + +__all__ = ['iresnet2060'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def checkpoint(self, func, num_seg, x): + if self.training: + return checkpoint_sequential(func, num_seg, x) + else: + return func(x) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.checkpoint(self.layer2, 20, x) + x = self.checkpoint(self.layer3, 100, x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet2060(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/scripts/face3d/models/arcface_torch/backbones/mobilefacenet.py b/scripts/face3d/models/arcface_torch/backbones/mobilefacenet.py new file mode 100644 index 0000000000000000000000000000000000000000..87731491d76f9ff61cc70e57bb3f18c54fae308c --- /dev/null +++ b/scripts/face3d/models/arcface_torch/backbones/mobilefacenet.py @@ -0,0 +1,130 @@ +''' +Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py +Original author cavalleria +''' + +import torch.nn as nn +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module +import torch + + +class Flatten(Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ConvBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(ConvBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), + BatchNorm2d(num_features=out_c), + PReLU(num_parameters=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class LinearBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), + BatchNorm2d(num_features=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class DepthWise(Module): + def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): + super(DepthWise, self).__init__() + self.residual = residual + self.layers = nn.Sequential( + ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), + ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), + LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + ) + + def forward(self, x): + short_cut = None + if self.residual: + short_cut = x + x = self.layers(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) + self.layers = Sequential(*modules) + + def forward(self, x): + return self.layers(x) + + +class GDC(Module): + def __init__(self, embedding_size): + super(GDC, self).__init__() + self.layers = nn.Sequential( + LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), + Flatten(), + Linear(512, embedding_size, bias=False), + BatchNorm1d(embedding_size)) + + def forward(self, x): + return self.layers(x) + + +class MobileFaceNet(Module): + def __init__(self, fp16=False, num_features=512): + super(MobileFaceNet, self).__init__() + scale = 2 + self.fp16 = fp16 + self.layers = nn.Sequential( + ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), + ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), + DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), + Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), + Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), + Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + ) + self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.features = GDC(num_features) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.layers(x) + x = self.conv_sep(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def get_mbf(fp16, num_features): + return MobileFaceNet(fp16, num_features) \ No newline at end of file diff --git a/scripts/face3d/models/arcface_torch/configs/3millions.py b/scripts/face3d/models/arcface_torch/configs/3millions.py new file mode 100644 index 0000000000000000000000000000000000000000..c9edc2f1414e35f93abfd3dfe11a61f1f406580e --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/3millions.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/scripts/face3d/models/arcface_torch/configs/3millions_pfc.py b/scripts/face3d/models/arcface_torch/configs/3millions_pfc.py new file mode 100644 index 0000000000000000000000000000000000000000..77caafdbb300d8109d5bfdb844f131710ef81f20 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/3millions_pfc.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/scripts/face3d/models/arcface_torch/configs/__init__.py b/scripts/face3d/models/arcface_torch/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/face3d/models/arcface_torch/configs/base.py b/scripts/face3d/models/arcface_torch/configs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..78e4b36a9142b649ec39a8c59331bb2557f2ad57 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/base.py @@ -0,0 +1,56 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = "ms1mv3_arcface_r50" + +config.dataset = "ms1m-retinaface-t1" +config.embedding_size = 512 +config.sample_rate = 1 +config.fp16 = False +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +if config.dataset == "emore": + config.rec = "/train_tmp/faces_emore" + config.num_classes = 85742 + config.num_image = 5822653 + config.num_epoch = 16 + config.warmup_epoch = -1 + config.decay_epoch = [8, 14, ] + config.val_targets = ["lfw", ] + +elif config.dataset == "ms1m-retinaface-t1": + config.rec = "/train_tmp/ms1m-retinaface-t1" + config.num_classes = 93431 + config.num_image = 5179510 + config.num_epoch = 25 + config.warmup_epoch = -1 + config.decay_epoch = [11, 17, 22] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "glint360k": + config.rec = "/train_tmp/glint360k" + config.num_classes = 360232 + config.num_image = 17091657 + config.num_epoch = 20 + config.warmup_epoch = -1 + config.decay_epoch = [8, 12, 15, 18] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "webface": + config.rec = "/train_tmp/faces_webface_112x112" + config.num_classes = 10572 + config.num_image = "forget" + config.num_epoch = 34 + config.warmup_epoch = -1 + config.decay_epoch = [20, 28, 32] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/glint360k_mbf.py b/scripts/face3d/models/arcface_torch/configs/glint360k_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..46ae777cc97af41a531cba4e5d1ff31f2efcb468 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/glint360k_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/glint360k_r100.py b/scripts/face3d/models/arcface_torch/configs/glint360k_r100.py new file mode 100644 index 0000000000000000000000000000000000000000..93d0701c0094517cec147c382b005e8063938548 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/glint360k_r100.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/glint360k_r18.py b/scripts/face3d/models/arcface_torch/configs/glint360k_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8db34cd547e8e667103c93585296e47a894e97 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/glint360k_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/glint360k_r34.py b/scripts/face3d/models/arcface_torch/configs/glint360k_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..fda2701758a839a7161d09c25f0ca3d26033baff --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/glint360k_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/glint360k_r50.py b/scripts/face3d/models/arcface_torch/configs/glint360k_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..37e7922f1f63284e356dcc45a5f979f9c105f25e --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/glint360k_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/scripts/face3d/models/arcface_torch/configs/ms1mv3_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a00d6305eeda5a94788017afc1cda0d4a4cd2a --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/ms1mv3_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 20, 25] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4e0d31f1aedf4590628d394e1606920fefb5c9 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r2060.py new file mode 100644 index 0000000000000000000000000000000000000000..23ad81e082c4b6390b67b164d0ceb84bb0635684 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r2060.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r2060" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 64 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..5f78337a3d1f9eb6e9145eb5093618796c6842d2 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..08ba55dbbea6df0afffddbb3d1ed173efad99604 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/ms1mv3_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/scripts/face3d/models/arcface_torch/configs/speed.py b/scripts/face3d/models/arcface_torch/configs/speed.py new file mode 100644 index 0000000000000000000000000000000000000000..45e95237da65e44f35a172c25ac6dc4e313e4eae --- /dev/null +++ b/scripts/face3d/models/arcface_torch/configs/speed.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 100 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/scripts/face3d/models/arcface_torch/dataset.py b/scripts/face3d/models/arcface_torch/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..96bbb8bb6da99122f350bc8e1a6390245840e32b --- /dev/null +++ b/scripts/face3d/models/arcface_torch/dataset.py @@ -0,0 +1,124 @@ +import numbers +import os +import queue as Queue +import threading + +import mxnet as mx +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() + self.transform = transforms.Compose( + [transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + self.root_dir = root_dir + self.local_rank = local_rank + path_imgrec = os.path.join(root_dir, 'train.rec') + path_imgidx = os.path.join(root_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + s = self.imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + if header.flag > 0: + self.header0 = (int(header.label[0]), int(header.label[1])) + self.imgidx = np.array(range(1, int(header.label[0]))) + else: + self.imgidx = np.array(list(self.imgrec.keys)) + + def __getitem__(self, index): + idx = self.imgidx[index] + s = self.imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = torch.tensor(label, dtype=torch.long) + sample = mx.image.imdecode(img).asnumpy() + if self.transform is not None: + sample = self.transform(sample) + return sample, label + + def __len__(self): + return len(self.imgidx) + + +class SyntheticDataset(Dataset): + def __init__(self, local_rank): + super(SyntheticDataset, self).__init__() + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).squeeze(0).float() + img = ((img / 255) - 0.5) / 0.5 + self.img = img + self.label = 1 + + def __getitem__(self, index): + return self.img, self.label + + def __len__(self): + return 1000000 diff --git a/scripts/face3d/models/arcface_torch/docs/eval.md b/scripts/face3d/models/arcface_torch/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..dd1d9e257367b6422680966198646c45e5a2671d --- /dev/null +++ b/scripts/face3d/models/arcface_torch/docs/eval.md @@ -0,0 +1,31 @@ +## Eval on ICCV2021-MFR + +coming soon. + + +## Eval IJBC +You can eval ijbc with pytorch or onnx. + + +1. Eval IJBC With Onnx +```shell +CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 +``` + +2. Eval IJBC With Pytorch +```shell +CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ +--model-prefix ms1mv3_arcface_r50/backbone.pth \ +--image-path IJB_release/IJBC \ +--result-dir ms1mv3_arcface_r50 \ +--batch-size 128 \ +--job ms1mv3_arcface_r50 \ +--target IJBC \ +--network iresnet50 +``` + +## Inference + +```shell +python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 +``` diff --git a/scripts/face3d/models/arcface_torch/docs/install.md b/scripts/face3d/models/arcface_torch/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..6314a40441285e9236438e468caf8b71a407531a --- /dev/null +++ b/scripts/face3d/models/arcface_torch/docs/install.md @@ -0,0 +1,51 @@ +## v1.8.0 +### Linux and Windows +```shell +# CUDA 11.0 +pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 + +# CPU only +pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +``` + + +## v1.7.1 +### Linux and Windows +```shell +# CUDA 11.0 +pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 + +# CUDA 10.1 +pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html +``` + + +## v1.6.0 + +### Linux and Windows +```shell +# CUDA 10.2 +pip install torch==1.6.0 torchvision==0.7.0 + +# CUDA 10.1 +pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html +``` \ No newline at end of file diff --git a/scripts/face3d/models/arcface_torch/docs/modelzoo.md b/scripts/face3d/models/arcface_torch/docs/modelzoo.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/face3d/models/arcface_torch/docs/speed_benchmark.md b/scripts/face3d/models/arcface_torch/docs/speed_benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea --- /dev/null +++ b/scripts/face3d/models/arcface_torch/docs/speed_benchmark.md @@ -0,0 +1,93 @@ +## Test Training Speed + +- Test Commands + +You need to use the following two commands to test the Partial FC training performance. +The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, +batch size is 1024. +```shell +# Model Parallel +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions +# Partial FC 0.1 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc +``` + +- GPU Memory + +``` +# (Model Parallel) gpustat -i +[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB +[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB +[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB +[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB +[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB +[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB +[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB +[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB + +# (Partial FC 0.1) gpustat -i +[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· +[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· +[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· +[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· +[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· +[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· +[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· +[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· +``` + +- Training Speed + +```python +# (Model Parallel) trainging.log +Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 + +# (Partial FC 0.1) trainging.log +Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 +``` + +In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, +and the training speed is 2.5 times faster than the model parallel. + + +## Speed Benchmark + +1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|250000 | 4047 | 4521 | 4976 | +|500000 | 3087 | 4013 | 4900 | +|1000000 | 2090 | 3449 | 4803 | +|1400000 | 1672 | 3043 | 4738 | +|2000000 | - | 2593 | 4626 | +|4000000 | - | 1748 | 4208 | +|5500000 | - | 1389 | 3975 | +|8000000 | - | - | 3565 | +|16000000 | - | - | 2679 | +|29000000 | - | - | 1855 | + +2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|250000 | 9940 | 5826 | 5004 | +|500000 | 14220 | 7114 | 5202 | +|1000000 | 23708 | 9966 | 5620 | +|1400000 | 32252 | 11178 | 6056 | +|2000000 | - | 13978 | 6472 | +|4000000 | - | 23238 | 8284 | +|5500000 | - | 32188 | 9854 | +|8000000 | - | - | 12310 | +|16000000 | - | - | 19950 | +|29000000 | - | - | 32324 | diff --git a/scripts/face3d/models/arcface_torch/eval/__init__.py b/scripts/face3d/models/arcface_torch/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/face3d/models/arcface_torch/eval/verification.py b/scripts/face3d/models/arcface_torch/eval/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..253343b83dbf9d1bd154d14ec068e098bf0968db --- /dev/null +++ b/scripts/face3d/models/arcface_torch/eval/verification.py @@ -0,0 +1,407 @@ +"""Helper for evaluation on the Labeled Faces in the Wild dataset +""" + +# MIT License +# +# Copyright (c) 2016 David Sandberg +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import datetime +import os +import pickle + +import mxnet as mx +import numpy as np +import sklearn +import torch +from mxnet import ndarray as nd +from scipy import interpolate +from sklearn.decomposition import PCA +from sklearn.model_selection import KFold + + +class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits + if self.n_splits > 1: + self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) + + def split(self, indices): + if self.n_splits > 1: + return self.k_fold.split(indices) + else: + return [(indices, indices)] + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], + actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + return tpr, fpr, accuracy + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and(np.logical_not(predict_issame), + np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc + + +def calculate_val(thresholds, + embeddings1, + embeddings2, + actual_issame, + far_target, + nrof_folds=10): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + val = np.zeros(nrof_folds) + far = np.zeros(nrof_folds) + + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + indices = np.arange(nrof_pairs) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + + # Find the threshold that gives FAR = far_target + far_train = np.zeros(nrof_thresholds) + for threshold_idx, threshold in enumerate(thresholds): + _, far_train[threshold_idx] = calculate_val_far( + threshold, dist[train_set], actual_issame[train_set]) + if np.max(far_train) >= far_target: + f = interpolate.interp1d(far_train, thresholds, kind='slinear') + threshold = f(far_target) + else: + threshold = 0.0 + + val[fold_idx], far[fold_idx] = calculate_val_far( + threshold, dist[test_set], actual_issame[test_set]) + + val_mean = np.mean(val) + far_mean = np.mean(far) + val_std = np.std(val) + return val_mean, val_std, far_mean + + +def calculate_val_far(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) + false_accept = np.sum( + np.logical_and(predict_issame, np.logical_not(actual_issame))) + n_same = np.sum(actual_issame) + n_diff = np.sum(np.logical_not(actual_issame)) + # print(true_accept, false_accept) + # print(n_same, n_diff) + val = float(true_accept) / float(n_same) + far = float(false_accept) / float(n_diff) + return val, far + + +def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy = calculate_roc(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) + val, val_std, far = calculate_val(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + 1e-3, + nrof_folds=nrof_folds) + return tpr, fpr, accuracy, val, val_std, far + +@torch.no_grad() +def load_bin(path, image_size): + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 + except UnicodeDecodeError as e: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f, encoding='bytes') # py3 + data_list = [] + for flip in [0, 1]: + data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + data_list.append(data) + for idx in range(len(issame_list) * 2): + _bin = bins[idx] + img = mx.image.imdecode(_bin) + if img.shape[1] != image_size[0]: + img = mx.image.resize_short(img, image_size[0]) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0, 1]: + if flip == 1: + img = mx.ndarray.flip(data=img, axis=2) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + if idx % 1000 == 0: + print('loading bin', idx) + print(data_list[0].shape) + return data_list, issame_list + +@torch.no_grad() +def test(data_set, backbone, batch_size, nfolds=10): + print('testing verification..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + _data = data[bb - batch_size: bb] + time0 = datetime.datetime.now() + img = ((_data / 255) - 0.5) / 0.5 + net_out: torch.Tensor = backbone(img) + _embeddings = net_out.detach().cpu().numpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + + _xnorm = 0.0 + _xnorm_cnt = 0 + for embed in embeddings_list: + for i in range(embed.shape[0]): + _em = embed[i] + _norm = np.linalg.norm(_em) + _xnorm += _norm + _xnorm_cnt += 1 + _xnorm /= _xnorm_cnt + + acc1 = 0.0 + std1 = 0.0 + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + print(embeddings.shape) + print('infer time', time_consumed) + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list + + +def dumpR(data_set, + backbone, + batch_size, + name='', + data_extra=None, + label_shape=None): + print('dump verification embedding..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + + _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) + time0 = datetime.datetime.now() + if data_extra is None: + db = mx.io.DataBatch(data=(_data,), label=(_label,)) + else: + db = mx.io.DataBatch(data=(_data, _data_extra), + label=(_label,)) + model.forward(db, is_train=False) + net_out = model.get_outputs() + _embeddings = net_out[0].asnumpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + actual_issame = np.asarray(issame_list) + outname = os.path.join('temp.bin') + with open(outname, 'wb') as f: + pickle.dump((embeddings, issame_list), + f, + protocol=pickle.HIGHEST_PROTOCOL) + + +# if __name__ == '__main__': +# +# parser = argparse.ArgumentParser(description='do verification') +# # general +# parser.add_argument('--data-dir', default='', help='') +# parser.add_argument('--model', +# default='../model/softmax,50', +# help='path to load model.') +# parser.add_argument('--target', +# default='lfw,cfp_ff,cfp_fp,agedb_30', +# help='test targets.') +# parser.add_argument('--gpu', default=0, type=int, help='gpu id') +# parser.add_argument('--batch-size', default=32, type=int, help='') +# parser.add_argument('--max', default='', type=str, help='') +# parser.add_argument('--mode', default=0, type=int, help='') +# parser.add_argument('--nfolds', default=10, type=int, help='') +# args = parser.parse_args() +# image_size = [112, 112] +# print('image_size', image_size) +# ctx = mx.gpu(args.gpu) +# nets = [] +# vec = args.model.split(',') +# prefix = args.model.split(',')[0] +# epochs = [] +# if len(vec) == 1: +# pdir = os.path.dirname(prefix) +# for fname in os.listdir(pdir): +# if not fname.endswith('.params'): +# continue +# _file = os.path.join(pdir, fname) +# if _file.startswith(prefix): +# epoch = int(fname.split('.')[0].split('-')[1]) +# epochs.append(epoch) +# epochs = sorted(epochs, reverse=True) +# if len(args.max) > 0: +# _max = [int(x) for x in args.max.split(',')] +# assert len(_max) == 2 +# if len(epochs) > _max[1]: +# epochs = epochs[_max[0]:_max[1]] +# +# else: +# epochs = [int(x) for x in vec[1].split('|')] +# print('model number', len(epochs)) +# time0 = datetime.datetime.now() +# for epoch in epochs: +# print('loading', prefix, epoch) +# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) +# all_layers = sym.get_internals() +# sym = all_layers['fc1_output'] +# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) +# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) +# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], +# image_size[1]))]) +# model.set_params(arg_params, aux_params) +# nets.append(model) +# time_now = datetime.datetime.now() +# diff = time_now - time0 +# print('model loading time', diff.total_seconds()) +# +# ver_list = [] +# ver_name_list = [] +# for name in args.target.split(','): +# path = os.path.join(args.data_dir, name + ".bin") +# if os.path.exists(path): +# print('loading.. ', name) +# data_set = load_bin(path, image_size) +# ver_list.append(data_set) +# ver_name_list.append(name) +# +# if args.mode == 0: +# for i in range(len(ver_list)): +# results = [] +# for model in nets: +# acc1, std1, acc2, std2, xnorm, embeddings_list = test( +# ver_list[i], model, args.batch_size, args.nfolds) +# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) +# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) +# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) +# results.append(acc2) +# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) +# elif args.mode == 1: +# raise ValueError +# else: +# model = nets[0] +# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/scripts/face3d/models/arcface_torch/eval_ijbc.py b/scripts/face3d/models/arcface_torch/eval_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5a650d486d18eb02d6f60d448fc3b315261f5d --- /dev/null +++ b/scripts/face3d/models/arcface_torch/eval_ijbc.py @@ -0,0 +1,483 @@ +# coding: utf-8 + +import os +import pickle + +import matplotlib +import pandas as pd + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import timeit +import sklearn +import argparse +import cv2 +import numpy as np +import torch +from skimage import transform as trans +from backbones import get_model +from sklearn.metrics import roc_curve, auc + +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from pathlib import Path + +import sys +import warnings + +sys.path.insert(0, "../") +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser(description='do ijb test') +# general +parser.add_argument('--model-prefix', default='', help='path to load model.') +parser.add_argument('--image-path', default='', type=str, help='') +parser.add_argument('--result-dir', default='.', type=str, help='') +parser.add_argument('--batch-size', default=128, type=int, help='') +parser.add_argument('--network', default='iresnet50', type=str, help='') +parser.add_argument('--job', default='insightface', type=str, help='job name') +parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') +args = parser.parse_args() + +target = args.target +model_path = args.model_prefix +image_path = args.image_path +result_dir = args.result_dir +gpu_id = None +use_norm_score = True # if Ture, TestMode(N1) +use_detector_score = True # if Ture, TestMode(D1) +use_flip_test = True # if Ture, TestMode(F1) +job = args.job +batch_size = args.batch_size + + +class Embedding(object): + def __init__(self, prefix, data_shape, batch_size=1): + image_size = (112, 112) + self.image_size = image_size + weight = torch.load(prefix) + resnet = get_model(args.network, dropout=0, fp16=False).cuda() + resnet.load_state_dict(weight) + model = torch.nn.DataParallel(resnet) + self.model = model + self.model.eval() + src = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + src[:, 0] += 8.0 + self.src = src + self.batch_size = batch_size + self.data_shape = data_shape + + def get(self, rimg, landmark): + + assert landmark.shape[0] == 68 or landmark.shape[0] == 5 + assert landmark.shape[1] == 2 + if landmark.shape[0] == 68: + landmark5 = np.zeros((5, 2), dtype=np.float32) + landmark5[0] = (landmark[36] + landmark[39]) / 2 + landmark5[1] = (landmark[42] + landmark[45]) / 2 + landmark5[2] = landmark[30] + landmark5[3] = landmark[48] + landmark5[4] = landmark[54] + else: + landmark5 = landmark + tform = trans.SimilarityTransform() + tform.estimate(landmark5, self.src) + M = tform.params[0:2, :] + img = cv2.warpAffine(rimg, + M, (self.image_size[1], self.image_size[0]), + borderValue=0.0) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_flip = np.fliplr(img) + img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB + img_flip = np.transpose(img_flip, (2, 0, 1)) + input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) + input_blob[0] = img + input_blob[1] = img_flip + return input_blob + + @torch.no_grad() + def forward_db(self, batch_data): + imgs = torch.Tensor(batch_data).cuda() + imgs.div_(255).sub_(0.5).div_(0.5) + feat = self.model(imgs) + feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) + return feat.cpu().numpy() + + +# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] +def divideIntoNstrand(listTemp, n): + twoList = [[] for i in range(n)] + for i, e in enumerate(listTemp): + twoList[i % n].append(e) + return twoList + + +def read_template_media_list(path): + # ijb_meta = np.loadtxt(path, dtype=str) + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +# In[ ]: + + +def read_template_pair_list(path): + # pairs = np.loadtxt(path, dtype=str) + pairs = pd.read_csv(path, sep=' ', header=None).values + # print(pairs.shape) + # print(pairs[:, 0].astype(np.int)) + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +# In[ ]: + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# In[ ]: + + +def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): + batch_size = args.batch_size + data_shape = (3, 112, 112) + + files = files_list + print('files:', len(files)) + rare_size = len(files) % batch_size + faceness_scores = [] + batch = 0 + img_feats = np.empty((len(files), 1024), dtype=np.float32) + + batch_data = np.empty((2 * batch_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, batch_size) + for img_index, each_line in enumerate(files[:len(files) - rare_size]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + + batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] + batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] + if (img_index + 1) % batch_size == 0: + print('batch', batch) + img_feats[batch * batch_size:batch * batch_size + + batch_size][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + + batch_data = np.empty((2 * rare_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, rare_size) + for img_index, each_line in enumerate(files[len(files) - rare_size:]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + batch_data[2 * img_index][:] = input_blob[0] + batch_data[2 * img_index + 1][:] = input_blob[1] + if (img_index + 1) % rare_size == 0: + print('batch', batch) + img_feats[len(files) - + rare_size:][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 + # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) + return img_feats, faceness_scores + + +# In[ ]: + + +def image2template_feature(img_feats=None, templates=None, medias=None): + # ========================================================== + # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] + # 2. compute media feature. + # 3. compute template feature. + # ========================================================== + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + + for count_template, uqt in enumerate(unique_templates): + + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, + return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [ + np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) + ] + media_norm_feats = np.array(media_norm_feats) + # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + # print(template_norm_feats.shape) + return template_norm_feats, unique_templates + + +# In[ ]: + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + # ========================================================== + # Compute set-to-set Similarity Score. + # ========================================================== + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + + score = np.zeros((len(p1),)) # save cosine distance between pairs + + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +# In[ ]: +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def read_score(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# # Step1: Load Meta Data + +# In[ ]: + +assert target == 'IJBC' or target == 'IJBB' + +# ============================================================= +# load image and template relationships for template feature embedding +# tid --> template id, mid --> media id +# format: +# image_name tid mid +# ============================================================= +start = timeit.default_timer() +templates, medias = read_template_media_list( + os.path.join('%s/meta' % image_path, + '%s_face_tid_mid.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: + +# ============================================================= +# load template pairs for template-to-template verification +# tid : template id, label : 1/0 +# format: +# tid_1 tid_2 label +# ============================================================= +start = timeit.default_timer() +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 2: Get Image Features + +# In[ ]: + +# ============================================================= +# load image features +# format: +# img_feats: [image_num x feats_dim] (227630, 512) +# ============================================================= +start = timeit.default_timer() +img_path = '%s/loose_crop' % image_path +img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) +img_list = open(img_list_path) +files = img_list.readlines() +# files_list = divideIntoNstrand(files, rank_size) +files_list = files + +# img_feats +# for i in range(rank_size): +img_feats, faceness_scores = get_image_feature(img_path, files_list, + model_path, 0, gpu_id) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) +print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], + img_feats.shape[1])) + +# # Step3: Get Template Features + +# In[ ]: + +# ============================================================= +# compute template features from image features. +# ============================================================= +start = timeit.default_timer() +# ========================================================== +# Norm feature before aggregation into template feature? +# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). +# ========================================================== +# 1. FaceScore (Feature Norm) +# 2. FaceScore (Detector) + +if use_flip_test: + # concat --- F1 + # img_input_feats = img_feats + # add --- F2 + img_input_feats = img_feats[:, 0:img_feats.shape[1] // + 2] + img_feats[:, img_feats.shape[1] // 2:] +else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + +if use_norm_score: + img_input_feats = img_input_feats +else: + # normalise features to remove norm information + img_input_feats = img_input_feats / np.sqrt( + np.sum(img_input_feats ** 2, -1, keepdims=True)) + +if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] +else: + img_input_feats = img_input_feats + +template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 4: Get Template Similarity Scores + +# In[ ]: + +# ============================================================= +# compute verification scores between template pairs. +# ============================================================= +start = timeit.default_timer() +score = verification(template_norm_feats, unique_templates, p1, p2) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: +save_path = os.path.join(result_dir, args.job) +# save_path = result_dir + '/%s_result' % target + +if not os.path.exists(save_path): + os.makedirs(save_path) + +score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) +np.save(score_save_file, score) + +# # Step 5: Get ROC Curves and TPR@FPR Table + +# In[ ]: + +files = [score_save_file] +methods = [] +scores = [] +for file in files: + methods.append(Path(file).stem) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) +print(tpr_fpr_table) diff --git a/scripts/face3d/models/arcface_torch/inference.py b/scripts/face3d/models/arcface_torch/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5156e8d649954837e397c2ff15ec29995e7502 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/inference.py @@ -0,0 +1,35 @@ +import argparse + +import cv2 +import numpy as np +import torch + +from backbones import get_model + + +@torch.no_grad() +def inference(weight, name, img): + if img is None: + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + else: + img = cv2.imread(img) + img = cv2.resize(img, (112, 112)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + img.div_(255).sub_(0.5).div_(0.5) + net = get_model(name, fp16=False) + net.load_state_dict(torch.load(weight)) + net.eval() + feat = net(img).numpy() + print(feat) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('--network', type=str, default='r50', help='backbone network') + parser.add_argument('--weight', type=str, default='') + parser.add_argument('--img', type=str, default=None) + args = parser.parse_args() + inference(args.weight, args.network, args.img) diff --git a/scripts/face3d/models/arcface_torch/losses.py b/scripts/face3d/models/arcface_torch/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..87aeaa107af4d53f5a6132b3739d5cafdcded7fc --- /dev/null +++ b/scripts/face3d/models/arcface_torch/losses.py @@ -0,0 +1,42 @@ +import torch +from torch import nn + + +def get_loss(name): + if name == "cosface": + return CosFace() + elif name == "arcface": + return ArcFace() + else: + raise ValueError() + + +class CosFace(nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine[index] -= m_hot + ret = cosine * self.s + return ret + + +class ArcFace(nn.Module): + def __init__(self, s=64.0, m=0.5): + super(ArcFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine: torch.Tensor, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine.acos_() + cosine[index] += m_hot + cosine.cos_().mul_(self.s) + return cosine diff --git a/scripts/face3d/models/arcface_torch/onnx_helper.py b/scripts/face3d/models/arcface_torch/onnx_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ca922ca6d410655029e459cf8fd1c323d276c34c --- /dev/null +++ b/scripts/face3d/models/arcface_torch/onnx_helper.py @@ -0,0 +1,250 @@ +from __future__ import division +import datetime +import os +import os.path as osp +import glob +import numpy as np +import cv2 +import sys +import onnxruntime +import onnx +import argparse +from onnx import numpy_helper +from insightface.data import get_image + +class ArcFaceORT: + def __init__(self, model_path, cpu=False): + self.model_path = model_path + # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" + self.providers = ['CPUExecutionProvider'] if cpu else None + + #input_size is (w,h), return error message, return None if success + def check(self, track='cfat', test_img = None): + #default is cfat + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=15 + if track.startswith('ms1m'): + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=10 + elif track.startswith('glint'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=20 + elif track.startswith('cfat'): + max_model_size_mb = 1024 + max_feat_dim = 512 + max_time_cost = 15 + elif track.startswith('unconstrained'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=30 + else: + return "track not found" + + if not os.path.exists(self.model_path): + return "model_path not exists" + if not os.path.isdir(self.model_path): + return "model_path should be directory" + onnx_files = [] + for _file in os.listdir(self.model_path): + if _file.endswith('.onnx'): + onnx_files.append(osp.join(self.model_path, _file)) + if len(onnx_files)==0: + return "do not have onnx files" + self.model_file = sorted(onnx_files)[-1] + print('use onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('input-shape:', input_shape) + if len(input_shape)!=4: + return "length of input_shape should be 4" + if not isinstance(input_shape[0], str): + #return "input_shape[0] should be str to support batch-inference" + print('reset input-shape[0] to None') + model = onnx.load(self.model_file) + model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') + onnx.save(model, new_model_file) + self.model_file = new_model_file + print('use new onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('new-input-shape:', input_shape) + + self.image_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + outputs = session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + #print(o.name, o.shape) + if len(output_names)!=1: + return "number of output nodes should be 1" + self.session = session + self.input_name = input_name + self.output_names = output_names + #print(self.output_names) + model = onnx.load(self.model_file) + graph = model.graph + if len(graph.node)<8: + return "too small onnx graph" + + input_size = (112,112) + self.crop = None + if track=='cfat': + crop_file = osp.join(self.model_path, 'crop.txt') + if osp.exists(crop_file): + lines = open(crop_file,'r').readlines() + if len(lines)!=6: + return "crop.txt should contain 6 lines" + lines = [int(x) for x in lines] + self.crop = lines[:4] + input_size = tuple(lines[4:6]) + if input_size!=self.image_size: + return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) + + self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) + if self.model_size_mb > max_model_size_mb: + return "max model size exceed, given %.3f-MB"%self.model_size_mb + + input_mean = None + input_std = None + if track=='cfat': + pn_file = osp.join(self.model_path, 'pixel_norm.txt') + if osp.exists(pn_file): + lines = open(pn_file,'r').readlines() + if len(lines)!=2: + return "pixel_norm.txt should contain 2 lines" + input_mean = float(lines[0]) + input_std = float(lines[1]) + if input_mean is not None or input_std is not None: + if input_mean is None or input_std is None: + return "please set input_mean and input_std simultaneously" + else: + find_sub = False + find_mul = False + for nid, node in enumerate(graph.node[:8]): + print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): + find_mul = True + if find_sub and find_mul: + print("find sub and mul") + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + for initn in graph.initializer: + weight_array = numpy_helper.to_array(initn) + dt = weight_array.dtype + if dt.itemsize<4: + return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) + if test_img is None: + test_img = get_image('Tom_Hanks_54745') + test_img = cv2.resize(test_img, self.image_size) + else: + test_img = cv2.resize(test_img, self.image_size) + feat, cost = self.benchmark(test_img) + batch_result = self.check_batch(test_img) + batch_result_sum = float(np.sum(batch_result)) + if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: + print(batch_result) + print(batch_result_sum) + return "batch result output contains NaN!" + + if len(feat.shape) < 2: + return "the shape of the feature must be two, but get {}".format(str(feat.shape)) + + if feat.shape[1] > max_feat_dim: + return "max feat dim exceed, given %d"%feat.shape[1] + self.feat_dim = feat.shape[1] + cost_ms = cost*1000 + if cost_ms>max_time_cost: + return "max time cost exceed, given %.4f"%cost_ms + self.cost_ms = cost_ms + print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) + return None + + def check_batch(self, img): + if not isinstance(img, list): + imgs = [img, ] * 32 + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] + if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: + nimg = cv2.resize(nimg, self.image_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages( + images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, + mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + + def meta_info(self): + return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} + + + def forward(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.image_size + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + return net_out + + def benchmark(self, img): + input_size = self.image_size + if self.crop is not None: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + img = nimg + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + costs = [] + for _ in range(50): + ta = datetime.datetime.now() + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + tb = datetime.datetime.now() + cost = (tb-ta).total_seconds() + costs.append(cost) + costs = sorted(costs) + cost = costs[5] + return net_out, cost + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + # general + parser.add_argument('workdir', help='submitted work dir', type=str) + parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') + args = parser.parse_args() + handler = ArcFaceORT(args.workdir) + err = handler.check(args.track) + print('err:', err) diff --git a/scripts/face3d/models/arcface_torch/onnx_ijbc.py b/scripts/face3d/models/arcface_torch/onnx_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..05b50bfad4b4cf38903b89f596263a8e29a50d3e --- /dev/null +++ b/scripts/face3d/models/arcface_torch/onnx_ijbc.py @@ -0,0 +1,267 @@ +import argparse +import os +import pickle +import timeit + +import cv2 +import mxnet as mx +import numpy as np +import pandas as pd +import prettytable +import skimage.transform +from sklearn.metrics import roc_curve +from sklearn.preprocessing import normalize + +from onnx_helper import ArcFaceORT + +SRC = np.array( + [ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]] + , dtype=np.float32) +SRC[:, 0] += 8.0 + + +class AlignedDataSet(mx.gluon.data.Dataset): + def __init__(self, root, lines, align=True): + self.lines = lines + self.root = root + self.align = align + + def __len__(self): + return len(self.lines) + + def __getitem__(self, idx): + each_line = self.lines[idx] + name_lmk_score = each_line.strip().split(' ') + name = os.path.join(self.root, name_lmk_score[0]) + img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) + landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) + st = skimage.transform.SimilarityTransform() + st.estimate(landmark5, SRC) + img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) + img_1 = np.expand_dims(img, 0) + img_2 = np.expand_dims(np.fliplr(img), 0) + output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) + output = np.transpose(output, (0, 3, 1, 2)) + output = mx.nd.array(output) + return output + + +def extract(model_root, dataset): + model = ArcFaceORT(model_path=model_root) + model.check() + feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) + + def batchify_fn(data): + return mx.nd.concat(*data, dim=0) + + data_loader = mx.gluon.data.DataLoader( + dataset, 128, last_batch='keep', num_workers=4, + thread_pool=True, prefetch=16, batchify_fn=batchify_fn) + num_iter = 0 + for batch in data_loader: + batch = batch.asnumpy() + batch = (batch - model.input_mean) / model.input_std + feat = model.session.run(model.output_names, {model.input_name: batch})[0] + feat = np.reshape(feat, (-1, model.feat_dim * 2)) + feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat + num_iter += 1 + if num_iter % 50 == 0: + print(num_iter) + return feat_mat + + +def read_template_media_list(path): + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +def image2template_feature(img_feats=None, + templates=None, + medias=None): + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + for count_template, uqt in enumerate(unique_templates): + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] + media_norm_feats = np.array(media_norm_feats) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + template_norm_feats = normalize(template_feats) + return template_norm_feats, unique_templates + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) + total_pairs = np.array(range(len(p1))) + batchsize = 100000 + sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def main(args): + use_norm_score = True # if Ture, TestMode(N1) + use_detector_score = True # if Ture, TestMode(D1) + use_flip_test = True # if Ture, TestMode(F1) + assert args.target == 'IJBC' or args.target == 'IJBB' + + start = timeit.default_timer() + templates, medias = read_template_media_list( + os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % args.image_path, + '%s_template_pair_label.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + img_path = '%s/loose_crop' % args.image_path + img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) + img_list = open(img_list_path) + files = img_list.readlines() + dataset = AlignedDataSet(root=img_path, lines=files, align=True) + img_feats = extract(args.model_root, dataset) + + faceness_scores = [] + for each_line in files: + name_lmk_score = each_line.split() + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) + start = timeit.default_timer() + + if use_flip_test: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] + else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + + if use_norm_score: + img_input_feats = img_input_feats + else: + img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) + + if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] + else: + img_input_feats = img_input_feats + + template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + score = verification(template_norm_feats, unique_templates, p1, p2) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) + if not os.path.exists(save_path): + os.makedirs(save_path) + score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) + np.save(score_save_file, score) + files = [score_save_file] + methods = [] + scores = [] + for file in files: + methods.append(os.path.basename(file)) + scores.append(np.load(file)) + methods = np.array(methods) + scores = dict(zip(methods, scores)) + x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] + tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) + for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, args.target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) + print(tpr_fpr_table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='do ijb test') + # general + parser.add_argument('--model-root', default='', help='path to load model.') + parser.add_argument('--image-path', default='', type=str, help='') + parser.add_argument('--result-dir', default='.', type=str, help='') + parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') + main(parser.parse_args()) diff --git a/scripts/face3d/models/arcface_torch/partial_fc.py b/scripts/face3d/models/arcface_torch/partial_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..17e2d25715d10ba446c957e1d2528b0687ed71d5 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/partial_fc.py @@ -0,0 +1,222 @@ +import logging +import os + +import torch +import torch.distributed as dist +from torch.nn import Module +from torch.nn.functional import normalize, linear +from torch.nn.parameter import Parameter + + +class PartialFC(Module): + """ + Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, + Partial FC: Training 10 Million Identities on a Single Machine + See the original paper: + https://arxiv.org/abs/2010.05222 + """ + + @torch.no_grad() + def __init__(self, rank, local_rank, world_size, batch_size, resume, + margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): + """ + rank: int + Unique process(GPU) ID from 0 to world_size - 1. + local_rank: int + Unique process(GPU) ID within the server from 0 to 7. + world_size: int + Number of GPU. + batch_size: int + Batch size on current rank(GPU). + resume: bool + Select whether to restore the weight of softmax. + margin_softmax: callable + A function of margin softmax, eg: cosface, arcface. + num_classes: int + The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, + required. + sample_rate: float + The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling + can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. + embedding_size: int + The feature dimension, default is 512. + prefix: str + Path for save checkpoint, default is './'. + """ + super(PartialFC, self).__init__() + # + self.num_classes: int = num_classes + self.rank: int = rank + self.local_rank: int = local_rank + self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) + self.world_size: int = world_size + self.batch_size: int = batch_size + self.margin_softmax: callable = margin_softmax + self.sample_rate: float = sample_rate + self.embedding_size: int = embedding_size + self.prefix: str = prefix + self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) + self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) + self.num_sample: int = int(self.sample_rate * self.num_local) + + self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) + self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) + + if resume: + try: + self.weight: torch.Tensor = torch.load(self.weight_name) + self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) + if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: + raise IndexError + logging.info("softmax weight resume successfully!") + logging.info("softmax weight mom resume successfully!") + except (FileNotFoundError, KeyError, IndexError): + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init!") + logging.info("softmax weight mom init!") + else: + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init successfully!") + logging.info("softmax weight mom init successfully!") + self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) + + self.index = None + if int(self.sample_rate) == 1: + self.update = lambda: 0 + self.sub_weight = Parameter(self.weight) + self.sub_weight_mom = self.weight_mom + else: + self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) + + def save_params(self): + """ Save softmax weight for each rank on prefix + """ + torch.save(self.weight.data, self.weight_name) + torch.save(self.weight_mom, self.weight_mom_name) + + @torch.no_grad() + def sample(self, total_label): + """ + Sample all positive class centers in each rank, and random select neg class centers to filling a fixed + `num_sample`. + + total_label: tensor + Label after all gather, which cross all GPUs. + """ + index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) + total_label[~index_positive] = -1 + total_label[index_positive] -= self.class_start + if int(self.sample_rate) != 1: + positive = torch.unique(total_label[index_positive], sorted=True) + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local], device=self.device) + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1] + index = index.sort()[0] + else: + index = positive + self.index = index + total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) + self.sub_weight = Parameter(self.weight[index]) + self.sub_weight_mom = self.weight_mom[index] + + def forward(self, total_features, norm_weight): + """ Partial fc forward, `logits = X * sample(W)` + """ + torch.cuda.current_stream().wait_stream(self.stream) + logits = linear(total_features, norm_weight) + return logits + + @torch.no_grad() + def update(self): + """ Set updated weight and weight_mom to memory bank. + """ + self.weight_mom[self.index] = self.sub_weight_mom + self.weight[self.index] = self.sub_weight + + def prepare(self, label, optimizer): + """ + get sampled class centers for cal softmax. + + label: tensor + Label tensor on each rank. + optimizer: opt + Optimizer for partial fc, which need to get weight mom. + """ + with torch.cuda.stream(self.stream): + total_label = torch.zeros( + size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) + dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) + self.sample(total_label) + optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) + optimizer.param_groups[-1]['params'][0] = self.sub_weight + optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom + norm_weight = normalize(self.sub_weight) + return total_label, norm_weight + + def forward_backward(self, label, features, optimizer): + """ + Partial fc forward and backward with model parallel + + label: tensor + Label tensor on each rank(GPU) + features: tensor + Features tensor on each rank(GPU) + optimizer: optimizer + Optimizer for partial fc + + Returns: + -------- + x_grad: tensor + The gradient of features. + loss_v: tensor + Loss value for cross entropy. + """ + total_label, norm_weight = self.prepare(label, optimizer) + total_features = torch.zeros( + size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) + dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) + total_features.requires_grad = True + + logits = self.forward(total_features, norm_weight) + logits = self.margin_softmax(logits, total_label) + + with torch.no_grad(): + max_fc = torch.max(logits, dim=1, keepdim=True)[0] + dist.all_reduce(max_fc, dist.ReduceOp.MAX) + + # calculate exp(logits) and all-reduce + logits_exp = torch.exp(logits - max_fc) + logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) + dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) + + # calculate prob + logits_exp.div_(logits_sum_exp) + + # get one-hot + grad = logits_exp + index = torch.where(total_label != -1)[0] + one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) + one_hot.scatter_(1, total_label[index, None], 1) + + # calculate loss + loss = torch.zeros(grad.size()[0], 1, device=grad.device) + loss[index] = grad[index].gather(1, total_label[index, None]) + dist.all_reduce(loss, dist.ReduceOp.SUM) + loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + + # calculate grad + grad[index] -= one_hot + grad.div_(self.batch_size * self.world_size) + + logits.backward(grad) + if total_features.grad is not None: + total_features.grad.detach_() + x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) + # feature gradient all-reduce + dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) + x_grad = x_grad * self.world_size + # backward backbone + return x_grad, loss_v diff --git a/scripts/face3d/models/arcface_torch/requirement.txt b/scripts/face3d/models/arcface_torch/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..f72c1b3ba814ae1e0bc1c1f56402026978b9e870 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/requirement.txt @@ -0,0 +1,5 @@ +tensorboard +easydict +mxnet +onnx +sklearn diff --git a/scripts/face3d/models/arcface_torch/run.sh b/scripts/face3d/models/arcface_torch/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..61af4b4950eb11334e55362e3e3c5e2796979a01 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/run.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/scripts/face3d/models/arcface_torch/torch2onnx.py b/scripts/face3d/models/arcface_torch/torch2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..fc26ab82e552331bc8d75b34e81000418f4d38ec --- /dev/null +++ b/scripts/face3d/models/arcface_torch/torch2onnx.py @@ -0,0 +1,59 @@ +import numpy as np +import onnx +import torch + + +def convert_onnx(net, path_module, output, opset=11, simplify=False): + assert isinstance(net, torch.nn.Module) + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = img.astype(np.float) + img = (img / 255. - 0.5) / 0.5 # torch style norm + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + + weight = torch.load(path_module) + net.load_state_dict(weight) + net.eval() + torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) + model = onnx.load(output) + graph = model.graph + graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + if simplify: + from onnxsim import simplify + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + onnx.save(model, output) + + +if __name__ == '__main__': + import os + import argparse + from backbones import get_model + + parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') + parser.add_argument('input', type=str, help='input backbone.pth file or path') + parser.add_argument('--output', type=str, default=None, help='output onnx path') + parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') + args = parser.parse_args() + input_file = args.input + if os.path.isdir(input_file): + input_file = os.path.join(input_file, "backbone.pth") + assert os.path.exists(input_file) + model_name = os.path.basename(os.path.dirname(input_file)).lower() + params = model_name.split("_") + if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + if args.network is None: + args.network = params[2] + assert args.network is not None + print(args) + backbone_onnx = get_model(args.network, dropout=0) + + output_path = args.output + if output_path is None: + output_path = os.path.join(os.path.dirname(__file__), 'onnx') + if not os.path.exists(output_path): + os.makedirs(output_path) + assert os.path.isdir(output_path) + output_file = os.path.join(output_path, "%s.onnx" % model_name) + convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) diff --git a/scripts/face3d/models/arcface_torch/train.py b/scripts/face3d/models/arcface_torch/train.py new file mode 100644 index 0000000000000000000000000000000000000000..55eca2d0ad9463415970e09bccab8b722e496704 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/train.py @@ -0,0 +1,141 @@ +import argparse +import logging +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.data.distributed +from torch.nn.utils import clip_grad_norm_ + +import losses +from backbones import get_model +from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX +from partial_fc import PartialFC +from utils.utils_amp import MaxClipGradScaler +from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint +from utils.utils_config import get_config +from utils.utils_logging import AverageMeter, init_logging + + +def main(args): + cfg = get_config(args.config) + try: + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ['RANK']) + dist.init_process_group('nccl') + except KeyError: + world_size = 1 + rank = 0 + dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) + + local_rank = args.local_rank + torch.cuda.set_device(local_rank) + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + if cfg.rec == "synthetic": + train_set = SyntheticDataset(local_rank=local_rank) + else: + train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) + train_loader = DataLoaderX( + local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, + sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) + backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) + + if cfg.resume: + try: + backbone_pth = os.path.join(cfg.output, "backbone.pth") + backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) + if rank == 0: + logging.info("backbone resume successfully!") + except (FileNotFoundError, KeyError, IndexError, RuntimeError): + if rank == 0: + logging.info("resume fail, backbone init successfully!") + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[local_rank]) + backbone.train() + margin_softmax = losses.get_loss(cfg.loss) + module_partial_fc = PartialFC( + rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, + batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, + sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) + + opt_backbone = torch.optim.SGD( + params=[{'params': backbone.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + opt_pfc = torch.optim.SGD( + params=[{'params': module_partial_fc.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + + num_image = len(train_set) + total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch + cfg.total_step = num_image // total_batch_size * cfg.num_epoch + + def lr_step_func(current_step): + cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] + if current_step < cfg.warmup_step: + return current_step / cfg.warmup_step + else: + return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) + + scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_backbone, lr_lambda=lr_step_func) + scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_pfc, lr_lambda=lr_step_func) + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + val_target = cfg.val_targets + callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) + callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) + callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) + + loss = AverageMeter() + start_epoch = 0 + global_step = 0 + grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None + for epoch in range(start_epoch, cfg.num_epoch): + train_sampler.set_epoch(epoch) + for step, (img, label) in enumerate(train_loader): + global_step += 1 + features = F.normalize(backbone(img)) + x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) + if cfg.fp16: + features.backward(grad_amp.scale(x_grad)) + grad_amp.unscale_(opt_backbone) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + grad_amp.step(opt_backbone) + grad_amp.update() + else: + features.backward(x_grad) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + opt_backbone.step() + + opt_pfc.step() + module_partial_fc.update() + opt_backbone.zero_grad() + opt_pfc.zero_grad() + loss.update(loss_v, 1) + callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) + callback_verification(global_step, backbone) + scheduler_backbone.step() + scheduler_pfc.step() + callback_checkpoint(global_step, backbone, module_partial_fc) + dist.destroy_process_group() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('config', type=str, help='py config file') + parser.add_argument('--local_rank', type=int, default=0, help='local_rank') + main(parser.parse_args()) diff --git a/scripts/face3d/models/arcface_torch/utils/__init__.py b/scripts/face3d/models/arcface_torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/face3d/models/arcface_torch/utils/plot.py b/scripts/face3d/models/arcface_torch/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc588e5c01ca550b69c385aeb3fd139c59fb88a --- /dev/null +++ b/scripts/face3d/models/arcface_torch/utils/plot.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from sklearn.metrics import roc_curve, auc + +image_path = "/data/anxiang/IJB_release/IJBC" +files = [ + "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" +] + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % 'ijbc')) + +methods = [] +scores = [] +for file in files: + methods.append(file.split('/')[-2]) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, "IJBC")) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +print(tpr_fpr_table) diff --git a/scripts/face3d/models/arcface_torch/utils/utils_amp.py b/scripts/face3d/models/arcface_torch/utils/utils_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac2a03f4212faa129faed447a8f4519c0a00a8b --- /dev/null +++ b/scripts/face3d/models/arcface_torch/utils/utils_amp.py @@ -0,0 +1,88 @@ +from typing import Dict, List + +import torch + +if torch.__version__ < '1.9': + Iterable = torch._six.container_abcs.Iterable +else: + import collections + + Iterable = collections.abc.Iterable +from torch.cuda.amp import GradScaler + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +class MaxClipGradScaler(GradScaler): + def __init__(self, init_scale, max_scale: float, growth_interval=100): + GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) + self.max_scale = max_scale + + def scale_clip(self): + if self.get_scale() == self.max_scale: + self.set_growth_factor(1) + elif self.get_scale() < self.max_scale: + self.set_growth_factor(2) + elif self.get_scale() > self.max_scale: + self._scale.fill_(self.max_scale) + self.set_growth_factor(1) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + self.scale_clip() + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) diff --git a/scripts/face3d/models/arcface_torch/utils/utils_callbacks.py b/scripts/face3d/models/arcface_torch/utils/utils_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2f56cba47c57de102710ff56eaac591e59f4da --- /dev/null +++ b/scripts/face3d/models/arcface_torch/utils/utils_callbacks.py @@ -0,0 +1,117 @@ +import logging +import os +import time +from typing import List + +import torch + +from eval import verification +from utils.utils_logging import AverageMeter + + +class CallBackVerification(object): + def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): + self.frequent: int = frequent + self.rank: int = rank + self.highest_acc: float = 0.0 + self.highest_acc_list: List[float] = [0.0] * len(val_targets) + self.ver_list: List[object] = [] + self.ver_name_list: List[str] = [] + if self.rank is 0: + self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + + def ver_test(self, backbone: torch.nn.Module, global_step: int): + results = [] + for i in range(len(self.ver_list)): + acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( + self.ver_list[i], backbone, 10, 10) + logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) + logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + if acc2 > self.highest_acc_list[i]: + self.highest_acc_list[i] = acc2 + logging.info( + '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) + results.append(acc2) + + def init_dataset(self, val_targets, data_dir, image_size): + for name in val_targets: + path = os.path.join(data_dir, name + ".bin") + if os.path.exists(path): + data_set = verification.load_bin(path, image_size) + self.ver_list.append(data_set) + self.ver_name_list.append(name) + + def __call__(self, num_update, backbone: torch.nn.Module): + if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: + backbone.eval() + self.ver_test(backbone, num_update) + backbone.train() + + +class CallBackLogging(object): + def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): + self.frequent: int = frequent + self.rank: int = rank + self.time_start = time.time() + self.total_step: int = total_step + self.batch_size: int = batch_size + self.world_size: int = world_size + self.writer = writer + + self.init = False + self.tic = 0 + + def __call__(self, + global_step: int, + loss: AverageMeter, + epoch: int, + fp16: bool, + learning_rate: float, + grad_scaler: torch.cuda.amp.GradScaler): + if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: + if self.init: + try: + speed: float = self.frequent * self.batch_size / (time.time() - self.tic) + speed_total = speed * self.world_size + except ZeroDivisionError: + speed_total = float('inf') + + time_now = (time.time() - self.time_start) / 3600 + time_total = time_now / ((global_step + 1) / self.total_step) + time_for_end = time_total - time_now + if self.writer is not None: + self.writer.add_scalar('time_for_end', time_for_end, global_step) + self.writer.add_scalar('learning_rate', learning_rate, global_step) + self.writer.add_scalar('loss', loss.avg, global_step) + if fp16: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, + grad_scaler.get_scale(), time_for_end + ) + else: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end + ) + logging.info(msg) + loss.reset() + self.tic = time.time() + else: + self.init = True + self.tic = time.time() + + +class CallBackModelCheckpoint(object): + def __init__(self, rank, output="./"): + self.rank: int = rank + self.output: str = output + + def __call__(self, global_step, backbone, partial_fc, ): + if global_step > 100 and self.rank == 0: + path_module = os.path.join(self.output, "backbone.pth") + torch.save(backbone.module.state_dict(), path_module) + logging.info("Pytorch Model Saved in '{}'".format(path_module)) + + if global_step > 100 and partial_fc is not None: + partial_fc.save_params() diff --git a/scripts/face3d/models/arcface_torch/utils/utils_config.py b/scripts/face3d/models/arcface_torch/utils/utils_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0c02eaf70fc0140aca7925f621c29a496f491cae --- /dev/null +++ b/scripts/face3d/models/arcface_torch/utils/utils_config.py @@ -0,0 +1,16 @@ +import importlib +import os.path as osp + + +def get_config(config_file): + assert config_file.startswith('configs/'), 'config file setting must start with configs/' + temp_config_name = osp.basename(config_file) + temp_module_name = osp.splitext(temp_config_name)[0] + config = importlib.import_module("configs.base") + cfg = config.config + config = importlib.import_module("configs.%s" % temp_module_name) + job_cfg = config.config + cfg.update(job_cfg) + if cfg.output is None: + cfg.output = osp.join('work_dirs', temp_module_name) + return cfg \ No newline at end of file diff --git a/scripts/face3d/models/arcface_torch/utils/utils_logging.py b/scripts/face3d/models/arcface_torch/utils/utils_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..c787b6aae7cd037a4718df44d672b8ffa9e5c249 --- /dev/null +++ b/scripts/face3d/models/arcface_torch/utils/utils_logging.py @@ -0,0 +1,41 @@ +import logging +import os +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value + """ + + def __init__(self): + self.val = None + self.avg = None + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_logging(rank, models_root): + if rank == 0: + log_root = logging.getLogger() + log_root.setLevel(logging.INFO) + formatter = logging.Formatter("Training: %(asctime)s-%(message)s") + handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) + handler_stream = logging.StreamHandler(sys.stdout) + handler_file.setFormatter(formatter) + handler_stream.setFormatter(formatter) + log_root.addHandler(handler_file) + log_root.addHandler(handler_stream) + log_root.info('rank_id: %d' % rank) diff --git a/scripts/face3d/models/arcface_torch/utils/utils_os.py b/scripts/face3d/models/arcface_torch/utils/utils_os.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/face3d/models/base_model.py b/scripts/face3d/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe64a7f739ad8f8cfbf3073a2bf49e1468127fd --- /dev/null +++ b/scripts/face3d/models/base_model.py @@ -0,0 +1,316 @@ +"""This script defines the base network model for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.isTrain = False + self.device = torch.device('cpu') + self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.parallel_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def dict_grad_hook_factory(add_func=lambda x: x): + saved_dict = dict() + + def hook_gen(name): + def grad_hook(grad): + saved_vals = add_func(grad) + saved_dict[name] = saved_vals + return grad_hook + return hook_gen, saved_dict + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + load_suffix = opt.epoch + self.load_networks(load_suffix) + + + # self.print_networks(opt.verbose) + + def parallelize(self, convert_sync_batchnorm=True): + if not self.opt.use_ddp: + for name in self.parallel_names: + if isinstance(name, str): + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + else: + for name in self.model_names: + if isinstance(name, str): + module = getattr(self, name) + if convert_sync_batchnorm: + module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) + setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), + device_ids=[self.device.index], + find_unused_parameters=True, broadcast_buffers=True)) + + # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. + for name in self.parallel_names: + if isinstance(name, str) and name not in self.model_names: + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + + # put state_dict of optimizer to gpu device + if self.opt.phase != 'test': + if self.opt.continue_train: + for optim in self.optimizers: + for state in optim.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + def data_dependent_initialize(self, data): + pass + + def train(self): + """Make models train mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train() + + def eval(self): + """Make models eval mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self, name='A'): + """ Return image paths that are used to load current data""" + return self.image_paths if name =='A' else self.image_paths_B + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name)[:, :3, ...] + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if not os.path.isdir(self.save_dir): + os.makedirs(self.save_dir) + + save_filename = 'epoch_%s.pth' % (epoch) + save_path = os.path.join(self.save_dir, save_filename) + + save_dict = {} + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel) or isinstance(net, + torch.nn.parallel.DistributedDataParallel): + net = net.module + save_dict[name] = net.state_dict() + + + for i, optim in enumerate(self.optimizers): + save_dict['opt_%02d'%i] = optim.state_dict() + + for i, sched in enumerate(self.schedulers): + save_dict['sched_%02d'%i] = sched.state_dict() + + torch.save(save_dict, save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if self.opt.isTrain and self.opt.pretrained_name is not None: + load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) + else: + load_dir = self.save_dir + load_filename = 'epoch_%s.pth' % (epoch) + load_path = os.path.join(load_dir, load_filename) + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % load_path) + + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + net.load_state_dict(state_dict[name]) + + if self.opt.phase != 'test': + if self.opt.continue_train: + print('loading the optim from %s' % load_path) + for i, optim in enumerate(self.optimizers): + optim.load_state_dict(state_dict['opt_%02d'%i]) + + try: + print('loading the sched from %s' % load_path) + for i, sched in enumerate(self.schedulers): + sched.load_state_dict(state_dict['sched_%02d'%i]) + except: + print('Failed to load schedulers, set schedulers according to epoch count manually') + for i, sched in enumerate(self.schedulers): + sched.last_epoch = self.opt.epoch_count - 1 + + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def generate_visuals_for_evaluation(self, data, mode): + return {} diff --git a/scripts/face3d/models/bfm.py b/scripts/face3d/models/bfm.py new file mode 100644 index 0000000000000000000000000000000000000000..07a36480262bcdd16aac4fea966f4d594aa6651d --- /dev/null +++ b/scripts/face3d/models/bfm.py @@ -0,0 +1,331 @@ +"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.io import loadmat +from scripts.face3d.util.load_mats import transferBFM09 +import os + +def perspective_projection(focal, center): + # return p.T (N, 3) @ (3, 3) + return np.array([ + focal, 0, center, + 0, focal, center, + 0, 0, 1 + ]).reshape([3, 3]).astype(np.float32).transpose() + +class SH: + def __init__(self): + self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] + self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] + + + +class ParametricFaceModel: + def __init__(self, + bfm_folder='./BFM', + recenter=True, + camera_distance=10., + init_lit=np.array([ + 0.8, 0, 0, 0, 0, 0, 0, 0, 0 + ]), + focal=1015., + center=112., + is_train=True, + default_name='BFM_model_front.mat'): + + if not os.path.isfile(os.path.join(bfm_folder, default_name)): + transferBFM09(bfm_folder) + + model = loadmat(os.path.join(bfm_folder, default_name)) + # mean face shape. [3*N,1] + self.mean_shape = model['meanshape'].astype(np.float32) + # identity basis. [3*N,80] + self.id_base = model['idBase'].astype(np.float32) + # expression basis. [3*N,64] + self.exp_base = model['exBase'].astype(np.float32) + # mean face texture. [3*N,1] (0-255) + self.mean_tex = model['meantex'].astype(np.float32) + # texture basis. [3*N,80] + self.tex_base = model['texBase'].astype(np.float32) + # face indices for each vertex that lies in. starts from 0. [N,8] + self.point_buf = model['point_buf'].astype(np.int64) - 1 + # vertex indices for each face. starts from 0. [F,3] + self.face_buf = model['tri'].astype(np.int64) - 1 + # vertex indices for 68 landmarks. starts from 0. [68,1] + self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 + + if is_train: + # vertex indices for small face region to compute photometric error. starts from 0. + self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 + # vertex indices for each face from small face region. starts from 0. [f,3] + self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 + # vertex indices for pre-defined skin region to compute reflectance loss + self.skin_mask = np.squeeze(model['skinmask']) + + if recenter: + mean_shape = self.mean_shape.reshape([-1, 3]) + mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) + self.mean_shape = mean_shape.reshape([-1, 1]) + + self.persc_proj = perspective_projection(focal, center) + self.device = 'cpu' + self.camera_distance = camera_distance + self.SH = SH() + self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) + + + def to(self, device): + self.device = device + for key, value in self.__dict__.items(): + if type(value).__module__ == np.__name__: + setattr(self, key, torch.tensor(value).to(device)) + + + def compute_shape(self, id_coeff, exp_coeff): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) + + Parameters: + id_coeff -- torch.tensor, size (B, 80), identity coeffs + exp_coeff -- torch.tensor, size (B, 64), expression coeffs + """ + batch_size = id_coeff.shape[0] + id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) + exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) + face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) + return face_shape.reshape([batch_size, -1, 3]) + + + def compute_texture(self, tex_coeff, normalize=True): + """ + Return: + face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) + + Parameters: + tex_coeff -- torch.tensor, size (B, 80) + """ + batch_size = tex_coeff.shape[0] + face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex + if normalize: + face_texture = face_texture / 255. + return face_texture.reshape([batch_size, -1, 3]) + + + def compute_norm(self, face_shape): + """ + Return: + vertex_norm -- torch.tensor, size (B, N, 3) + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + + v1 = face_shape[:, self.face_buf[:, 0]] + v2 = face_shape[:, self.face_buf[:, 1]] + v3 = face_shape[:, self.face_buf[:, 2]] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = torch.cross(e1, e2, dim=-1) + face_norm = F.normalize(face_norm, dim=-1, p=2) + face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) + + vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) + vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) + return vertex_norm + + + def compute_color(self, face_texture, face_norm, gamma): + """ + Return: + face_color -- torch.tensor, size (B, N, 3), range (0, 1.) + + Parameters: + face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) + face_norm -- torch.tensor, size (B, N, 3), rotated face normal + gamma -- torch.tensor, size (B, 27), SH coeffs + """ + batch_size = gamma.shape[0] + v_num = face_texture.shape[1] + a, c = self.SH.a, self.SH.c + gamma = gamma.reshape([batch_size, 3, 9]) + gamma = gamma + self.init_lit + gamma = gamma.permute(0, 2, 1) + Y = torch.cat([ + a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), + -a[1] * c[1] * face_norm[..., 1:2], + a[1] * c[1] * face_norm[..., 2:], + -a[1] * c[1] * face_norm[..., :1], + a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], + -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], + 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), + -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], + 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) + ], dim=-1) + r = Y @ gamma[..., :1] + g = Y @ gamma[..., 1:2] + b = Y @ gamma[..., 2:] + face_color = torch.cat([r, g, b], dim=-1) * face_texture + return face_color + + + def compute_rotation(self, angles): + """ + Return: + rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat + + Parameters: + angles -- torch.tensor, size (B, 3), radian + """ + + batch_size = angles.shape[0] + ones = torch.ones([batch_size, 1]).to(self.device) + zeros = torch.zeros([batch_size, 1]).to(self.device) + x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([batch_size, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) + + + def to_camera(self, face_shape): + face_shape[..., -1] = self.camera_distance - face_shape[..., -1] + return face_shape + + def to_image(self, face_shape): + """ + Return: + face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + # to image_plane + face_proj = face_shape @ self.persc_proj + face_proj = face_proj[..., :2] / face_proj[..., 2:] + + return face_proj + + + def transform(self, face_shape, rot, trans): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + rot -- torch.tensor, size (B, 3, 3) + trans -- torch.tensor, size (B, 3) + """ + return face_shape @ rot + trans.unsqueeze(1) + + + def get_landmarks(self, face_proj): + """ + Return: + face_lms -- torch.tensor, size (B, 68, 2) + + Parameters: + face_proj -- torch.tensor, size (B, N, 2) + """ + return face_proj[:, self.keypoints] + + def split_coeff(self, coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + def compute_for_render(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + rotation = self.compute_rotation(coef_dict['angle']) + + + face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape_transformed) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + def compute_for_render_woRotation(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + #rotation = self.compute_rotation(coef_dict['angle']) + + + #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm # @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/scripts/face3d/models/facerecon_model.py b/scripts/face3d/models/facerecon_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b23907662468d92b92b4b966d55937d0c00c32f6 --- /dev/null +++ b/scripts/face3d/models/facerecon_model.py @@ -0,0 +1,220 @@ +"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +from scripts.face3d.models.base_model import BaseModel +from scripts.face3d.models import networks +from scripts.face3d.models.bfm import ParametricFaceModel +from scripts.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss +from scripts.face3d.util import util +from scripts.face3d.util.nvdiffrast import MeshRenderer +# from scripts.face3d.util.preprocess import estimate_norm_torch + +import trimesh +from scipy.io import savemat + +class FaceReconModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=False): + """ Configures options specific for CUT model + """ + # net structure and parameters + parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') + parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth') + parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') + parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/') + parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') + + # renderer parameters + parser.add_argument('--focal', type=float, default=1015.) + parser.add_argument('--center', type=float, default=112.) + parser.add_argument('--camera_d', type=float, default=10.) + parser.add_argument('--z_near', type=float, default=5.) + parser.add_argument('--z_far', type=float, default=15.) + + if is_train: + # training parameters + parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') + parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') + parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') + parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') + + + # augmentation parameters + parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') + parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') + parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') + + # loss weights + parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') + parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') + parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') + parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') + parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') + parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') + parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') + parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') + parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') + + opt, _ = parser.parse_known_args() + parser.set_defaults( + focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. + ) + if is_train: + parser.set_defaults( + use_crop_face=True, use_predef_M=False + ) + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + + self.visual_names = ['output_vis'] + self.model_names = ['net_recon'] + self.parallel_names = self.model_names + ['renderer'] + + self.facemodel = ParametricFaceModel( + bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, + is_train=self.isTrain, default_name=opt.bfm_model + ) + + fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi + self.renderer = MeshRenderer( + rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) + ) + + if self.isTrain: + self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] + + self.net_recog = networks.define_net_recog( + net_recog=opt.net_recog, pretrained_path=opt.net_recog_path + ) + # loss func name: (compute_%s_loss) % loss_name + self.compute_feat_loss = perceptual_loss + self.comupte_color_loss = photo_loss + self.compute_lm_loss = landmark_loss + self.compute_reg_loss = reg_loss + self.compute_reflc_loss = reflectance_loss + + self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) + self.optimizers = [self.optimizer] + self.parallel_names += ['net_recog'] + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + self.input_img = input['imgs'].to(self.device) + self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None + self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None + self.trans_m = input['M'].to(self.device) if 'M' in input else None + self.image_paths = input['im_paths'] if 'im_paths' in input else None + + def forward(self, output_coeff, device): + self.facemodel.to(device) + self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ + self.facemodel.compute_for_render(output_coeff) + self.pred_mask, _, self.pred_face = self.renderer( + self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) + + self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) + + + def compute_losses(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + + assert self.net_recog.training == False + trans_m = self.trans_m + if not self.opt.use_predef_M: + trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) + + pred_feat = self.net_recog(self.pred_face, trans_m) + gt_feat = self.net_recog(self.input_img, self.trans_m) + self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) + + face_mask = self.pred_mask + if self.opt.use_crop_face: + face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) + + face_mask = face_mask.detach() + self.loss_color = self.opt.w_color * self.comupte_color_loss( + self.pred_face, self.input_img, self.atten_mask * face_mask) + + loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) + self.loss_reg = self.opt.w_reg * loss_reg + self.loss_gamma = self.opt.w_gamma * loss_gamma + + self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) + + self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) + + self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + + self.loss_lm + self.loss_reflc + + + def optimize_parameters(self, isTrain=True): + self.forward() + self.compute_losses() + """Update network weights; it will be called in every training iteration.""" + if isTrain: + self.optimizer.zero_grad() + self.loss_all.backward() + self.optimizer.step() + + def compute_visuals(self): + with torch.no_grad(): + input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() + output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img + output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() + + if self.gt_lm is not None: + gt_lm_numpy = self.gt_lm.cpu().numpy() + pred_lm_numpy = self.pred_lm.detach().cpu().numpy() + output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') + output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') + + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw, output_vis_numpy), axis=-2) + else: + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw), axis=-2) + + self.output_vis = torch.tensor( + output_vis_numpy / 255., dtype=torch.float32 + ).permute(0, 3, 1, 2).to(self.device) + + def save_mesh(self, name): + + recon_shape = self.pred_vertex # get reconstructed shape + recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space + recon_shape = recon_shape.cpu().numpy()[0] + recon_color = self.pred_color + recon_color = recon_color.cpu().numpy()[0] + tri = self.facemodel.face_buf.cpu().numpy() + mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) + mesh.export(name) + + def save_coeff(self,name): + + pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} + pred_lm = self.pred_lm.cpu().numpy() + pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate + pred_coeffs['lm68'] = pred_lm + savemat(name,pred_coeffs) + + + diff --git a/scripts/face3d/models/losses.py b/scripts/face3d/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..09d6a85870af1ef2b857e4a3fdd4b2f7fc991317 --- /dev/null +++ b/scripts/face3d/models/losses.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn as nn +from kornia.geometry import warp_affine +import torch.nn.functional as F + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) + +### perceptual level loss +class PerceptualLoss(nn.Module): + def __init__(self, recog_net, input_size=112): + super(PerceptualLoss, self).__init__() + self.recog_net = recog_net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + def forward(imageA, imageB, M): + """ + 1 - cosine distance + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order + imageB --same as imageA + """ + + imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) + imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) + + # freeze bn + self.recog_net.eval() + + id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) + id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +def perceptual_loss(id_featureA, id_featureB): + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +### image level loss +def photo_loss(imageA, imageB, mask, eps=1e-6): + """ + l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order + imageB --same as imageA + """ + loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask + loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) + return loss + +def landmark_loss(predict_lm, gt_lm, weight=None): + """ + weighted mse loss + Parameters: + predict_lm --torch.tensor (B, 68, 2) + gt_lm --torch.tensor (B, 68, 2) + weight --numpy.array (1, 68) + """ + if not weight: + weight = np.ones([68]) + weight[28:31] = 20 + weight[-8:] = 20 + weight = np.expand_dims(weight, 0) + weight = torch.tensor(weight).to(predict_lm.device) + loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight + loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) + return loss + + +### regulization +def reg_loss(coeffs_dict, opt=None): + """ + l2 norm without the sqrt, from yu's implementation (mse) + tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss + Parameters: + coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans + + """ + # coefficient regularization to ensure plausible 3d faces + if opt: + w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex + else: + w_id, w_exp, w_tex = 1, 1, 1, 1 + creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ + w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ + w_tex * torch.sum(coeffs_dict['tex'] ** 2) + creg_loss = creg_loss / coeffs_dict['id'].shape[0] + + # gamma regularization to ensure a nearly-monochromatic light + gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) + gamma_mean = torch.mean(gamma, dim=1, keepdims=True) + gamma_loss = torch.mean((gamma - gamma_mean) ** 2) + + return creg_loss, gamma_loss + +def reflectance_loss(texture, mask): + """ + minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo + Parameters: + texture --torch.tensor, (B, N, 3) + mask --torch.tensor, (N), 1 or 0 + + """ + mask = mask.reshape([1, mask.shape[0], 1]) + texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) + loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) + return loss + diff --git a/scripts/face3d/models/networks.py b/scripts/face3d/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..ead9cdcb8720b845c233de79dc8a8d1668492108 --- /dev/null +++ b/scripts/face3d/models/networks.py @@ -0,0 +1,521 @@ +"""This script defines deep neural networks for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch.nn.functional as F +from torch.nn import init +import functools +from torch.optim import lr_scheduler +import torch +from torch import Tensor +import torch.nn as nn +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional +from .arcface_torch.backbones import get_model +from kornia.geometry import warp_affine + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_net_recon(net_recon, use_last_fc=False, init_path=None): + return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) + +def define_net_recog(net_recog, pretrained_path=None): + net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) + net.eval() + return net + +class ReconNetWrapper(nn.Module): + fc_dim=257 + def __init__(self, net_recon, use_last_fc=False, init_path=None): + super(ReconNetWrapper, self).__init__() + self.use_last_fc = use_last_fc + if net_recon not in func_dict: + return NotImplementedError('network [%s] is not implemented', net_recon) + func, last_dim = func_dict[net_recon] + backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) + if init_path and os.path.isfile(init_path): + state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) + backbone.load_state_dict(state_dict) + print("loading init net_recon %s from %s" %(net_recon, init_path)) + self.backbone = backbone + if not use_last_fc: + self.final_layers = nn.ModuleList([ + conv1x1(last_dim, 80, bias=True), # id layer + conv1x1(last_dim, 64, bias=True), # exp layer + conv1x1(last_dim, 80, bias=True), # tex layer + conv1x1(last_dim, 3, bias=True), # angle layer + conv1x1(last_dim, 27, bias=True), # gamma layer + conv1x1(last_dim, 2, bias=True), # tx, ty + conv1x1(last_dim, 1, bias=True) # tz + ]) + for m in self.final_layers: + nn.init.constant_(m.weight, 0.) + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + x = self.backbone(x) + if not self.use_last_fc: + output = [] + for layer in self.final_layers: + output.append(layer(x)) + x = torch.flatten(torch.cat(output, dim=1), 1) + return x + + +class RecogNetWrapper(nn.Module): + def __init__(self, net_recog, pretrained_path=None, input_size=112): + super(RecogNetWrapper, self).__init__() + net = get_model(name=net_recog, fp16=False) + if pretrained_path: + state_dict = torch.load(pretrained_path, map_location='cpu') + net.load_state_dict(state_dict) + print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) + for param in net.parameters(): + param.requires_grad = False + self.net = net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + + def forward(self, image, M): + image = self.preprocess(resize_n_crop(image, M, self.input_size)) + id_feature = F.normalize(self.net(image), dim=-1, p=2) + return id_feature + + +# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + use_last_fc: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.use_last_fc = use_last_fc + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if self.use_last_fc: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + if self.use_last_fc: + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +func_dict = { + 'resnet18': (resnet18, 512), + 'resnet50': (resnet50, 2048) +} diff --git a/scripts/face3d/models/template_model.py b/scripts/face3d/models/template_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dac7b33d5889777eb63c9882a3b9fa094dcab293 --- /dev/null +++ b/scripts/face3d/models/template_model.py @@ -0,0 +1,100 @@ +"""Model class template + +This module provides a template for users to implement custom models. +You can specify '--model template' to use this model. +The class name should be consistent with both the filename and its model option. +The filename should be _dataset.py +The class name should be Dataset.py +It implements a simple image-to-image translation baseline based on regression loss. +Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: + min_ ||netG(data_A) - data_B||_1 +You need to implement the following functions: + : Add model-specific options and rewrite default values for existing options. + <__init__>: Initialize this model class. + : Unpack input data and perform data pre-processing. + : Run forward pass. This will be called by both and . + : Update network weights; it will be called in every training iteration. +""" +import numpy as np +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/scripts/face3d/options/__init__.py b/scripts/face3d/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90 --- /dev/null +++ b/scripts/face3d/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/scripts/face3d/options/base_options.py b/scripts/face3d/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f921d5a43434ae802a55a0fa3889c4b7ab9f6d --- /dev/null +++ b/scripts/face3d/options/base_options.py @@ -0,0 +1,169 @@ +"""This script contains base options for Deep3DFaceRecon_pytorch +""" + +import argparse +import os +from util import util +import numpy as np +import torch +import face3d.models as models +import face3d.data as data + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self, cmd_line=None): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + self.cmd_line = None + if cmd_line is not None: + self.cmd_line = cmd_line.split() + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') + parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') + parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') + parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') + parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') + parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') + parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') + + # model parameters + parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') + + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + if self.cmd_line is None: + opt, _ = parser.parse_known_args() + else: + opt, _ = parser.parse_known_args(self.cmd_line) + + # set cuda visible devices + os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + if self.cmd_line is None: + opt, _ = parser.parse_known_args() # parse again with new defaults + else: + opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults + + # modify dataset-related parser options + if opt.dataset_mode: + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + if self.cmd_line is None: + return parser.parse_args() + else: + return parser.parse_args(self.cmd_line) + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + try: + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + except PermissionError as error: + print("permission error {}".format(error)) + pass + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + gpu_ids.append(id) + opt.world_size = len(gpu_ids) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(gpu_ids[0]) + if opt.world_size == 1: + opt.use_ddp = False + + if opt.phase != 'test': + # set continue_train automatically + if opt.pretrained_name is None: + model_dir = os.path.join(opt.checkpoints_dir, opt.name) + else: + model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) + if os.path.isdir(model_dir): + model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] + if os.path.isdir(model_dir) and len(model_pths) != 0: + opt.continue_train= True + + # update the latest epoch count + if opt.continue_train: + if opt.epoch == 'latest': + epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] + if len(epoch_counts) != 0: + opt.epoch_count = max(epoch_counts) + 1 + else: + opt.epoch_count = int(opt.epoch) + 1 + + + self.print_options(opt) + self.opt = opt + return self.opt diff --git a/scripts/face3d/options/inference_options.py b/scripts/face3d/options/inference_options.py new file mode 100644 index 0000000000000000000000000000000000000000..c453965959ab4cfb31acbc424f994db68c3d4df5 --- /dev/null +++ b/scripts/face3d/options/inference_options.py @@ -0,0 +1,23 @@ +from face3d.options.base_options import BaseOptions + + +class InferenceOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') + parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') + parser.add_argument('--save_split_files', action='store_true', help='save split files or not') + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/scripts/face3d/options/test_options.py b/scripts/face3d/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff3ad142779850d1d5a1640bc00f70d34d4a862 --- /dev/null +++ b/scripts/face3d/options/test_options.py @@ -0,0 +1,21 @@ +"""This script contains the test options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/scripts/face3d/options/train_options.py b/scripts/face3d/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..1337bfdd5f372b5c686a91b394a2aadbe5741f44 --- /dev/null +++ b/scripts/face3d/options/train_options.py @@ -0,0 +1,53 @@ +"""This script contains the training options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions +from util import util + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # dataset parameters + # for train + parser.add_argument('--data_root', type=str, default='./', help='dataset root') + parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') + parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') + + # for val + parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') + parser.add_argument('--batch_size_val', type=int, default=32) + + + # visualization parameters + parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') + + # training parameters + parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') + + self.isTrain = True + return parser diff --git a/scripts/face3d/util/BBRegressorParam_r.mat b/scripts/face3d/util/BBRegressorParam_r.mat new file mode 100644 index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084 Binary files /dev/null and b/scripts/face3d/util/BBRegressorParam_r.mat differ diff --git a/scripts/face3d/util/__init__.py b/scripts/face3d/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0717bbdde4b9276c3a519e87cbfa733be6c0d003 --- /dev/null +++ b/scripts/face3d/util/__init__.py @@ -0,0 +1,3 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" +from scripts.face3d.util import * + diff --git a/scripts/face3d/util/__pycache__/__init__.cpython-310.pyc b/scripts/face3d/util/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f42a3658eb8cb521a3ec22726caf0ca23ef93db Binary files /dev/null and b/scripts/face3d/util/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/face3d/util/__pycache__/load_mats.cpython-310.pyc b/scripts/face3d/util/__pycache__/load_mats.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e20f20c51e34dd4896a0d79efb583081c01215a Binary files /dev/null and b/scripts/face3d/util/__pycache__/load_mats.cpython-310.pyc differ diff --git a/scripts/face3d/util/__pycache__/my_awing_arch.cpython-310.pyc b/scripts/face3d/util/__pycache__/my_awing_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75bfffc26092b7f093049f3b7dbceedfc22830f Binary files /dev/null and b/scripts/face3d/util/__pycache__/my_awing_arch.cpython-310.pyc differ diff --git a/scripts/face3d/util/__pycache__/preprocess.cpython-310.pyc b/scripts/face3d/util/__pycache__/preprocess.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2086b9dd6a4dfbcdbb076c824706a17b7579a237 Binary files /dev/null and b/scripts/face3d/util/__pycache__/preprocess.cpython-310.pyc differ diff --git a/scripts/face3d/util/detect_lm68.py b/scripts/face3d/util/detect_lm68.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e40997289e17405e1fb6c408d21adce7b626ce --- /dev/null +++ b/scripts/face3d/util/detect_lm68.py @@ -0,0 +1,106 @@ +import os +import cv2 +import numpy as np +from scipy.io import loadmat +import tensorflow as tf +from util.preprocess import align_for_lm +from shutil import move + +mean_face = np.loadtxt('util/test_mean_face.txt') +mean_face = mean_face.reshape([68, 2]) + +def save_label(labels, save_path): + np.savetxt(save_path, labels) + +def draw_landmarks(img, landmark, save_name): + landmark = landmark + lm_img = np.zeros([img.shape[0], img.shape[1], 3]) + lm_img[:] = img.astype(np.float32) + landmark = np.round(landmark).astype(np.int32) + + for i in range(len(landmark)): + for j in range(-1, 1): + for k in range(-1, 1): + if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ + img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ + landmark[i, 0]+k > 0 and \ + landmark[i, 0]+k < img.shape[1]: + lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, + :] = np.array([0, 0, 255]) + lm_img = lm_img.astype(np.uint8) + + cv2.imwrite(save_name, lm_img) + + +def load_data(img_name, txt_name): + return cv2.imread(img_name), np.loadtxt(txt_name) + +# create tensorflow graph for landmark detector +def load_lm_graph(graph_filename): + with tf.gfile.GFile(graph_filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='net') + img_224 = graph.get_tensor_by_name('net/input_imgs:0') + output_lm = graph.get_tensor_by_name('net/lm:0') + lm_sess = tf.Session(graph=graph) + + return lm_sess,img_224,output_lm + +# landmark detection +def detect_68p(img_path,sess,input_op,output_op): + print('detecting landmarks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + vis_path = os.path.join(img_path, 'vis') + remove_path = os.path.join(img_path, 'remove') + save_path = os.path.join(img_path, 'landmarks') + if not os.path.isdir(vis_path): + os.makedirs(vis_path) + if not os.path.isdir(remove_path): + os.makedirs(remove_path) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + txt_name = '.'.join(name.split('.')[:-1]) + '.txt' + full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image + + # if an image does not have detected 5 facial landmarks, remove it from the training list + if not os.path.isfile(full_txt_name): + move(full_image_name, os.path.join(remove_path, name)) + continue + + # load data + img, five_points = load_data(full_image_name, full_txt_name) + input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection + + # if the alignment fails, remove corresponding image from the training list + if scale == 0: + move(full_txt_name, os.path.join( + remove_path, txt_name)) + move(full_image_name, os.path.join(remove_path, name)) + continue + + # detect landmarks + input_img = np.reshape( + input_img, [1, 224, 224, 3]).astype(np.float32) + landmark = sess.run( + output_op, feed_dict={input_op: input_img}) + + # transform back to original image coordinate + landmark = landmark.reshape([68, 2]) + mean_face + landmark[:, 1] = 223 - landmark[:, 1] + landmark = landmark / scale + landmark[:, 0] = landmark[:, 0] + bbox[0] + landmark[:, 1] = landmark[:, 1] + bbox[1] + landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] + + if i % 100 == 0: + draw_landmarks(img, landmark, os.path.join(vis_path, name)) + save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/scripts/face3d/util/generate_list.py b/scripts/face3d/util/generate_list.py new file mode 100644 index 0000000000000000000000000000000000000000..943d906781063c3584a7e5b5c784f8aac0694985 --- /dev/null +++ b/scripts/face3d/util/generate_list.py @@ -0,0 +1,34 @@ +"""This script is to generate training list files for Deep3DFaceRecon_pytorch +""" + +import os + +# save path to training data +def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): + save_path = os.path.join(save_folder, mode) + if not os.path.isdir(save_path): + os.makedirs(save_path) + with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in lms_list]) + + with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in imgs_list]) + + with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in msks_list]) + +# check if the path is valid +def check_list(rlms_list, rimgs_list, rmsks_list): + lms_list, imgs_list, msks_list = [], [], [] + for i in range(len(rlms_list)): + flag = 'false' + lm_path = rlms_list[i] + im_path = rimgs_list[i] + msk_path = rmsks_list[i] + if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): + flag = 'true' + lms_list.append(rlms_list[i]) + imgs_list.append(rimgs_list[i]) + msks_list.append(rmsks_list[i]) + print(i, rlms_list[i], flag) + return lms_list, imgs_list, msks_list diff --git a/scripts/face3d/util/html.py b/scripts/face3d/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68 --- /dev/null +++ b/scripts/face3d/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/scripts/face3d/util/load_mats.py b/scripts/face3d/util/load_mats.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a6fcc71de1d7dad8b0f81c67dc1c213764ff0b --- /dev/null +++ b/scripts/face3d/util/load_mats.py @@ -0,0 +1,120 @@ +"""This script is to load 3D face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from PIL import Image +from scipy.io import loadmat, savemat +from array import array +import os.path as osp + +# load expression basis +def LoadExpBasis(bfm_folder='BFM'): + n_vertex = 53215 + Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') + exp_dim = array('i') + exp_dim.fromfile(Expbin, 1) + expMU = array('f') + expPC = array('f') + expMU.fromfile(Expbin, 3*n_vertex) + expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) + Expbin.close() + + expPC = np.array(expPC) + expPC = np.reshape(expPC, [exp_dim[0], -1]) + expPC = np.transpose(expPC) + + expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) + + return expPC, expEV + + +# transfer original BFM09 to our face model +def transferBFM09(bfm_folder='BFM'): + print('Transfer BFM09 to BFM_model_front......') + original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) + shapePC = original_BFM['shapePC'] # shape basis + shapeEV = original_BFM['shapeEV'] # corresponding eigen value + shapeMU = original_BFM['shapeMU'] # mean face + texPC = original_BFM['texPC'] # texture basis + texEV = original_BFM['texEV'] # eigen value + texMU = original_BFM['texMU'] # mean texture + + expPC, expEV = LoadExpBasis(bfm_folder) + + # transfer BFM09 to our face model + + idBase = shapePC*np.reshape(shapeEV, [-1, 199]) + idBase = idBase/1e5 # unify the scale to decimeter + idBase = idBase[:, :80] # use only first 80 basis + + exBase = expPC*np.reshape(expEV, [-1, 79]) + exBase = exBase/1e5 # unify the scale to decimeter + exBase = exBase[:, :64] # use only first 64 basis + + texBase = texPC*np.reshape(texEV, [-1, 199]) + texBase = texBase[:, :80] # use only first 80 basis + + # our face model is cropped along face landmarks and contains only 35709 vertex. + # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. + # thus we select corresponding vertex to get our face model. + + index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) + index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) + + index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) + index_shape = index_shape['trimIndex'].astype( + np.int32) - 1 # starts from 0 (to 53490) + index_shape = index_shape[index_exp] + + idBase = np.reshape(idBase, [-1, 3, 80]) + idBase = idBase[index_shape, :, :] + idBase = np.reshape(idBase, [-1, 80]) + + texBase = np.reshape(texBase, [-1, 3, 80]) + texBase = texBase[index_shape, :, :] + texBase = np.reshape(texBase, [-1, 80]) + + exBase = np.reshape(exBase, [-1, 3, 64]) + exBase = exBase[index_exp, :, :] + exBase = np.reshape(exBase, [-1, 64]) + + meanshape = np.reshape(shapeMU, [-1, 3])/1e5 + meanshape = meanshape[index_shape, :] + meanshape = np.reshape(meanshape, [1, -1]) + + meantex = np.reshape(texMU, [-1, 3]) + meantex = meantex[index_shape, :] + meantex = np.reshape(meantex, [1, -1]) + + # other info contains triangles, region used for computing photometric loss, + # region used for skin texture regularization, and 68 landmarks index etc. + other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) + frontmask2_idx = other_info['frontmask2_idx'] + skinmask = other_info['skinmask'] + keypoints = other_info['keypoints'] + point_buf = other_info['point_buf'] + tri = other_info['tri'] + tri_mask2 = other_info['tri_mask2'] + + # save our face model + savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, + 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) + + +# load landmarks for standard face, which is used for image preprocessing +def load_lm3d(bfm_folder): + + Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) + Lm3D = Lm3D['lm'] + + # calculate 5 facial landmarks using 68 landmarks + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( + Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) + Lm3D = Lm3D[[1, 2, 0, 3, 4], :] + + return Lm3D + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/scripts/face3d/util/my_awing_arch.py b/scripts/face3d/util/my_awing_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..647fa111e015d925b7ce38164c0501fccabb0034 --- /dev/null +++ b/scripts/face3d/util/my_awing_arch.py @@ -0,0 +1,378 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def calculate_points(heatmaps): + # change heatmaps to landmarks + B, N, H, W = heatmaps.shape + HW = H * W + BN_range = np.arange(B * N) + + heatline = heatmaps.reshape(B, N, HW) + indexes = np.argmax(heatline, axis=2) + + preds = np.stack((indexes % W, indexes // W), axis=2) + preds = preds.astype(np.float64, copy=False) + + inr = indexes.ravel() + + heatline = heatline.reshape(B * N, HW) + x_up = heatline[BN_range, inr + 1] + x_down = heatline[BN_range, inr - 1] + # y_up = heatline[BN_range, inr + W] + + if any((inr + W) >= 4096): + y_up = heatline[BN_range, 4095] + else: + y_up = heatline[BN_range, inr + W] + if any((inr - W) <= 0): + y_down = heatline[BN_range, 0] + else: + y_down = heatline[BN_range, inr - W] + + think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1)) + think_diff *= .25 + + preds += think_diff.reshape(B, N, 2) + preds += .5 + return preds + + +class AddCoordsTh(nn.Module): + + def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): + super(AddCoordsTh, self).__init__() + self.x_dim = x_dim + self.y_dim = y_dim + self.with_r = with_r + self.with_boundary = with_boundary + + def forward(self, input_tensor, heatmap=None): + """ + input_tensor: (batch, c, x_dim, y_dim) + """ + batch_size_tensor = input_tensor.shape[0] + + xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device) + xx_ones = xx_ones.unsqueeze(-1) + + xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) + xx_range = xx_range.unsqueeze(1) + + xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) + xx_channel = xx_channel.unsqueeze(-1) + + yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device) + yy_ones = yy_ones.unsqueeze(1) + + yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) + yy_range = yy_range.unsqueeze(-1) + + yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) + yy_channel = yy_channel.unsqueeze(-1) + + xx_channel = xx_channel.permute(0, 3, 2, 1) + yy_channel = yy_channel.permute(0, 3, 2, 1) + + xx_channel = xx_channel / (self.x_dim - 1) + yy_channel = yy_channel / (self.y_dim - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) + yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) + + if self.with_boundary and heatmap is not None: + boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) + + zero_tensor = torch.zeros_like(xx_channel) + xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor) + yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor) + if self.with_boundary and heatmap is not None: + xx_boundary_channel = xx_boundary_channel.to(input_tensor.device) + yy_boundary_channel = yy_boundary_channel.to(input_tensor.device) + ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) + rr = rr / torch.max(rr) + ret = torch.cat([ret, rr], dim=1) + + if self.with_boundary and heatmap is not None: + ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1) + return ret + + +class CoordConvTh(nn.Module): + """CoordConv layer as in the paper.""" + + def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs): + super(CoordConvTh, self).__init__() + self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary) + in_channels += 2 + if with_r: + in_channels += 1 + if with_boundary and not first_one: + in_channels += 2 + self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) + + def forward(self, input_tensor, heatmap=None): + ret = self.addcoords(input_tensor, heatmap) + last_channel = ret[:, -2:, :, :] + ret = self.conv(ret) + return ret, last_channel + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1): + '3x3 convolution with padding' + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + # self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + # self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + + out = self.conv2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ConvBlock(nn.Module): + + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class HourGlass(nn.Module): + + def __init__(self, num_modules, depth, num_features, first_one=False): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + self.coordconv = CoordConvTh( + x_dim=64, + y_dim=64, + with_r=True, + with_boundary=True, + in_channels=256, + first_one=first_one, + out_channels=256, + kernel_size=1, + stride=1, + padding=0) + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(256, 256)) + + self.add_module('b2_' + str(level), ConvBlock(256, 256)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) + + self.add_module('b3_' + str(level), ConvBlock(256, 256)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x, heatmap): + x, last_channel = self.coordconv(x, heatmap) + return self._forward(self.depth, x), last_channel + + +class FAN(nn.Module): + + def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'): + super(FAN, self).__init__() + self.device = device + self.num_modules = num_modules + self.gray_scale = gray_scale + self.end_relu = end_relu + self.num_landmarks = num_landmarks + + # Base part + if self.gray_scale: + self.conv1 = CoordConvTh( + x_dim=256, + y_dim=256, + with_r=True, + with_boundary=False, + in_channels=3, + out_channels=64, + kernel_size=7, + stride=2, + padding=3) + else: + self.conv1 = CoordConvTh( + x_dim=256, + y_dim=256, + with_r=True, + with_boundary=False, + in_channels=3, + out_channels=64, + kernel_size=7, + stride=2, + padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + if hg_module == 0: + first_one = True + else: + first_one = False + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), + nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x, _ = self.conv1(x) + x = F.relu(self.bn1(x), True) + # x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + boundary_channels = [] + tmp_out = None + for i in range(self.num_modules): + hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + if self.end_relu: + tmp_out = F.relu(tmp_out) # HACK: Added relu + outputs.append(tmp_out) + boundary_channels.append(boundary_channel) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs, boundary_channels + + def get_landmarks(self, img): + H, W, _ = img.shape + offset = W / 64, H / 64, 0, 0 + + img = cv2.resize(img, (256, 256)) + inp = img[..., ::-1] + inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float() + inp = inp.to(self.device) + inp.div_(255.0).unsqueeze_(0) + + outputs, _ = self.forward(inp) + out = outputs[-1][:, :-1, :, :] + heatmaps = out.detach().cpu().numpy() + + pred = calculate_points(heatmaps).reshape(-1, 2) + + pred *= offset[:2] + pred += offset[-2:] + + return pred diff --git a/scripts/face3d/util/nvdiffrast.py b/scripts/face3d/util/nvdiffrast.py new file mode 100644 index 0000000000000000000000000000000000000000..f3245859c650afbfe841a66b74cddefaf28820d9 --- /dev/null +++ b/scripts/face3d/util/nvdiffrast.py @@ -0,0 +1,126 @@ +"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch + Attention, antialiasing step is missing in current version. +""" +import pytorch3d.ops +import torch +import torch.nn.functional as F +import kornia +from kornia.geometry.camera import pixel2cam +import numpy as np +from typing import List +from scipy.io import loadmat +from torch import nn + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + DirectionalLights, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, +) + +# def ndc_projection(x=0.1, n=1.0, f=50.0): +# return np.array([[n/x, 0, 0, 0], +# [ 0, n/-x, 0, 0], +# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], +# [ 0, 0, -1, 0]]).astype(np.float32) + +class MeshRenderer(nn.Module): + def __init__(self, + rasterize_fov, + znear=0.1, + zfar=10, + rasterize_size=224): + super(MeshRenderer, self).__init__() + + # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear + # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( + # torch.diag(torch.tensor([1., -1, -1, 1]))) + self.rasterize_size = rasterize_size + self.fov = rasterize_fov + self.znear = znear + self.zfar = zfar + + self.rasterizer = None + + def forward(self, vertex, tri, feat=None): + """ + Return: + mask -- torch.tensor, size (B, 1, H, W) + depth -- torch.tensor, size (B, 1, H, W) + features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None + + Parameters: + vertex -- torch.tensor, size (B, N, 3) + tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles + feat(optional) -- torch.tensor, size (B, N ,C), features + """ + device = vertex.device + rsize = int(self.rasterize_size) + # ndc_proj = self.ndc_proj.to(device) + # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v + if vertex.shape[-1] == 3: + vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) + vertex[..., 0] = -vertex[..., 0] + + + # vertex_ndc = vertex @ ndc_proj.t() + if self.rasterizer is None: + self.rasterizer = MeshRasterizer() + print("create rasterizer on device cuda:%d"%device.index) + + # ranges = None + # if isinstance(tri, List) or len(tri.shape) == 3: + # vum = vertex_ndc.shape[1] + # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) + # fstartidx = torch.cumsum(fnum, dim=0) - fnum + # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() + # for i in range(tri.shape[0]): + # tri[i] = tri[i] + i*vum + # vertex_ndc = torch.cat(vertex_ndc, dim=0) + # tri = torch.cat(tri, dim=0) + + # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] + tri = tri.type(torch.int32).contiguous() + + # rasterize + cameras = FoVPerspectiveCameras( + device=device, + fov=self.fov, + znear=self.znear, + zfar=self.zfar, + ) + + raster_settings = RasterizationSettings( + image_size=rsize + ) + + # print(vertex.shape, tri.shape) + mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) + + fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) + rast_out = fragments.pix_to_face.squeeze(-1) + depth = fragments.zbuf + + # render depth + depth = depth.permute(0, 3, 1, 2) + mask = (rast_out > 0).float().unsqueeze(1) + depth = mask * depth + + + image = None + if feat is not None: + attributes = feat.reshape(-1,3)[mesh.faces_packed()] + image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, + fragments.bary_coords, + attributes) + # print(image.shape) + image = image.squeeze(-2).permute(0, 3, 1, 2) + image = mask * image + + return mask, depth, image + diff --git a/scripts/face3d/util/preprocess.py b/scripts/face3d/util/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..a807b90bf95024e7894f21006384fe7a53aceb85 --- /dev/null +++ b/scripts/face3d/util/preprocess.py @@ -0,0 +1,109 @@ +"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from scipy.io import loadmat +from PIL import Image +import cv2 +import os +from skimage import transform as trans +import torch +import warnings +warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +# calculating least square problem for image alignment +def POS(xp, x): + npts = xp.shape[1] + + A = np.zeros([2*npts, 8]) + + A[0:2*npts-1:2, 0:3] = x.transpose() + A[0:2*npts-1:2, 3] = 1 + + A[1:2*npts:2, 4:7] = x.transpose() + A[1:2*npts:2, 7] = 1 + + b = np.reshape(xp.transpose(), [2*npts, 1]) + + k, _, _, _ = np.linalg.lstsq(A, b) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.stack([sTx, sTy], axis=0) + + return t, s + +# resize and crop images for face reconstruction +def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): + w0, h0 = img.size + w = (w0*s).astype(np.int32) + h = (h0*s).astype(np.int32) + left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) + right = left + target_size + up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) + below = up + target_size + + img = img.resize((w, h), resample=Image.BICUBIC) + img = img.crop((left, up, right, below)) + + if mask is not None: + mask = mask.resize((w, h), resample=Image.BICUBIC) + mask = mask.crop((left, up, right, below)) + + lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - + t[1] + h0/2], axis=1)*s + lm = lm - np.reshape( + np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) + + return img, lm, mask + +# utils for face reconstruction +def extract_5p(lm): + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( + lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) + lm5p = lm5p[[1, 2, 0, 3, 4], :] + return lm5p + +# utils for face reconstruction +def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): + """ + Return: + transparams --numpy.array (raw_W, raw_H, scale, tx, ty) + img_new --PIL.Image (target_size, target_size, 3) + lm_new --numpy.array (68, 2), y direction is opposite to v direction + mask_new --PIL.Image (target_size, target_size) + + Parameters: + img --PIL.Image (raw_H, raw_W, 3) + lm --numpy.array (68, 2), y direction is opposite to v direction + lm3D --numpy.array (5, 3) + mask --PIL.Image (raw_H, raw_W, 3) + """ + + w0, h0 = img.size + if lm.shape[0] != 5: + lm5p = extract_5p(lm) + else: + lm5p = lm + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face + t, s = POS(lm5p.transpose(), lm3D.transpose()) + s = rescale_factor/s + + # processing the image + img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) + print("w0:", w0) + print("h0:", h0) + print("s:", s) + print("t:", t) + print("t[0]:", t[0]) + print("t[1]:", t[1]) + trans_params = np.array([float(w0), float(h0), float(s), float(t[0][0]), float(t[1][0])]) + + return trans_params, img_new, lm_new, mask_new diff --git a/scripts/face3d/util/skin_mask.py b/scripts/face3d/util/skin_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a74e4c3b40d13b0258b83a12f56321a85bb179 --- /dev/null +++ b/scripts/face3d/util/skin_mask.py @@ -0,0 +1,125 @@ +"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch +""" + +import math +import numpy as np +import os +import cv2 + +class GMM: + def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): + self.dim = dim # feature dimension + self.num = num # number of Gaussian components + self.w = w # weights of Gaussian components (a list of scalars) + self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) + self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) + self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) + self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) + + self.factor = [0]*num + for i in range(self.num): + self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 + + def likelihood(self, data): + assert(data.shape[1] == self.dim) + N = data.shape[0] + lh = np.zeros(N) + + for i in range(self.num): + data_ = data - self.mu[i] + + tmp = np.matmul(data_,self.cov_inv[i]) * data_ + tmp = np.sum(tmp,axis=1) + power = -0.5 * tmp + + p = np.array([math.exp(power[j]) for j in range(N)]) + p = p/self.factor[i] + lh += p*self.w[i] + + return lh + + +def _rgb2ycbcr(rgb): + m = np.array([[65.481, 128.553, 24.966], + [-37.797, -74.203, 112], + [112, -93.786, -18.214]]) + shape = rgb.shape + rgb = rgb.reshape((shape[0] * shape[1], 3)) + ycbcr = np.dot(rgb, m.transpose() / 255.) + ycbcr[:, 0] += 16. + ycbcr[:, 1:] += 128. + return ycbcr.reshape(shape) + + +def _bgr2ycbcr(bgr): + rgb = bgr[..., ::-1] + return _rgb2ycbcr(rgb) + + +gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] +gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), + np.array([150.19858, 105.18467, 155.51428]), + np.array([183.92976, 107.62468, 152.71820]), + np.array([114.90524, 113.59782, 151.38217])] +gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] +gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), + np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), + np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), + np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] + +gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) + +gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] +gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), + np.array([110.91392, 125.52969, 130.19237]), + np.array([129.75864, 129.96107, 126.96808]), + np.array([112.29587, 128.85121, 129.05431])] +gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] +gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), + np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), + np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), + np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] + +gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) + +prior_skin = 0.8 +prior_nonskin = 1 - prior_skin + + +# calculate skin attention mask +def skinmask(imbgr): + im = _bgr2ycbcr(imbgr) + + data = im.reshape((-1,3)) + + lh_skin = gmm_skin.likelihood(data) + lh_nonskin = gmm_nonskin.likelihood(data) + + tmp1 = prior_skin * lh_skin + tmp2 = prior_nonskin * lh_nonskin + post_skin = tmp1 / (tmp1+tmp2) # posterior probability + + post_skin = post_skin.reshape((im.shape[0],im.shape[1])) + + post_skin = np.round(post_skin*255) + post_skin = post_skin.astype(np.uint8) + post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 + + return post_skin + + +def get_skin_mask(img_path): + print('generating skin masks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + save_path = os.path.join(img_path, 'mask') + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + img = cv2.imread(full_image_name).astype(np.float32) + skin_img = skinmask(img) + cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/scripts/face3d/util/test_mean_face.txt b/scripts/face3d/util/test_mean_face.txt new file mode 100644 index 0000000000000000000000000000000000000000..3a46d4db7699ffed8f898fcee64099631509946d --- /dev/null +++ b/scripts/face3d/util/test_mean_face.txt @@ -0,0 +1,136 @@ +-5.228591537475585938e+01 +2.078247070312500000e-01 +-5.064269638061523438e+01 +-1.315765380859375000e+01 +-4.952939224243164062e+01 +-2.592591094970703125e+01 +-4.793047332763671875e+01 +-3.832135772705078125e+01 +-4.512159729003906250e+01 +-5.059623336791992188e+01 +-3.917720794677734375e+01 +-6.043736648559570312e+01 +-2.929953765869140625e+01 +-6.861183166503906250e+01 +-1.719801330566406250e+01 +-7.572736358642578125e+01 +-1.961936950683593750e+00 +-7.862001037597656250e+01 +1.467941284179687500e+01 +-7.607844543457031250e+01 +2.744073486328125000e+01 +-6.915261840820312500e+01 +3.855677795410156250e+01 +-5.950350570678710938e+01 +4.478240966796875000e+01 +-4.867547225952148438e+01 +4.714337158203125000e+01 +-3.800830078125000000e+01 +4.940315246582031250e+01 +-2.496297454833984375e+01 +5.117234802246093750e+01 +-1.241538238525390625e+01 +5.190507507324218750e+01 +8.244247436523437500e-01 +-4.150688934326171875e+01 +2.386329650878906250e+01 +-3.570307159423828125e+01 +3.017010498046875000e+01 +-2.790358734130859375e+01 +3.212951660156250000e+01 +-1.941773223876953125e+01 +3.156523132324218750e+01 +-1.138106536865234375e+01 +2.841992187500000000e+01 +5.993263244628906250e+00 +2.895182800292968750e+01 +1.343590545654296875e+01 +3.189880371093750000e+01 +2.203153991699218750e+01 +3.302221679687500000e+01 +2.992478942871093750e+01 +3.099150085449218750e+01 +3.628388977050781250e+01 +2.765748596191406250e+01 +-1.933914184570312500e+00 +1.405374145507812500e+01 +-2.153038024902343750e+00 +5.772636413574218750e+00 +-2.270050048828125000e+00 +-2.121643066406250000e+00 +-2.218330383300781250e+00 +-1.068978118896484375e+01 +-1.187252044677734375e+01 +-1.997912597656250000e+01 +-6.879402160644531250e+00 +-2.143579864501953125e+01 +-1.227821350097656250e+00 +-2.193494415283203125e+01 +4.623237609863281250e+00 +-2.152721405029296875e+01 +9.721397399902343750e+00 +-1.953671264648437500e+01 +-3.648714447021484375e+01 +9.811126708984375000e+00 +-3.130242919921875000e+01 +1.422447967529296875e+01 +-2.212834930419921875e+01 +1.493019866943359375e+01 +-1.500880432128906250e+01 +1.073588562011718750e+01 +-2.095037078857421875e+01 +9.054298400878906250e+00 +-3.050099182128906250e+01 +8.704177856445312500e+00 +1.173237609863281250e+01 +1.054329681396484375e+01 +1.856353759765625000e+01 +1.535009765625000000e+01 +2.893331909179687500e+01 +1.451992797851562500e+01 +3.452944946289062500e+01 +1.065280151367187500e+01 +2.875990295410156250e+01 +8.654792785644531250e+00 +1.942100524902343750e+01 +9.422447204589843750e+00 +-2.204488372802734375e+01 +-3.983994293212890625e+01 +-1.324458312988281250e+01 +-3.467377471923828125e+01 +-6.749649047851562500e+00 +-3.092894744873046875e+01 +-9.183349609375000000e-01 +-3.196458435058593750e+01 +4.220649719238281250e+00 +-3.090406036376953125e+01 +1.089889526367187500e+01 +-3.497008514404296875e+01 +1.874589538574218750e+01 +-4.065438079833984375e+01 +1.124106597900390625e+01 +-4.438417816162109375e+01 +5.181709289550781250e+00 +-4.649170684814453125e+01 +-1.158607482910156250e+00 +-4.680406951904296875e+01 +-7.918922424316406250e+00 +-4.671575164794921875e+01 +-1.452505493164062500e+01 +-4.416526031494140625e+01 +-2.005007171630859375e+01 +-3.997841644287109375e+01 +-1.054919433593750000e+01 +-3.849683380126953125e+01 +-1.051826477050781250e+00 +-3.794863128662109375e+01 +6.412681579589843750e+00 +-3.804645538330078125e+01 +1.627674865722656250e+01 +-4.039697265625000000e+01 +6.373878479003906250e+00 +-4.087213897705078125e+01 +-8.551712036132812500e-01 +-4.157129669189453125e+01 +-1.014953613281250000e+01 +-4.128469085693359375e+01 diff --git a/scripts/face3d/util/util.py b/scripts/face3d/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0d689ca138fc0fbf5bec794511ea0f9e638f9ea9 --- /dev/null +++ b/scripts/face3d/util/util.py @@ -0,0 +1,208 @@ +"""This script contains basic utilities for Deep3DFaceRecon_pytorch +""" +from __future__ import print_function +import numpy as np +import torch +from PIL import Image +import os +import importlib +import argparse +from argparse import Namespace +import torchvision + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def copyconf(default_opt, **kwargs): + conf = Namespace(**vars(default_opt)) + for key in kwargs: + setattr(conf, key, kwargs[key]) + return conf + +def genvalconf(train_opt, **kwargs): + conf = Namespace(**vars(train_opt)) + attr_dict = train_opt.__dict__ + for key, value in attr_dict.items(): + if 'val' in key and key.split('_')[0] in attr_dict: + setattr(conf, key.split('_')[0], value) + + for key in kwargs: + setattr(conf, key, kwargs[key]) + + return conf + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) + + return cls + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array, range(0, 1) + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio is None: + pass + elif aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + elif aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + + +def correct_resize_label(t, size): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i, :1] + one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) + one_np = one_np[:, :, 0] + one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) + resized_t = torch.from_numpy(np.array(one_image)).long() + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + + +def correct_resize(t, size, mode=Image.BICUBIC): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i:i + 1] + one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) + resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + +def draw_landmarks(img, landmark, color='r', step=2): + """ + Return: + img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) + + + Parameters: + img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) + landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction + color -- str, 'r' or 'b' (red or blue) + """ + if color =='r': + c = np.array([255., 0, 0]) + else: + c = np.array([0, 0, 255.]) + + _, H, W, _ = img.shape + img, landmark = img.copy(), landmark.copy() + landmark[..., 1] = H - 1 - landmark[..., 1] + landmark = np.round(landmark).astype(np.int32) + for i in range(landmark.shape[1]): + x, y = landmark[:, i, 0], landmark[:, i, 1] + for j in range(-step, step): + for k in range(-step, step): + u = np.clip(x + j, 0, W - 1) + v = np.clip(y + k, 0, H - 1) + for m in range(landmark.shape[0]): + img[m, v[m], u[m]] = c + return img diff --git a/scripts/face3d/util/visualizer.py b/scripts/face3d/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4023a6d4086acba9bc88e079f625194d324d7c9e --- /dev/null +++ b/scripts/face3d/util/visualizer.py @@ -0,0 +1,227 @@ +"""This script defines the visualizer for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from torch.utils.tensorboard import SummaryWriter + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s/%s.png' % (label, name) + os.makedirs(os.path.join(image_dir, label), exist_ok=True) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.use_html = opt.isTrain and not opt.no_html + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) + self.win_size = opt.display_winsize + self.name = opt.name + self.saved = False + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + + def display_current_results(self, visuals, total_iters, epoch, save_result): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + for label, image in visuals.items(): + self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, total_iters, losses): + # G_loss_collection = {} + # D_loss_collection = {} + # for name, value in losses.items(): + # if 'G' in name or 'NCE' in name or 'idt' in name: + # G_loss_collection[name] = value + # else: + # D_loss_collection[name] = value + # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) + # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) + for name, value in losses.items(): + self.writer.add_scalar(name, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message + + +class MyVisualizer: + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the optio + self.name = opt.name + self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') + + if opt.phase != 'test': + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + + def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, + add_image=True): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + dataset (str) - - 'train' or 'val' or 'test' + """ + # if (not add_image) and (not save_results): return + + for label, image in visuals.items(): + for i in range(image.shape[0]): + image_numpy = util.tensor2im(image[i]) + if add_image: + self.writer.add_image(label + '%s_%02d'%(dataset, i + count), + image_numpy, total_iters, dataformats='HWC') + + if save_results: + save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + if name is not None: + img_path = os.path.join(save_path, '%s.png' % name) + else: + img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) + util.save_image(image_numpy, img_path) + + + def plot_current_losses(self, total_iters, losses, dataset='train'): + for name, value in losses.items(): + self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( + dataset, epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/scripts/face3d/visualize.py b/scripts/face3d/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..67c32101dc3857a262ec740ea7333b26abc0cc92 --- /dev/null +++ b/scripts/face3d/visualize.py @@ -0,0 +1,48 @@ +# check the sync of 3dmm feature and the audio +import cv2 +import numpy as np +from scripts.face3d.models.bfm import ParametricFaceModel +from scripts.face3d.models.facerecon_model import FaceReconModel +import torch +import subprocess, platform +import scipy.io as scio +from tqdm import tqdm + +# draft +def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64): + + coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] + + coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] + + coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 + + coeff_full[:, 80:144] = coeff_pred[:, 0:64] + coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation + coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation + + tmp_video_path = '/tmp/face3dtmp.mp4' + + facemodel = FaceReconModel(args) + + video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) + + for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): + cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) + + facemodel.forward(cur_coeff_full, device) + + predicted_landmark = facemodel.pred_lm # TODO. + predicted_landmark = predicted_landmark.cpu().numpy().squeeze() + + rendered_img = facemodel.pred_face + rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) + out_img = rendered_img[:, :, :3].astype(np.uint8) + + video.write(np.uint8(out_img[:,:,::-1])) + + video.release() + + command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) + subprocess.call(command, shell=platform.system() != 'Windows') + diff --git a/scripts/facelib/detection/__init__.py b/scripts/facelib/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a42b3c3a59da1f22877cdb1e2bd1fcae7945343b --- /dev/null +++ b/scripts/facelib/detection/__init__.py @@ -0,0 +1,100 @@ +import os +import torch +from torch import nn +from copy import deepcopy + +from scripts.facelib.utils import load_file_from_url +from scripts.facelib.utils import download_pretrained_models +from scripts.facelib.detection.yolov5face.models.common import Conv + +from .retinaface.retinaface import RetinaFace +from .yolov5face.face_detector import YoloDetector + + +def init_detection_model(model_name, half=False, device='cuda'): + if 'retinaface' in model_name: + model = init_retinaface_model(model_name, half, device) + elif 'YOLOv5' in model_name: + model = init_yolov5face_model(model_name, device) + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + return model + + +def init_retinaface_model(model_name, half=False, device='cuda'): + if model_name == 'retinaface_resnet50': + model = RetinaFace(network_name='resnet50', half=half) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' + elif model_name == 'retinaface_mobile0.25': + model = RetinaFace(network_name='mobile0.25', half=half) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + + return model + + +def init_yolov5face_model(model_name, device='cuda'): + if model_name == 'YOLOv5l': + model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' + elif model_name == 'YOLOv5n': + model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.detector.load_state_dict(load_net, strict=True) + model.detector.eval() + model.detector = model.detector.to(device).float() + + for m in model.detector.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif isinstance(m, Conv): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + return model + + +# Download from Google Drive +# def init_yolov5face_model(model_name, device='cuda'): +# if model_name == 'YOLOv5l': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) +# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} +# elif model_name == 'YOLOv5n': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) +# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} +# else: +# raise NotImplementedError(f'{model_name} is not implemented.') + +# model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) +# if not os.path.exists(model_path): +# download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') + +# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) +# model.detector.load_state_dict(load_net, strict=True) +# model.detector.eval() +# model.detector = model.detector.to(device).float() + +# for m in model.detector.modules(): +# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: +# m.inplace = True # pytorch 1.7.0 compatibility +# elif isinstance(m, Conv): +# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + +# return model \ No newline at end of file diff --git a/scripts/facelib/detection/__pycache__/__init__.cpython-310.pyc b/scripts/facelib/detection/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd87cf6cee159df4790a4ca8bf36dc09b3bc07f6 Binary files /dev/null and b/scripts/facelib/detection/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facelib/detection/__pycache__/align_trans.cpython-310.pyc b/scripts/facelib/detection/__pycache__/align_trans.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae5ffd153c0fcc94929a305b1fadc05249f3ec3a Binary files /dev/null and b/scripts/facelib/detection/__pycache__/align_trans.cpython-310.pyc differ diff --git a/scripts/facelib/detection/__pycache__/matlab_cp2tform.cpython-310.pyc b/scripts/facelib/detection/__pycache__/matlab_cp2tform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44a4e5fe5405b3e12c8c221d43239c0a4c5d4752 Binary files /dev/null and b/scripts/facelib/detection/__pycache__/matlab_cp2tform.cpython-310.pyc differ diff --git a/scripts/facelib/detection/align_trans.py b/scripts/facelib/detection/align_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..07f1eb365462c2ec5bbac6d1854c786b6fd6be90 --- /dev/null +++ b/scripts/facelib/detection/align_trans.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], + [33.54930115, 92.3655014], [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): + + return tmp_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException('facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img diff --git a/scripts/facelib/detection/matlab_cp2tform.py b/scripts/facelib/detection/matlab_cp2tform.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a8b54a91709c71437e15c68d3be9a9b0a20a34 --- /dev/null +++ b/scripts/facelib/detection/matlab_cp2tform.py @@ -0,0 +1,317 @@ +import numpy as np +from numpy.linalg import inv, lstsq +from numpy.linalg import matrix_rank as rank +from numpy.linalg import norm + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=-1) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + T = inv(Tinv) + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/scripts/facelib/detection/retinaface/__pycache__/retinaface.cpython-310.pyc b/scripts/facelib/detection/retinaface/__pycache__/retinaface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de5741ea946df57ca689f0f288127647fc8ede1 Binary files /dev/null and b/scripts/facelib/detection/retinaface/__pycache__/retinaface.cpython-310.pyc differ diff --git a/scripts/facelib/detection/retinaface/__pycache__/retinaface_net.cpython-310.pyc b/scripts/facelib/detection/retinaface/__pycache__/retinaface_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c25303631900687063fc286a9c049849d9704a29 Binary files /dev/null and b/scripts/facelib/detection/retinaface/__pycache__/retinaface_net.cpython-310.pyc differ diff --git a/scripts/facelib/detection/retinaface/__pycache__/retinaface_utils.cpython-310.pyc b/scripts/facelib/detection/retinaface/__pycache__/retinaface_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdbed7f0f532809592eac29c40daa0ee5b68c3c8 Binary files /dev/null and b/scripts/facelib/detection/retinaface/__pycache__/retinaface_utils.cpython-310.pyc differ diff --git a/scripts/facelib/detection/retinaface/retinaface.py b/scripts/facelib/detection/retinaface/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..4c329cac6b9d11e809d0dc976f29d72f6c90d734 --- /dev/null +++ b/scripts/facelib/detection/retinaface/retinaface.py @@ -0,0 +1,372 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter + +from scripts.facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face +from scripts.facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head +from scripts.facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, + py_cpu_nms) + +from scripts.basicsr.utils.misc import get_device +# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = get_device() + + +def generate_config(network_name): + + cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 640, + 'return_layers': { + 'stage1': 1, + 'stage2': 2, + 'stage3': 3 + }, + 'in_channel': 32, + 'out_channel': 64 + } + + cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 24, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'image_size': 840, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 + } + + if network_name == 'mobile0.25': + return cfg_mnet + elif network_name == 'resnet50': + return cfg_re50 + else: + raise NotImplementedError(f'network_name={network_name}') + + +class RetinaFace(nn.Module): + + def __init__(self, network_name='resnet50', half=False, phase='test'): + super(RetinaFace, self).__init__() + self.half_inference = half + cfg = generate_config(network_name) + self.backbone = cfg['name'] + + self.model_name = f'retinaface_{network_name}' + self.cfg = cfg + self.phase = phase + self.target_size, self.max_size = 1600, 2150 + self.resize, self.scale, self.scale1 = 1., None, None + self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) + self.reference = get_reference_facial_points(default_square=True) + # Build network. + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=False) + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + self.to(device) + self.eval() + if self.half_inference: + self.half() + + def forward(self, inputs): + out = self.body(inputs) + + if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': + out = list(out.values()) + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) + tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = (torch.cat(tmp, dim=1)) + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output + + def __detect_faces(self, inputs): + # get scale + height, width = inputs.shape[2:] + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + tmp = [width, height, width, height, width, height, width, height, width, height] + self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + + # forawrd + inputs = inputs.to(device) + if self.half_inference: + inputs = inputs.half() + loc, conf, landmarks = self(inputs) + + # get priorbox + priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) + priors = priorbox.forward().to(device) + + return loc, conf, landmarks, priors + + # single image detection + def transform(self, image, use_origin_size): + # convert to opencv format + if isinstance(image, Image.Image): + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = image.astype(np.float32) + + # testing scale + im_size_min = np.min(image.shape[0:2]) + im_size_max = np.max(image.shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + # convert to torch.tensor format + # image -= (104, 117, 123) + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).unsqueeze(0) + + return image, resize + + def detect_faces( + self, + image, + conf_threshold=0.8, + nms_threshold=0.4, + use_origin_size=True, + ): + """ + Params: + imgs: BGR image + """ + image, self.resize = self.transform(image, use_origin_size) + image = image.to(device) + if self.half_inference: + image = image.half() + image = image - self.mean_tensor + + loc, conf, landmarks, priors = self.__detect_faces(image) + + boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) + boxes = boxes * self.scale / self.resize + boxes = boxes.cpu().numpy() + + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) + landmarks = landmarks * self.scale1 / self.resize + landmarks = landmarks.cpu().numpy() + + # ignore low scores + inds = np.where(scores > conf_threshold)[0] + boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] + + # sort + order = scores.argsort()[::-1] + boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] + + # do NMS + bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] + # self.t['forward_pass'].toc() + # print(self.t['forward_pass'].average_time) + # import sys + # sys.stdout.flush() + return np.concatenate((bounding_boxes, landmarks), axis=1) + + def __align_multi(self, image, boxes, landmarks, limit=None): + + if len(boxes) < 1: + return [], [] + + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + + faces = [] + for landmark in landmarks: + facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] + + warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) + faces.append(warped_face) + + return np.concatenate((boxes, landmarks), axis=1), faces + + def align_multi(self, img, conf_threshold=0.8, limit=None): + + rlt = self.detect_faces(img, conf_threshold=conf_threshold) + boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] + + return self.__align_multi(img, boxes, landmarks, limit) + + # batched detection + def batched_transform(self, frames, use_origin_size): + """ + Arguments: + frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], + type=np.float32, BGR format). + use_origin_size: whether to use origin size. + """ + from_PIL = True if isinstance(frames[0], Image.Image) else False + + # convert to opencv format + if from_PIL: + frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] + frames = np.asarray(frames, dtype=np.float32) + + # testing scale + im_size_min = np.min(frames[0].shape[0:2]) + im_size_max = np.max(frames[0].shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + if not from_PIL: + frames = F.interpolate(frames, scale_factor=resize) + else: + frames = [ + cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + for frame in frames + ] + + # convert to torch.tensor format + if not from_PIL: + frames = frames.transpose(1, 2).transpose(1, 3).contiguous() + else: + frames = frames.transpose((0, 3, 1, 2)) + frames = torch.from_numpy(frames) + + return frames, resize + + def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): + """ + Arguments: + frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], + type=np.uint8, BGR format). + conf_threshold: confidence threshold. + nms_threshold: nms threshold. + use_origin_size: whether to use origin size. + Returns: + final_bounding_boxes: list of np.array ([n_boxes, 5], + type=np.float32). + final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). + """ + # self.t['forward_pass'].tic() + frames, self.resize = self.batched_transform(frames, use_origin_size) + frames = frames.to(device) + frames = frames - self.mean_tensor + + b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) + + final_bounding_boxes, final_landmarks = [], [] + + # decode + priors = priors.unsqueeze(0) + b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize + b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize + b_conf = b_conf[:, :, 1] + + # index for selection + b_indice = b_conf > conf_threshold + + # concat + b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() + + for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): + + # ignore low scores + pred, landm = pred[inds, :], landm[inds, :] + if pred.shape[0] == 0: + final_bounding_boxes.append(np.array([], dtype=np.float32)) + final_landmarks.append(np.array([], dtype=np.float32)) + continue + + # sort + # order = score.argsort(descending=True) + # box, landm, score = box[order], landm[order], score[order] + + # to CPU + bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() + + # NMS + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] + + # append + final_bounding_boxes.append(bounding_boxes) + final_landmarks.append(landmarks) + # self.t['forward_pass'].toc(average=True) + # self.batch_time += self.t['forward_pass'].diff + # self.total_frame += len(frames) + # print(self.batch_time / self.total_frame) + + return final_bounding_boxes, final_landmarks diff --git a/scripts/facelib/detection/retinaface/retinaface_net.py b/scripts/facelib/detection/retinaface/retinaface_net.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6aa82d3e9055a838f1f9076b12f05fdfc154d0 --- /dev/null +++ b/scripts/facelib/detection/retinaface/retinaface_net.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + # input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + +def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + +def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead diff --git a/scripts/facelib/detection/retinaface/retinaface_utils.py b/scripts/facelib/detection/retinaface/retinaface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c357757741c6d9bd7ce4d8ce740fefd51850fbf --- /dev/null +++ b/scripts/facelib/detection/retinaface/retinaface_utils.py @@ -0,0 +1,421 @@ +import numpy as np +import torch +import torchvision +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + keep = torchvision.ops.nms( + boxes=torch.Tensor(dets[:, :4]), + scores=torch.Tensor(dets[:, 4]), + iou_threshold=thresh, + ) + + return list(keep) + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + ( + boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:] / 2), + 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy + boxes[:, 2:] - boxes[:, :2], + 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when matching boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ encoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ encoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence + 3)landm preds. + """ + # jaccard index + overlaps = jaccard(truths, point_form(priors)) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + tmp = ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ) + landms = torch.cat(tmp, dim=1) + return landms + + +def batched_decode(b_loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + b_loc (tensor): location predictions for loc layers, + Shape: [num_batches,num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + boxes = ( + priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), + ) + boxes = torch.cat(boxes, dim=2) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +def batched_decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_batches,num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = ( + priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], + ) + landms = torch.cat(landms, dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/scripts/facelib/detection/yolov5face/__init__.py b/scripts/facelib/detection/yolov5face/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/facelib/detection/yolov5face/__pycache__/__init__.cpython-310.pyc b/scripts/facelib/detection/yolov5face/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf481e4d44340c11bacfbb167210078e50f0a072 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/__pycache__/face_detector.cpython-310.pyc b/scripts/facelib/detection/yolov5face/__pycache__/face_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ab3d2682f658c0dd0f4b15c2f21cc8fd5bf1827 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/__pycache__/face_detector.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/face_detector.py b/scripts/facelib/detection/yolov5face/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e1f6bee4c38c24c47a5004731b6bd0b2f14805 --- /dev/null +++ b/scripts/facelib/detection/yolov5face/face_detector.py @@ -0,0 +1,141 @@ +import cv2 +import copy +import re +import torch +import numpy as np + +from pathlib import Path +from scripts.facelib.detection.yolov5face.models.yolo import Model +from scripts.facelib.detection.yolov5face.utils.datasets import letterbox +from scripts.facelib.detection.yolov5face.utils.general import ( + check_img_size, + non_max_suppression_face, + scale_coords, + scale_coords_landmarks, +) + +# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9) +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 9, 0] + + +def isListempty(inList): + if isinstance(inList, list): # Is a list + return all(map(isListempty, inList)) + return False # Not a list + +class YoloDetector: + def __init__( + self, + config_name, + min_face=10, + target_size=None, + device='cuda', + ): + """ + config_name: name of .yaml config with network configuration from models/ folder. + min_face : minimal face size in pixels. + target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. + None for original resolution. + """ + self._class_path = Path(__file__).parent.absolute() + self.target_size = target_size + self.min_face = min_face + self.detector = Model(cfg=config_name) + self.device = device + + + def _preprocess(self, imgs): + """ + Preprocessing image before passing through the network. Resize and conversion to torch tensor. + """ + pp_imgs = [] + for img in imgs: + h0, w0 = img.shape[:2] # orig hw + if self.target_size: + r = self.target_size / min(h0, w0) # resize image to img_size + if r < 1: + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + + imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size + img = letterbox(img, new_shape=imgsz)[0] + pp_imgs.append(img) + pp_imgs = np.array(pp_imgs) + pp_imgs = pp_imgs.transpose(0, 3, 1, 2) + pp_imgs = torch.from_numpy(pp_imgs).to(self.device) + pp_imgs = pp_imgs.float() # uint8 to fp16/32 + return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 + + def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): + """ + Postprocessing of raw pytorch model output. + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + bboxes = [[] for _ in range(len(origimgs))] + landmarks = [[] for _ in range(len(origimgs))] + + pred = non_max_suppression_face(pred, conf_thres, iou_thres) + + for image_id, origimg in enumerate(origimgs): + img_shape = origimg.shape + image_height, image_width = img_shape[:2] + gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh + gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks + det = pred[image_id].cpu() + scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() + scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() + + for j in range(det.size()[0]): + box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() + box = list( + map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) + ) + if box[3] - box[1] < self.min_face: + continue + lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() + lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) + lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] + bboxes[image_id].append(box) + landmarks[image_id].append(lm) + return bboxes, landmarks + + def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): + """ + Get bbox coordinates and keypoints of faces on original image. + Params: + imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) + conf_thres: confidence threshold for each prediction + iou_thres: threshold for NMS (filter of intersecting bboxes) + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + # Pass input images through face detector + images = imgs if isinstance(imgs, list) else [imgs] + images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + origimgs = copy.deepcopy(images) + + images = self._preprocess(images) + + if IS_HIGH_VERSION: + with torch.inference_mode(): # for pytorch>=1.9 + pred = self.detector(images)[0] + else: + with torch.no_grad(): # for pytorch<1.9 + pred = self.detector(images)[0] + + bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) + + # return bboxes, points + if not isListempty(points): + bboxes = np.array(bboxes).reshape(-1,4) + points = np.array(points).reshape(-1,10) + padding = bboxes[:,0].reshape(-1,1) + return np.concatenate((bboxes, padding, points), axis=1) + else: + return None + + def __call__(self, *args): + return self.predict(*args) diff --git a/scripts/facelib/detection/yolov5face/models/__init__.py b/scripts/facelib/detection/yolov5face/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/facelib/detection/yolov5face/models/__pycache__/__init__.cpython-310.pyc b/scripts/facelib/detection/yolov5face/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e322e51899dc3f88d9ba9757da0195b357fb5225 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/models/__pycache__/common.cpython-310.pyc b/scripts/facelib/detection/yolov5face/models/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84ccbd58307c63c1ec6d6abf1856bd5f1b0bc5a2 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/models/__pycache__/common.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/models/__pycache__/experimental.cpython-310.pyc b/scripts/facelib/detection/yolov5face/models/__pycache__/experimental.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab479f205e3197af0acbc4272baf05140bd9f23 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/models/__pycache__/experimental.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/models/__pycache__/yolo.cpython-310.pyc b/scripts/facelib/detection/yolov5face/models/__pycache__/yolo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc0214a3ee525a1c28422ec8891c7a59b1548cb3 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/models/__pycache__/yolo.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/models/common.py b/scripts/facelib/detection/yolov5face/models/common.py new file mode 100644 index 0000000000000000000000000000000000000000..117fccb5f0acc705cc5fa5f83f4116083df1d54d --- /dev/null +++ b/scripts/facelib/detection/yolov5face/models/common.py @@ -0,0 +1,299 @@ +# This file contains modules common to various models + +import math + +import numpy as np +import torch +from torch import nn + +from scripts.facelib.detection.yolov5face.utils.datasets import letterbox +from scripts.facelib.detection.yolov5face.utils.general import ( + make_divisible, + non_max_suppression, + scale_coords, + xyxy2xywh, +) + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc") + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + return x.view(batchsize, -1, height, width) + + +def DWConv(c1, c2, k=1, s=1, act=True): + # Depthwise convolution + return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class StemBlock(nn.Module): + def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True): + super().__init__() + self.stem_1 = Conv(c1, c2, k, s, p, g, act) + self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0) + self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1) + self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) + self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0) + + def forward(self, x): + stem_1_out = self.stem_1(x) + stem_2a_out = self.stem_2a(stem_1_out) + stem_2b_out = self.stem_2b(stem_2a_out) + stem_2p_out = self.stem_2p(stem_1_out) + return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1)) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class ShuffleV2Block(nn.Module): + def __init__(self, inp, oup, stride): + super().__init__() + + if not 1 <= stride <= 3: + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(branch_features), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + out = channel_shuffle(out, 2) + return out + + +class SPP(nn.Module): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold + classes = None # (optional list) filter by class + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + +class AutoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super().__init__() + self.model = model.eval() + + def autoshape(self): + print("autoShape already enabled, skipping... ") # model already converted to model.autoshape() + return self + + def forward(self, imgs, size=640, augment=False, profile=False): + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images + shape0, shape1 = [], [] # image and inference shapes + for i, im in enumerate(imgs): + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = size / max(s) # gain + shape1.append([y * g for y in s]) + imgs[i] = im # update + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32 + + # Inference + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, names=None): + super().__init__() + d = pred[0].device # device + gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) + + def __len__(self): + return self.n + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + for d in x: + for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]: + setattr(d, k, getattr(d, k)[0]) # pop out of list + return x diff --git a/scripts/facelib/detection/yolov5face/models/experimental.py b/scripts/facelib/detection/yolov5face/models/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..828c282d0347e9302cb511569e46c8a157828c41 --- /dev/null +++ b/scripts/facelib/detection/yolov5face/models/experimental.py @@ -0,0 +1,45 @@ +# # This file contains experimental modules + +import numpy as np +import torch +from torch import nn + +from scripts.facelib.detection.yolov5face.models.common import Conv + + +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class MixConv2d(nn.Module): + # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + super().__init__() + groups = len(k) + if equal_ch: # equal c_ per group + i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + else: # equal weight.numel() per group + b = [c2] + [0] * groups + a = np.eye(groups + 1, groups, k=-1) + a -= np.roll(a, 1, axis=1) + a *= np.array(k) ** 2 + a[0] = 1 + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) diff --git a/scripts/facelib/detection/yolov5face/models/yolo.py b/scripts/facelib/detection/yolov5face/models/yolo.py new file mode 100644 index 0000000000000000000000000000000000000000..e37c4d1f639a46b22d15e3401505cea6abdce8c6 --- /dev/null +++ b/scripts/facelib/detection/yolov5face/models/yolo.py @@ -0,0 +1,235 @@ +import math +from copy import deepcopy +from pathlib import Path + +import torch +import yaml # for torch hub +from torch import nn + +from scripts.facelib.detection.yolov5face.models.common import ( + C3, + NMS, + SPP, + AutoShape, + Bottleneck, + BottleneckCSP, + Concat, + Conv, + DWConv, + Focus, + ShuffleV2Block, + StemBlock, +) +from scripts.facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d +from scripts.facelib.detection.yolov5face.utils.autoanchor import check_anchor_order +from scripts.facelib.detection.yolov5face.utils.general import make_divisible +from scripts.facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn + + +class Detect(nn.Module): + stride = None # strides computed during build + export = False # onnx export + + def __init__(self, nc=80, anchors=(), ch=()): # detection layer + super().__init__() + self.nc = nc # number of classes + self.no = nc + 5 + 10 # number of outputs per anchor + + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + a = torch.tensor(anchors).float().view(self.nl, -1, 2) + self.register_buffer("anchors", a) # shape(nl,na,2) + self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + + def forward(self, x): + z = [] # inference output + if self.export: + for i in range(self.nl): + x[i] = self.m[i](x[i]) + return x + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = torch.full_like(x[i], 0) + y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() + y[..., 5:15] = x[i][..., 5:15] + + y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + + y[..., 5:7] = ( + y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x1 y1 + y[..., 7:9] = ( + y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x2 y2 + y[..., 9:11] = ( + y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x3 y3 + y[..., 11:13] = ( + y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x4 y4 + y[..., 13:15] = ( + y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x5 y5 + + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + @staticmethod + def _make_grid(nx=20, ny=20): + # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10 + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + + +class Model(nn.Module): + def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes + super().__init__() + self.yaml_file = Path(cfg).name + with Path(cfg).open(encoding="utf8") as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + self.yaml["nc"] = nc # override yaml value + + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml["nc"])] # default names + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 128 # 2x min stride + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + + def forward(self, x): + return self.forward_once(x) # single-scale inference, train + + def forward_once(self, x): + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + return x + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print("Fusing layers... ") + for m in self.model.modules(): + if isinstance(m, Conv) and hasattr(m, "bn"): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.fuseforward # update forward + elif type(m) is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + return self + + def nms(self, mode=True): # add or remove NMS module + present = isinstance(self.model[-1], NMS) # last layer is NMS + if mode and not present: + print("Adding NMS... ") + m = NMS() # module + m.f = -1 # from + m.i = self.model[-1].i + 1 # index + self.model.add_module(name=str(m.i), module=m) # add + self.eval() + elif not mode and present: + print("Removing NMS... ") + self.model = self.model[:-1] # remove + return self + + def autoshape(self): # add autoShape module + print("Adding autoShape... ") + m = AutoShape(self) # wrap model + copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes + return m + + +def parse_model(d, ch): # model_dict, input_channels(3) + anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except: + pass + + n = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [ + Conv, + Bottleneck, + SPP, + DWConv, + MixConv2d, + Focus, + CrossConv, + BottleneckCSP, + C3, + ShuffleV2Block, + StemBlock, + ]: + c1, c2 = ch[f], args[0] + + c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3]: + args.insert(2, n) + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) + elif m is Detect: + args.append([ch[x + 1] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + ch.append(c2) + return nn.Sequential(*layers), sorted(save) diff --git a/scripts/facelib/detection/yolov5face/models/yolov5l.yaml b/scripts/facelib/detection/yolov5face/models/yolov5l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0532b0e22fa7f59349b178146ffddcfdb368aba6 --- /dev/null +++ b/scripts/facelib/detection/yolov5face/models/yolov5l.yaml @@ -0,0 +1,47 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 + [-1, 1, SPP, [1024, [3,5,7]]], + [-1, 3, C3, [1024, False]], # 8 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 5], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 12 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 3], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 16 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 13], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 19 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 22 (P5/32-large) + + [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/scripts/facelib/detection/yolov5face/models/yolov5n.yaml b/scripts/facelib/detection/yolov5face/models/yolov5n.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caba6bed674aa2213b110f19e04eb352ffbeaf1e --- /dev/null +++ b/scripts/facelib/detection/yolov5face/models/yolov5n.yaml @@ -0,0 +1,45 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 + [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 + [-1, 3, ShuffleV2Block, [128, 1]], # 2 + [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 + [-1, 7, ShuffleV2Block, [256, 1]], # 4 + [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 + [-1, 3, ShuffleV2Block, [512, 1]], # 6 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P4 + [-1, 1, C3, [128, False]], # 10 + + [-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], # cat backbone P3 + [-1, 1, C3, [128, False]], # 14 (P3/8-small) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 11], 1, Concat, [1]], # cat head P4 + [-1, 1, C3, [128, False]], # 17 (P4/16-medium) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 7], 1, Concat, [1]], # cat head P5 + [-1, 1, C3, [128, False]], # 20 (P5/32-large) + + [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/scripts/facelib/detection/yolov5face/utils/__init__.py b/scripts/facelib/detection/yolov5face/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/facelib/detection/yolov5face/utils/__pycache__/__init__.cpython-310.pyc b/scripts/facelib/detection/yolov5face/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1f1e91f0754dc7f753565c28f566b509d2ae3a Binary files /dev/null and b/scripts/facelib/detection/yolov5face/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/utils/__pycache__/autoanchor.cpython-310.pyc b/scripts/facelib/detection/yolov5face/utils/__pycache__/autoanchor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e42916fda77201bd260719bb2659985a7f3e989 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/utils/__pycache__/autoanchor.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/utils/__pycache__/datasets.cpython-310.pyc b/scripts/facelib/detection/yolov5face/utils/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..794953f26a4219ebe8fbc8a7d21a6f7b0839bdaf Binary files /dev/null and b/scripts/facelib/detection/yolov5face/utils/__pycache__/datasets.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/utils/__pycache__/general.cpython-310.pyc b/scripts/facelib/detection/yolov5face/utils/__pycache__/general.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..014d2a699c165df270fcd64e824f02a3ce75303a Binary files /dev/null and b/scripts/facelib/detection/yolov5face/utils/__pycache__/general.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/utils/__pycache__/torch_utils.cpython-310.pyc b/scripts/facelib/detection/yolov5face/utils/__pycache__/torch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f4ac32cae47b6388c61f24f3da6ddfb2ee23a05 Binary files /dev/null and b/scripts/facelib/detection/yolov5face/utils/__pycache__/torch_utils.cpython-310.pyc differ diff --git a/scripts/facelib/detection/yolov5face/utils/autoanchor.py b/scripts/facelib/detection/yolov5face/utils/autoanchor.py new file mode 100644 index 0000000000000000000000000000000000000000..a4eba3e94888709be7d2a7c7499fbcc1808b4a88 --- /dev/null +++ b/scripts/facelib/detection/yolov5face/utils/autoanchor.py @@ -0,0 +1,12 @@ +# Auto-anchor utils + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print("Reversing anchor order") + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) diff --git a/scripts/facelib/detection/yolov5face/utils/datasets.py b/scripts/facelib/detection/yolov5face/utils/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e672b136f56fd6b05038e24377908361a54fe519 --- /dev/null +++ b/scripts/facelib/detection/yolov5face/utils/datasets.py @@ -0,0 +1,35 @@ +import cv2 +import numpy as np + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): + # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding + elif scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) diff --git a/scripts/facelib/detection/yolov5face/utils/extract_ckpt.py b/scripts/facelib/detection/yolov5face/utils/extract_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8b631348f2d0cdea4e5a3594bb59f3e8f34a0f --- /dev/null +++ b/scripts/facelib/detection/yolov5face/utils/extract_ckpt.py @@ -0,0 +1,5 @@ +import torch +import sys +sys.path.insert(0,'./facelib/detection/yolov5face') +model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] +torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') \ No newline at end of file diff --git a/scripts/facelib/detection/yolov5face/utils/general.py b/scripts/facelib/detection/yolov5face/utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8e14f56a107ec3a4269c382cfc5168ad780ffc --- /dev/null +++ b/scripts/facelib/detection/yolov5face/utils/general.py @@ -0,0 +1,271 @@ +import math +import time + +import numpy as np +import torch +import torchvision + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + # if new_size != img_size: + # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") + return new_size + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) + + +def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label = labels[xi] + v = torch.zeros((len(label), nc + 15), device=x.device) + v[:, :4] = label[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + break # time limit exceeded + + return output + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label_id = labels[xi] + v = torch.zeros((len(label_id), nc + 5), device=x.device) + v[:, :4] = label_id[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f"WARNING: NMS time limit {time_limit}s exceeded") + break # time limit exceeded + + return output + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords diff --git a/scripts/facelib/detection/yolov5face/utils/torch_utils.py b/scripts/facelib/detection/yolov5face/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af2d06587b2d07b2eab199a8484380fde1de5c3c --- /dev/null +++ b/scripts/facelib/detection/yolov5face/utils/torch_utils.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (include and k not in include) or k.startswith("_") or k in exclude: + continue + + setattr(a, k, v) diff --git a/scripts/facelib/parsing/__init__.py b/scripts/facelib/parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f08263e2476cec53bb50245b8de20c2b5ce5a1f --- /dev/null +++ b/scripts/facelib/parsing/__init__.py @@ -0,0 +1,23 @@ +import torch + +from scripts.facelib.utils import load_file_from_url +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' + elif model_name == 'parsenet': + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/scripts/facelib/parsing/__pycache__/__init__.cpython-310.pyc b/scripts/facelib/parsing/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e5fa6348677d0a6ef65a3abaef6f541abca73b Binary files /dev/null and b/scripts/facelib/parsing/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facelib/parsing/__pycache__/bisenet.cpython-310.pyc b/scripts/facelib/parsing/__pycache__/bisenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e18163afccc79f17652117e0e255838427a106 Binary files /dev/null and b/scripts/facelib/parsing/__pycache__/bisenet.cpython-310.pyc differ diff --git a/scripts/facelib/parsing/__pycache__/parsenet.cpython-310.pyc b/scripts/facelib/parsing/__pycache__/parsenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c13d78fc5bd7185ae38b4e94576432d2c07d381 Binary files /dev/null and b/scripts/facelib/parsing/__pycache__/parsenet.cpython-310.pyc differ diff --git a/scripts/facelib/parsing/__pycache__/resnet.cpython-310.pyc b/scripts/facelib/parsing/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..294ec1cb674c101555f27348b0f924ce95d0b8ba Binary files /dev/null and b/scripts/facelib/parsing/__pycache__/resnet.cpython-310.pyc differ diff --git a/scripts/facelib/parsing/bisenet.py b/scripts/facelib/parsing/bisenet.py new file mode 100644 index 0000000000000000000000000000000000000000..3898cab76ae5876459cd4899c54cafa14234971d --- /dev/null +++ b/scripts/facelib/parsing/bisenet.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) + out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) + out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/scripts/facelib/parsing/parsenet.py b/scripts/facelib/parsing/parsenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e178ebe43a1ef666aaea0bc0faf629485c22a24f --- /dev/null +++ b/scripts/facelib/parsing/parsenet.py @@ -0,0 +1,194 @@ +"""Modified from https://github.com/chaofengc/PSFRGAN +""" +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + + +class NormLayer(nn.Module): + """Normalization Layers. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. + """ + + def __init__(self, channels, normalize_shape=None, norm_type='bn'): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + if norm_type == 'bn': + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == 'in': + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == 'gn': + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == 'pixel': + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == 'layer': + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == 'none': + self.norm = lambda x: x * 1.0 + else: + assert 1 == 0, f'Norm type {norm_type} not support.' + + def forward(self, x, ref=None): + if self.norm_type == 'spade': + return self.norm(x, ref) + else: + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + + Args: + relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + + def __init__(self, channels, relu_type='relu'): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == 'relu': + self.func = nn.ReLU(True) + elif relu_type == 'leakyrelu': + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == 'prelu': + self.func = nn.PReLU(channels) + elif relu_type == 'selu': + self.func = nn.SELU(True) + elif relu_type == 'none': + self.func = lambda x: x * 1.0 + else: + assert 1 == 0, f'Relu type {relu_type} not support.' + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + scale='none', + norm_type='none', + relu_type='none', + use_pad=True, + bias=True): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + if norm_type in ['bn']: + bias = False + + stride = 2 if scale == 'down' else 1 + + self.scale_func = lambda x: x + if scale == 'up': + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): + super(ResidualBlock, self).__init__() + + if scale == 'none' and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + +class ParseNet(nn.Module): + + def __init__(self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type='LeakyReLU', + norm_type='bn', + ch_range=[32, 256]): + super().__init__() + self.res_depth = res_depth + act_args = {'norm_type': norm_type, 'relu_type': relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size // min_feat_size)) + up_steps = int(np.log2(out_size // min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img diff --git a/scripts/facelib/parsing/resnet.py b/scripts/facelib/parsing/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fec8e82cf64469fb51be21ad5130217052addbda --- /dev/null +++ b/scripts/facelib/parsing/resnet.py @@ -0,0 +1,69 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/scripts/facelib/utils/__init__.py b/scripts/facelib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f03b1c2bafcd7759cb7e8722a0c6715f201a46dc --- /dev/null +++ b/scripts/facelib/utils/__init__.py @@ -0,0 +1,7 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back +from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir + +__all__ = [ + 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', + 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' +] diff --git a/scripts/facelib/utils/__pycache__/__init__.cpython-310.pyc b/scripts/facelib/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31f5ff39665cf5576b31a6f94500fb9740f07ad8 Binary files /dev/null and b/scripts/facelib/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facelib/utils/__pycache__/face_restoration_helper.cpython-310.pyc b/scripts/facelib/utils/__pycache__/face_restoration_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa1334325603726f79396f40f4af4f36edae84d Binary files /dev/null and b/scripts/facelib/utils/__pycache__/face_restoration_helper.cpython-310.pyc differ diff --git a/scripts/facelib/utils/__pycache__/face_utils.cpython-310.pyc b/scripts/facelib/utils/__pycache__/face_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d1505a5ae50d0219cb98fca2759ad1ae5273424 Binary files /dev/null and b/scripts/facelib/utils/__pycache__/face_utils.cpython-310.pyc differ diff --git a/scripts/facelib/utils/__pycache__/misc.cpython-310.pyc b/scripts/facelib/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..300f9d635720807457cc68fb965b52cb6d91db72 Binary files /dev/null and b/scripts/facelib/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/scripts/facelib/utils/face_restoration_helper.py b/scripts/facelib/utils/face_restoration_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..157ac31a547fb196c1cc54f4a9dc7768798b32e7 --- /dev/null +++ b/scripts/facelib/utils/face_restoration_helper.py @@ -0,0 +1,525 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from scripts.facelib.detection import init_detection_model +from scripts.facelib.parsing import init_parsing_model +from scripts.facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy +from scripts.basicsr.utils.download_util import load_file_from_url +from scripts.basicsr.utils.misc import get_device + +dlib_model_url = { + 'face_detector': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat', + 'shape_predictor_5': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat' +} + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = int(upscale_factor) + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + self.det_model = det_model + + if self.det_model == 'dlib': + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941], + [337.91089109, 488.38613861], [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + elif self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = get_device() + else: + self.device = device + + # init face detection model + if self.det_model == 'dlib': + self.face_detector, self.shape_predictor_5 = self.init_dlib(dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5']) + else: + self.face_detector = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + self.is_gray = is_gray(img, threshold=10) + if self.is_gray: + print('Grayscale input: True') + + if min(self.input_img.shape[:2])<512: + f = 512.0/min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def init_dlib(self, detection_path, landmark5_path): + """Initialize the dlib detectors and predictors.""" + try: + import dlib + except ImportError: + print('Please install dlib by running:' 'conda install -c conda-forge dlib') + detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None) + landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None) + face_detector = dlib.cnn_face_detection_model_v1(detection_path) + shape_predictor_5 = dlib.shape_predictor(landmark5_path) + return face_detector, shape_predictor_5 + + def get_face_landmarks_5_dlib(self, + only_keep_largest=False, + scale=1): + det_faces = self.face_detector(self.input_img, scale) + + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + return 0 + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + + if len(self.det_faces) == 0: + return 0 + + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + + return len(self.all_landmarks_5) + + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if self.det_model == 'dlib': + return self.get_face_landmarks_5_dlib(only_keep_largest) + + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_detector.detect_faces(input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + + def add_restored_face(self, restored_face, input_face=None): + if self.is_gray: + restored_face = bgr2gray(restored_face) # convert img into grayscale + if input_face is not None: + restored_face = adain_npy(restored_face, input_face) # transfer the color + self.restored_faces.append(restored_face) + + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + + inv_mask_borders = [] + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # if draw_box or not self.use_parse: # use square parse maps + # mask = np.ones(face_size, dtype=np.float32) + # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # # remove the black borders + # inv_mask_erosion = cv2.erode( + # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + # pasted_face = inv_mask_erosion[:, :, None] * inv_restored + # total_face_area = np.sum(inv_mask_erosion) # // 3 + # # add border + # if draw_box: + # h, w = face_size + # mask_border = np.ones((h, w, 3), dtype=np.float32) + # border = int(1400/np.sqrt(total_face_area)) + # mask_border[border:h-border, border:w-border,:] = 0 + # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + # inv_mask_borders.append(inv_mask_border) + # if not self.use_parse: + # # compute the fusion edge based on the area of face + # w_edge = int(total_face_area**0.5) // 20 + # erosion_radius = w_edge * 2 + # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + # blur_size = w_edge * 2 + # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + # if len(upsample_img.shape) == 2: # upsample_img is gray image + # upsample_img = upsample_img[:, :, None] + # inv_soft_mask = inv_soft_mask[:, :, None] + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400/np.sqrt(total_face_area)) + mask_border[border:h-border, border:w-border,:] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255. + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:,:,0] = 0 + img_color[:,:,1] = 255 + img_color[:,:,2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] \ No newline at end of file diff --git a/scripts/facelib/utils/face_utils.py b/scripts/facelib/utils/face_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b89cfde0b1151dc0d497a96746b13c2f058279 --- /dev/null +++ b/scripts/facelib/utils/face_utils.py @@ -0,0 +1,248 @@ +import cv2 +import numpy as np +import torch + + +def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): + left, top, right, bot = bbox + width = right - left + height = bot - top + + if preserve_aspect: + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + else: + width_increase = height_increase = increase_area + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + return (left, top, right, bot) + + +def get_valid_bboxes(bboxes, h, w): + left = max(bboxes[0], 0) + top = max(bboxes[1], 0) + right = min(bboxes[2], w) + bottom = min(bboxes[3], h) + return (left, top, right, bottom) + + +def align_crop_face_landmarks(img, + landmarks, + output_size, + transform_size=None, + enable_padding=True, + return_inverse_affine=False, + shrink_ratio=(1, 1)): + """Align and crop face with landmarks. + + The output_size and transform_size are based on width. The height is + adjusted based on shrink_ratio_h/shring_ration_w. + + Modified from: + https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + Args: + img (Numpy array): Input image. + landmarks (Numpy array): 5 or 68 or 98 landmarks. + output_size (int): Output face size. + transform_size (ing): Transform size. Usually the four time of + output_size. + enable_padding (float): Default: True. + shrink_ratio (float | tuple[float] | list[float]): Shring the whole + face for height and width (crop larger area). Default: (1, 1). + + Returns: + (Numpy array): Cropped face. + """ + lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 + + if isinstance(shrink_ratio, (float, int)): + shrink_ratio = (shrink_ratio, shrink_ratio) + if transform_size is None: + transform_size = output_size * 4 + + # Parse landmarks + lm = np.array(landmarks) + if lm.shape[0] == 5 and lm_type == 'retinaface_5': + eye_left = lm[0] + eye_right = lm[1] + mouth_avg = (lm[3] + lm[4]) * 0.5 + elif lm.shape[0] == 5 and lm_type == 'dlib_5': + lm_eye_left = lm[2:4] + lm_eye_right = lm[0:2] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = lm[4] + elif lm.shape[0] == 68: + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[48] + lm[54]) * 0.5 + elif lm.shape[0] == 98: + lm_eye_left = lm[60:68] + lm_eye_right = lm[68:76] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[76] + lm[82]) * 0.5 + + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1 # TODO: you can edit it to get larger rect + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + x *= shrink_ratio[1] # width + y *= shrink_ratio[0] # height + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + + quad_ori = np.copy(quad) + # Shrink, for large face + # TODO: do we really need shrink + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + h, w = img.shape[0:2] + rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) + img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) + quad /= shrink + qsize /= shrink + + # Crop + h, w = img.shape[0:2] + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) + if crop[2] - crop[0] < w or crop[3] - crop[1] < h: + img = img[crop[1]:crop[3], crop[0]:crop[2], :] + quad -= crop[0:2] + + # Pad + # pad: (width_left, height_top, width_right, height_bottom) + h, w = img.shape[0:2] + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w = img.shape[0:2] + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * 0.02) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) + + img = img.astype('float32') + img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = np.clip(img, 0, 255) # float32, [0, 255] + quad += pad[:2] + + # Transform use cv2 + h_ratio = shrink_ratio[0] / shrink_ratio[1] + dst_h, dst_w = int(transform_size * h_ratio), transform_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] + cropped_face = cv2.warpAffine( + img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray + + if output_size < transform_size: + cropped_face = cv2.resize( + cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) + + if return_inverse_affine: + dst_h, dst_w = int(output_size * h_ratio), output_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D( + quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] + inverse_affine = cv2.invertAffineTransform(affine_matrix) + else: + inverse_affine = None + return cropped_face, inverse_affine + + +def paste_face_back(img, face, inverse_affine): + h, w = img.shape[0:2] + face_h, face_w = face.shape[0:2] + inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) + mask = np.ones((face_h, face_w, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) + # remove the black borders + inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img + # float32, [0, 255] + return img + + +if __name__ == '__main__': + import os + + from scripts.facelib.detection import init_detection_model + from scripts.facelib.utils.face_restoration_helper import get_largest_face + + img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png' + img_name = os.splitext(os.path.basename(img_path))[0] + + # initialize model + det_net = init_detection_model('retinaface_resnet50', half=False) + img_ori = cv2.imread(img_path) + h, w = img_ori.shape[0:2] + # if larger than 800, scale it + scale = max(h / 800, w / 800) + if scale > 1: + img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) + + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + if scale > 1: + bboxes *= scale # the score is incorrect + bboxes = get_largest_face(bboxes, h, w)[0] + + landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) + + cropped_face, inverse_affine = align_crop_face_landmarks( + img_ori, + landmarks, + output_size=512, + transform_size=None, + enable_padding=True, + return_inverse_affine=True, + shrink_ratio=(1, 1)) + + cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face) + img = paste_face_back(img_ori, cropped_face, inverse_affine) + cv2.imwrite(f'tmp/{img_name}_back.png', img) diff --git a/scripts/facelib/utils/misc.py b/scripts/facelib/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..03aa05622cbcb2aa0ee923e3336e793c4ebed56d --- /dev/null +++ b/scripts/facelib/utils/misc.py @@ -0,0 +1,202 @@ +import cv2 +import os +import os.path as osp +import numpy as np +from PIL import Image +import torch +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse +# from scripts.basicsr.utils.download_util import download_file_from_google_drive + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def download_pretrained_models(file_ids, save_path_root): + import gdown + + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = 'https://drive.google.com/uc?id='+file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accepts Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def is_gray(img, threshold=10): + img = Image.fromarray(img) + if len(img.getbands()) == 1: + return True + img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16) + img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16) + img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16) + diff1 = (img1 - img2).var() + diff2 = (img2 - img3).var() + diff3 = (img3 - img1).var() + diff_sum = (diff1 + diff2 + diff3) / 3.0 + if diff_sum <= threshold: + return True + else: + return False + +def rgb2gray(img, out_channel=3): + r, g, b = img[:,:,0], img[:,:,1], img[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + if out_channel == 3: + gray = gray[:,:,np.newaxis].repeat(3, axis=2) + return gray + +def bgr2gray(img, out_channel=3): + b, g, r = img[:,:,0], img[:,:,1], img[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + if out_channel == 3: + gray = gray[:,:,np.newaxis].repeat(3, axis=2) + return gray + + +def calc_mean_std(feat, eps=1e-5): + """ + Args: + feat (numpy): 3D [w h c]s + """ + size = feat.shape + assert len(size) == 3, 'The input feature should be 3D tensor.' + c = size[2] + feat_var = feat.reshape(-1, c).var(axis=0) + eps + feat_std = np.sqrt(feat_var).reshape(1, 1, c) + feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c) + return feat_mean, feat_std + + +def adain_npy(content_feat, style_feat): + """Adaptive instance normalization for numpy. + + Args: + content_feat (numpy): The input feature. + style_feat (numpy): The reference feature. + """ + size = content_feat.shape + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size) + return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size) \ No newline at end of file diff --git a/scripts/facerender/__pycache__/animate.cpython-310.pyc b/scripts/facerender/__pycache__/animate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b75435c0acc56b3866085b0b8adf7ed82a1c60c3 Binary files /dev/null and b/scripts/facerender/__pycache__/animate.cpython-310.pyc differ diff --git a/scripts/facerender/animate.py b/scripts/facerender/animate.py new file mode 100644 index 0000000000000000000000000000000000000000..9031cecb04a0efffc392ea34860024f1d9a7c5bc --- /dev/null +++ b/scripts/facerender/animate.py @@ -0,0 +1,255 @@ +import os +import cv2 +import yaml +import numpy as np +import warnings +from skimage import img_as_ubyte +import safetensors +import safetensors.torch +warnings.filterwarnings('ignore') + + +import imageio +import torch +import torchvision + + +from scripts.facerender.modules.keypoint_detector import HEEstimator, KPDetector +from scripts.facerender.modules.mapping import MappingNet +from scripts.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator +from scripts.facerender.modules.make_animation import make_animation + +from pydub import AudioSegment +from scripts.utils.face_enhancer import enhancer_generator_with_len, enhancer_list +from scripts.utils.paste_pic import paste_pic +from scripts.utils.videoio import save_video_with_watermark + +try: + import webui # in webui + in_webui = True +except: + in_webui = False + +class AnimateFromCoeff(): + + def __init__(self, sadtalker_path, device): + + with open(sadtalker_path['facerender_yaml']) as f: + config = yaml.safe_load(f) + + generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) + he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], + **config['model_params']['common_params']) + mapping = MappingNet(**config['model_params']['mapping_params']) + + generator.to(device) + kp_extractor.to(device) + he_estimator.to(device) + mapping.to(device) + for param in generator.parameters(): + param.requires_grad = False + for param in kp_extractor.parameters(): + param.requires_grad = False + for param in he_estimator.parameters(): + param.requires_grad = False + for param in mapping.parameters(): + param.requires_grad = False + + if sadtalker_path is not None: + if 'checkpoint' in sadtalker_path: # use safe tensor + self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None) + else: + self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) + else: + raise AttributeError("Checkpoint should be specified for video head pose estimator.") + + if sadtalker_path['mappingnet_checkpoint'] is not None: + self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping) + else: + raise AttributeError("Checkpoint should be specified for video head pose estimator.") + + self.kp_extractor = kp_extractor + self.generator = generator + self.he_estimator = he_estimator + self.mapping = mapping + + self.kp_extractor.eval() + self.generator.eval() + self.he_estimator.eval() + self.mapping.eval() + + self.device = device + + def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, + kp_detector=None, he_estimator=None, + device="cpu"): + + checkpoint = safetensors.torch.load_file(checkpoint_path) + + if generator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'generator' in k: + x_generator[k.replace('generator.', '')] = v + generator.load_state_dict(x_generator) + if kp_detector is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'kp_extractor' in k: + x_generator[k.replace('kp_extractor.', '')] = v + kp_detector.load_state_dict(x_generator) + if he_estimator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'he_estimator' in k: + x_generator[k.replace('he_estimator.', '')] = v + he_estimator.load_state_dict(x_generator) + + return None + + def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, + kp_detector=None, he_estimator=None, optimizer_generator=None, + optimizer_discriminator=None, optimizer_kp_detector=None, + optimizer_he_estimator=None, device="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint['he_estimator']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) + + return checkpoint['epoch'] + + def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, + optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if mapping is not None: + mapping.load_state_dict(checkpoint['mapping']) + if discriminator is not None: + discriminator.load_state_dict(checkpoint['discriminator']) + if optimizer_mapping is not None: + optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) + if optimizer_discriminator is not None: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + + return checkpoint['epoch'] + + def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): + + source_image=x['source_image'].type(torch.FloatTensor) + source_semantics=x['source_semantics'].type(torch.FloatTensor) + target_semantics=x['target_semantics_list'].type(torch.FloatTensor) + source_image=source_image.to(self.device) + source_semantics=source_semantics.to(self.device) + target_semantics=target_semantics.to(self.device) + if 'yaw_c_seq' in x: + yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor) + yaw_c_seq = x['yaw_c_seq'].to(self.device) + else: + yaw_c_seq = None + if 'pitch_c_seq' in x: + pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor) + pitch_c_seq = x['pitch_c_seq'].to(self.device) + else: + pitch_c_seq = None + if 'roll_c_seq' in x: + roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) + roll_c_seq = x['roll_c_seq'].to(self.device) + else: + roll_c_seq = None + + frame_num = x['frame_num'] + + predictions_video = make_animation(source_image, source_semantics, target_semantics, + self.generator, self.kp_extractor, self.he_estimator, self.mapping, + yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True) + + predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) + predictions_video = predictions_video[:frame_num] + + video = [] + for idx in range(predictions_video.shape[0]): + image = predictions_video[idx] + image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) + video.append(image) + result = img_as_ubyte(video) + + ### the generated video is 256x256, so we keep the aspect ratio, + original_size = crop_info[0] + if original_size: + result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] + + video_name = x['video_name'] + '.mp4' + path = os.path.join(video_save_dir, 'temp_'+video_name) + + imageio.mimsave(path, result, fps=float(25)) + + av_path = os.path.join(video_save_dir, video_name) + return_path = av_path + + audio_path = x['audio_path'] + audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] + new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') + start_time = 0 + # cog will not keep the .mp3 filename + sound = AudioSegment.from_file(audio_path) + frames = frame_num + end_time = start_time + frames*1/25*1000 + word1=sound.set_frame_rate(16000) + word = word1[start_time:end_time] + word.export(new_audio_path, format="wav") + + save_video_with_watermark(path, new_audio_path, av_path, watermark= False) + + if 'full' in preprocess.lower(): + # only add watermark to the full image. + video_name_full = x['video_name'] + '_full.mp4' + full_video_path = os.path.join(video_save_dir, video_name_full) + return_path = full_video_path + paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) + + else: + full_video_path = av_path + + #### paste back then enhancers + if enhancer: + video_name_enhancer = x['video_name'] + '_enhanced.mp4' + enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) + av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) + return_path = av_path_enhancer + + try: + enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) + imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) + except: + enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) + imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) + + save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) + os.remove(enhanced_path) + + os.remove(path) + os.remove(new_audio_path) + + return return_path + diff --git a/scripts/facerender/modules/__pycache__/dense_motion.cpython-310.pyc b/scripts/facerender/modules/__pycache__/dense_motion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bbae2fbb92bc9f6c7af00dcdaa5c0050c073aa5 Binary files /dev/null and b/scripts/facerender/modules/__pycache__/dense_motion.cpython-310.pyc differ diff --git a/scripts/facerender/modules/__pycache__/generator.cpython-310.pyc b/scripts/facerender/modules/__pycache__/generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f51968444f7798e7c143deb1356ced3610ff0a37 Binary files /dev/null and b/scripts/facerender/modules/__pycache__/generator.cpython-310.pyc differ diff --git a/scripts/facerender/modules/__pycache__/keypoint_detector.cpython-310.pyc b/scripts/facerender/modules/__pycache__/keypoint_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..271adf93cb64927992117eed2be4f981317c5f27 Binary files /dev/null and b/scripts/facerender/modules/__pycache__/keypoint_detector.cpython-310.pyc differ diff --git a/scripts/facerender/modules/__pycache__/make_animation.cpython-310.pyc b/scripts/facerender/modules/__pycache__/make_animation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c501c5408d59620941a5b8839e21aa8cba05c66 Binary files /dev/null and b/scripts/facerender/modules/__pycache__/make_animation.cpython-310.pyc differ diff --git a/scripts/facerender/modules/__pycache__/mapping.cpython-310.pyc b/scripts/facerender/modules/__pycache__/mapping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5d4a19ee797e8dfe968ee018f14c7cecee09818 Binary files /dev/null and b/scripts/facerender/modules/__pycache__/mapping.cpython-310.pyc differ diff --git a/scripts/facerender/modules/__pycache__/util.cpython-310.pyc b/scripts/facerender/modules/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c032f9b43854348352e891bfb9dd894d9fd2da6 Binary files /dev/null and b/scripts/facerender/modules/__pycache__/util.cpython-310.pyc differ diff --git a/scripts/facerender/modules/dense_motion.py b/scripts/facerender/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..a850119ea33b8ce468d672c208427c8ce9b13a3f --- /dev/null +++ b/scripts/facerender/modules/dense_motion.py @@ -0,0 +1,121 @@ +from torch import nn +import torch.nn.functional as F +import torch +from scripts.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian + +from scripts.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving + """ + + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, + estimate_occlusion_map=False): + super(DenseMotionNetwork, self).__init__() + # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) + + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) + self.norm = BatchNorm3d(compress, affine=True) + + if estimate_occlusion_map: + # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + self.num_kp = num_kp + + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) + coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) + + # if 'jacobian' in kp_driving: + if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: + jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + + + driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + #adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3 + + # sparse_motions = driving_to_source + + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!! + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) + heatmap = gaussian_driving - gaussian_source + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + + feature = self.compress(feature) + feature = self.norm(feature) + feature = F.relu(feature) + + out_dict = dict() + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) + + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) + + input_ = torch.cat([heatmap, deformed_feature], dim=2) + input_ = input_.view(bs, -1, d, h, w) + + # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) + + prediction = self.hourglass(input_) + + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + + zeros_mask = torch.zeros_like(mask) + mask = torch.where(mask < 1e-3, zeros_mask, mask) + + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.occlusion: + bs, c, d, h, w = prediction.shape + prediction = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/scripts/facerender/modules/discriminator.py b/scripts/facerender/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..d4459b07cb075c9f9d345f9b3dffc02cd859313b --- /dev/null +++ b/scripts/facerender/modules/discriminator.py @@ -0,0 +1,90 @@ +from torch import nn +import torch.nn.functional as F +from facerender.modules.util import kp2gaussian +import torch + + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, x): + feature_maps = [] + out = x + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + feature_maps, prediction_map = disc(x[key]) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict diff --git a/scripts/facerender/modules/generator.py b/scripts/facerender/modules/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e868905bba93bc0d63410a890510e4f88dce5235 --- /dev/null +++ b/scripts/facerender/modules/generator.py @@ -0,0 +1,255 @@ +import torch +from torch import nn +import torch.nn.functional as F +from scripts.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock +from scripts.facerender.modules.dense_motion import DenseMotionNetwork + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator follows NVIDIA architecture. + """ + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.resblocks_2d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) + out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.ModuleList(up_blocks) + + self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image + + # Decoding part + out = self.resblocks_2d(out) + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = F.sigmoid(out) + + output_dict["prediction"] = out + + return output_dict + + +class SPADEDecoder(nn.Module): + def __init__(self): + super().__init__() + ic = 256 + oc = 64 + norm_G = 'spadespectralinstance' + label_nc = 256 + + self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) + self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) + self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) + self.up = nn.Upsample(scale_factor=2) + + def forward(self, feature): + seg = feature + x = self.fc(feature) + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + x = self.up(x) + x = self.up_0(x, seg) # 256, 128, 128 + x = self.up(x) + x = self.up_1(x, seg) # 64, 256, 256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + # x = torch.tanh(x) + x = F.sigmoid(x) + + return x + + +class OcclusionAwareSPADEGenerator(nn.Module): + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareSPADEGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + self.decoder = SPADEDecoder() + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + # import pdb; pdb.set_trace() + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # Decoding part + out = self.decoder(out) + + output_dict["prediction"] = out + + return output_dict \ No newline at end of file diff --git a/scripts/facerender/modules/keypoint_detector.py b/scripts/facerender/modules/keypoint_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..6f49c0bf16619bff55925ae97632974e8ec5df0d --- /dev/null +++ b/scripts/facerender/modules/keypoint_detector.py @@ -0,0 +1,179 @@ +from torch import nn +import torch +import torch.nn.functional as F + +from scripts.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from scripts.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck + + +class KPDetector(nn.Module): + """ + Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): + super(KPDetector, self).__init__() + + self.predictor = KPHourglass(block_expansion, in_features=image_channel, + max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) + + # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) + self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) + self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) + ''' + initial as: + [[1 0 0] + [0 1 0] + [0 0 1]] + ''' + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) + + def gaussian2kp(self, heatmap): + """ + Extract the mean from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + value = (heatmap * grid).sum(dim=(2, 3, 4)) + kp = {'value': value} + + return kp + + def forward(self, x): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], + final_shape[3], final_shape[4]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map + jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) + jacobian = jacobian.sum(dim=-1) + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) + out['jacobian'] = jacobian + + return out + + +class HEEstimator(nn.Module): + """ + Estimating head pose and expression. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): + super(HEEstimator, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) + self.norm1 = BatchNorm2d(block_expansion, affine=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) + self.norm2 = BatchNorm2d(256, affine=True) + + self.block1 = nn.Sequential() + for i in range(3): + self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) + + self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) + self.norm3 = BatchNorm2d(512, affine=True) + self.block2 = ResBottleneck(in_features=512, stride=2) + + self.block3 = nn.Sequential() + for i in range(3): + self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) + + self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) + self.norm4 = BatchNorm2d(1024, affine=True) + self.block4 = ResBottleneck(in_features=1024, stride=2) + + self.block5 = nn.Sequential() + for i in range(5): + self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) + + self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) + self.norm5 = BatchNorm2d(2048, affine=True) + self.block6 = ResBottleneck(in_features=2048, stride=2) + + self.block7 = nn.Sequential() + for i in range(2): + self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) + + self.fc_roll = nn.Linear(2048, num_bins) + self.fc_pitch = nn.Linear(2048, num_bins) + self.fc_yaw = nn.Linear(2048, num_bins) + + self.fc_t = nn.Linear(2048, 3) + + self.fc_exp = nn.Linear(2048, 3*num_kp) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.maxpool(out) + + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + + out = self.block1(out) + + out = self.conv3(out) + out = self.norm3(out) + out = F.relu(out) + out = self.block2(out) + + out = self.block3(out) + + out = self.conv4(out) + out = self.norm4(out) + out = F.relu(out) + out = self.block4(out) + + out = self.block5(out) + + out = self.conv5(out) + out = self.norm5(out) + out = F.relu(out) + out = self.block6(out) + + out = self.block7(out) + + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.shape[0], -1) + + yaw = self.fc_roll(out) + pitch = self.fc_pitch(out) + roll = self.fc_yaw(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} + diff --git a/scripts/facerender/modules/make_animation.py b/scripts/facerender/modules/make_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..3360c53501a064f35d7db21a5361f89aa9658b42 --- /dev/null +++ b/scripts/facerender/modules/make_animation.py @@ -0,0 +1,170 @@ +from scipy.spatial import ConvexHull +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm + +def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, + use_relative_movement=False, use_relative_jacobian=False): + if adapt_movement_scale: + source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume + driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = {k: v for k, v in kp_driving.items()} + + if use_relative_movement: + kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) + kp_value_diff *= adapt_movement_scale + kp_new['value'] = kp_value_diff + kp_source['value'] + + if use_relative_jacobian: + jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) + kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) + + return kp_new + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 + return degree + +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), + torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), + torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), + torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), + -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), + torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) + + return rot_mat + +def keypoint_transformation(kp_canonical, he, wo_exp=False): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll= he['yaw'], he['pitch'], he['roll'] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + + if 'yaw_in' in he: + yaw = he['yaw_in'] + if 'pitch_in' in he: + pitch = he['pitch_in'] + if 'roll_in' in he: + roll = he['roll_in'] + + rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + t, exp = he['t'], he['exp'] + if wo_exp: + exp = exp*0 + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + # keypoint translation + t[:, 0] = t[:, 0]*0 + t[:, 2] = t[:, 2]*0 + t = t.unsqueeze(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + return {'value': kp_transformed} + + + +def make_animation(source_image, source_semantics, target_semantics, + generator, kp_detector, he_estimator, mapping, + yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, + use_exp=True, use_half=False): + with torch.no_grad(): + predictions = [] + + kp_canonical = kp_detector(source_image) + he_source = mapping(source_semantics) + kp_source = keypoint_transformation(kp_canonical, he_source) + + for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'): + # still check the dimension + # print(target_semantics.shape, source_semantics.shape) + target_semantics_frame = target_semantics[:, frame_idx] + he_driving = mapping(target_semantics_frame) + if yaw_c_seq is not None: + he_driving['yaw_in'] = yaw_c_seq[:, frame_idx] + if pitch_c_seq is not None: + he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] + if roll_c_seq is not None: + he_driving['roll_in'] = roll_c_seq[:, frame_idx] + + kp_driving = keypoint_transformation(kp_canonical, he_driving) + + kp_norm = kp_driving + out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) + ''' + source_image_new = out['prediction'].squeeze(1) + kp_canonical_new = kp_detector(source_image_new) + he_source_new = he_estimator(source_image_new) + kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True) + kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True) + out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new) + ''' + predictions.append(out['prediction']) + predictions_ts = torch.stack(predictions, dim=1) + return predictions_ts + +class AnimateModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, generator, kp_extractor, mapping): + super(AnimateModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.mapping = mapping + + self.kp_extractor.eval() + self.generator.eval() + self.mapping.eval() + + def forward(self, x): + + source_image = x['source_image'] + source_semantics = x['source_semantics'] + target_semantics = x['target_semantics'] + yaw_c_seq = x['yaw_c_seq'] + pitch_c_seq = x['pitch_c_seq'] + roll_c_seq = x['roll_c_seq'] + + predictions_video = make_animation(source_image, source_semantics, target_semantics, + self.generator, self.kp_extractor, + self.mapping, use_exp = True, + yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) + + return predictions_video \ No newline at end of file diff --git a/scripts/facerender/modules/mapping.py b/scripts/facerender/modules/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3a1c2d1770996080c08e9daafb346f05d7bcdd --- /dev/null +++ b/scripts/facerender/modules/mapping.py @@ -0,0 +1,47 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MappingNet(nn.Module): + def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): + super( MappingNet, self).__init__() + + self.layer = layer + nonlinearity = nn.LeakyReLU(0.1) + + self.first = nn.Sequential( + torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) + + for i in range(layer): + net = nn.Sequential(nonlinearity, + torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) + setattr(self, 'encoder' + str(i), net) + + self.pooling = nn.AdaptiveAvgPool1d(1) + self.output_nc = descriptor_nc + + self.fc_roll = nn.Linear(descriptor_nc, num_bins) + self.fc_pitch = nn.Linear(descriptor_nc, num_bins) + self.fc_yaw = nn.Linear(descriptor_nc, num_bins) + self.fc_t = nn.Linear(descriptor_nc, 3) + self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) + + def forward(self, input_3dmm): + out = self.first(input_3dmm) + for i in range(self.layer): + model = getattr(self, 'encoder' + str(i)) + out = model(out) + out[:,:,3:-3] + out = self.pooling(out) + out = out.view(out.shape[0], -1) + #print('out:', out.shape) + + yaw = self.fc_yaw(out) + pitch = self.fc_pitch(out) + roll = self.fc_roll(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} \ No newline at end of file diff --git a/scripts/facerender/modules/util.py b/scripts/facerender/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8299d19414fb5e3b376026a5ae3f048e0ade534d --- /dev/null +++ b/scripts/facerender/modules/util.py @@ -0,0 +1,564 @@ +from torch import nn + +import torch.nn.functional as F +import torch + +from scripts.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from scripts.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + +import torch.nn.utils.spectral_norm as spectral_norm + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp['value'] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + +def make_coordinate_grid_2d(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +def make_coordinate_grid(spatial_size, type): + d, h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + z = torch.arange(d).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + z = (2 * (z / (d - 1)) - 1) + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ResBottleneck(nn.Module): + def __init__(self, in_features, stride): + super(ResBottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) + self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) + self.norm1 = BatchNorm2d(in_features//4, affine=True) + self.norm2 = BatchNorm2d(in_features//4, affine=True) + self.norm3 = BatchNorm2d(in_features, affine=True) + + self.stride = stride + if self.stride != 1: + self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) + self.norm4 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv3(out) + out = self.norm3(out) + if self.stride != 1: + x = self.skip(x) + x = self.norm4(x) + out += x + out = F.relu(out) + return out + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm3d(in_features, affine=True) + self.norm2 = BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + + def forward(self, x): + # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + # self.out_filters = block_expansion + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + # for up_block in self.up_blocks[:-1]: + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + # out = self.up_blocks[-1](out) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class KPHourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): + super(KPHourglass, self).__init__() + + self.down_blocks = nn.Sequential() + for i in range(num_blocks): + self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + + in_filters = min(max_features, block_expansion * (2 ** num_blocks)) + self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) + + self.up_blocks = nn.Sequential() + for i in range(num_blocks): + in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) + out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) + self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.reshape_depth = reshape_depth + self.out_filters = out_filters + + def forward(self, x): + out = self.down_blocks(x) + out = self.conv(out) + bs, c, h, w = out.shape + out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) + out = self.up_blocks(out) + + return out + + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + +class audio2image(nn.Module): + def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params): + super().__init__() + # Attributes + self.generator = generator + self.kp_extractor = kp_extractor + self.he_estimator_video = he_estimator_video + self.he_estimator_audio = he_estimator_audio + self.train_params = train_params + + def headpose_pred_to_degree(self, pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 + + return degree + + def get_rotation_matrix(self, yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), + torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), + torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), + -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), + torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), + torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) + + return rot_mat + + def keypoint_transformation(self, kp_canonical, he): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] + t, exp = he['t'], he['exp'] + + yaw = self.headpose_pred_to_degree(yaw) + pitch = self.headpose_pred_to_degree(pitch) + roll = self.headpose_pred_to_degree(roll) + + rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + return {'value': kp_transformed} + + def forward(self, source_image, target_audio): + pose_source = self.he_estimator_video(source_image) + pose_generated = self.he_estimator_audio(target_audio) + kp_canonical = self.kp_extractor(source_image) + kp_source = self.keypoint_transformation(kp_canonical, pose_source) + kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) + generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) + return generated \ No newline at end of file diff --git a/scripts/facerender/sync_batchnorm/__init__.py b/scripts/facerender/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/scripts/facerender/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/scripts/facerender/sync_batchnorm/__pycache__/__init__.cpython-310.pyc b/scripts/facerender/sync_batchnorm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40452d69da02b88698cbe2122d629fa3ee3cdb41 Binary files /dev/null and b/scripts/facerender/sync_batchnorm/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc b/scripts/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57cf019e6374452b3657fb701922e7f916cad7be Binary files /dev/null and b/scripts/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc differ diff --git a/scripts/facerender/sync_batchnorm/__pycache__/comm.cpython-310.pyc b/scripts/facerender/sync_batchnorm/__pycache__/comm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef55357d8abfcc7a3d81f1ef906130d66dd8be6e Binary files /dev/null and b/scripts/facerender/sync_batchnorm/__pycache__/comm.cpython-310.pyc differ diff --git a/scripts/facerender/sync_batchnorm/__pycache__/replicate.cpython-310.pyc b/scripts/facerender/sync_batchnorm/__pycache__/replicate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6229dfe96b876bcad4fa4d767900df1756b6edd Binary files /dev/null and b/scripts/facerender/sync_batchnorm/__pycache__/replicate.cpython-310.pyc differ diff --git a/scripts/facerender/sync_batchnorm/batchnorm.py b/scripts/facerender/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4e763f0366dffa10320116413f8c7181a8aeb1 --- /dev/null +++ b/scripts/facerender/sync_batchnorm/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/scripts/facerender/sync_batchnorm/comm.py b/scripts/facerender/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/scripts/facerender/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/scripts/facerender/sync_batchnorm/replicate.py b/scripts/facerender/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/scripts/facerender/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/scripts/facerender/sync_batchnorm/unittest.py b/scripts/facerender/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/scripts/facerender/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/scripts/generate_batch.py b/scripts/generate_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f2772fbb5f8be9ec6cc9c58911ad3993e9adf3 --- /dev/null +++ b/scripts/generate_batch.py @@ -0,0 +1,120 @@ +import os + +from tqdm import tqdm +import torch +import numpy as np +import random +import scipy.io as scio +import scripts.utils.audio as audio + +def crop_pad_audio(wav, audio_length): + if len(wav) > audio_length: + wav = wav[:audio_length] + elif len(wav) < audio_length: + wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) + return wav + +def parse_audio_length(audio_length, sr, fps): + bit_per_frames = sr / fps + + num_frames = int(audio_length / bit_per_frames) + audio_length = int(num_frames * bit_per_frames) + + return audio_length, num_frames + +def generate_blink_seq(num_frames): + ratio = np.zeros((num_frames,1)) + frame_id = 0 + while frame_id in range(num_frames): + start = 80 + if frame_id+start+9<=num_frames - 1: + ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] + frame_id = frame_id+start+9 + else: + break + return ratio + +def generate_blink_seq_randomly(num_frames): + ratio = np.zeros((num_frames,1)) + if num_frames<=20: + return ratio + frame_id = 0 + while frame_id in range(num_frames): + start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) + if frame_id+start+5<=num_frames - 1: + ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] + frame_id = frame_id+start+5 + else: + break + return ratio + +def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): + + syncnet_mel_step_size = 16 + fps = 25 + + pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] + audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] + + + if idlemode: + num_frames = int(length_of_audio * 25) + indiv_mels = np.zeros((num_frames, 80, 16)) + else: + wav = audio.load_wav(audio_path, 16000) + wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) + wav = crop_pad_audio(wav, wav_length) + orig_mel = audio.melspectrogram(wav).T + spec = orig_mel.copy() # nframes 80 + indiv_mels = [] + + for i in tqdm(range(num_frames), 'mel:'): + start_frame_num = i-2 + start_idx = int(80. * (start_frame_num / float(fps))) + end_idx = start_idx + syncnet_mel_step_size + seq = list(range(start_idx, end_idx)) + seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] + m = spec[seq, :] + indiv_mels.append(m.T) + indiv_mels = np.asarray(indiv_mels) # T 80 16 + + ratio = generate_blink_seq_randomly(num_frames) # T + source_semantics_path = first_coeff_path + source_semantics_dict = scio.loadmat(source_semantics_path) + ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 + ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) + + if ref_eyeblink_coeff_path is not None: + ratio[:num_frames] = 0 + refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) + refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] + refeyeblink_num_frames = refeyeblink_coeff.shape[0] + if refeyeblink_num_frames frame_num: + new_degree_list = new_degree_list[:frame_num] + elif len(new_degree_list) < frame_num: + for _ in range(frame_num-len(new_degree_list)): + new_degree_list.append(new_degree_list[-1]) + print(len(new_degree_list)) + print(frame_num) + + remainder = frame_num%batch_size + if remainder!=0: + for _ in range(batch_size-remainder): + new_degree_list.append(new_degree_list[-1]) + new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) + return new_degree_np + diff --git a/scripts/maskscratches/__init__.py b/scripts/maskscratches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..812e2c21ba36e36ac7122ae3990e450fdfdc225f --- /dev/null +++ b/scripts/maskscratches/__init__.py @@ -0,0 +1,3 @@ +from .detection_models.antialiasing import Downsample +from .detection_models.sync_batchnorm import DataParallelWithCallback +from .scratches_detector import ScratchesDetector diff --git a/scripts/maskscratches/__pycache__/__init__.cpython-310.pyc b/scripts/maskscratches/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32e14ce8c41a114cbff4b4b6176a337e17fd99ce Binary files /dev/null and b/scripts/maskscratches/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/maskscratches/__pycache__/__init__.cpython-311.pyc b/scripts/maskscratches/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae0370e2eb23ccfeda5c2540ca1e95bab100877e Binary files /dev/null and b/scripts/maskscratches/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/maskscratches/__pycache__/scratches_detector.cpython-310.pyc b/scripts/maskscratches/__pycache__/scratches_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0c6e14b030da3ec15ff1431af4dc2a6d1e50df Binary files /dev/null and b/scripts/maskscratches/__pycache__/scratches_detector.cpython-310.pyc differ diff --git a/scripts/maskscratches/__pycache__/scratches_detector.cpython-311.pyc b/scripts/maskscratches/__pycache__/scratches_detector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5818e8e1cad994eef0dd4782e60380a840d8a0f Binary files /dev/null and b/scripts/maskscratches/__pycache__/scratches_detector.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/.gitignore b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7bbc71c09205c78d790739d246bbe4f9f1881c17 --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/.gitignore @@ -0,0 +1,101 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/LICENSE b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c39939e7e3aa940d405030335ec0e6ff2f2a1ee --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/README.md b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..779983436c9727dd0d6301a1c857f2360245b51d --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/README.md @@ -0,0 +1,118 @@ +# Synchronized-BatchNorm-PyTorch + +**IMPORTANT: Please read the "Implementation details and highlights" section before use.** + +Synchronized Batch Normalization implementation in PyTorch. + +This module differs from the built-in PyTorch BatchNorm as the mean and +standard-deviation are reduced across all devices during training. + +For example, when one uses `nn.DataParallel` to wrap the network during +training, PyTorch's implementation normalize the tensor on each device using +the statistics only on that device, which accelerated the computation and +is also easy to implement, but the statistics might be inaccurate. +Instead, in this synchronized version, the statistics will be computed +over all training samples distributed on multiple devices. + +Note that, for one-GPU or CPU-only case, this module behaves exactly same +as the built-in PyTorch implementation. + +This module is currently only a prototype version for research usages. As mentioned below, +it has its limitations and may even suffer from some design problems. If you have any +questions or suggestions, please feel free to +[open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or +[submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues). + +## Why Synchronized BatchNorm? + +Although the typical implementation of BatchNorm working on multiple devices (GPUs) +is fast (with no communication overhead), it inevitably reduces the size of batch size, +which potentially degenerates the performance. This is not a significant issue in some +standard vision tasks such as ImageNet classification (as the batch size per device +is usually large enough to obtain good statistics). However, it will hurt the performance +in some tasks that the batch size is usually very small (e.g., 1 per GPU). + +For example, the importance of synchronized batch normalization in object detection has been recently proved with a +an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240). + +## Usage + +To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight +difference with typical usage of the `nn.DataParallel`. + +Use it with a provided, customized data parallel wrapper: + +```python +from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback + +sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) +sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) +``` + +Or, if you are using a customized data parallel module, you can use this library as a monkey patching. + +```python +from torch.nn import DataParallel # or your customized DataParallel module +from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback + +sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) +sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) +patch_replication_callback(sync_bn) # monkey-patching +``` + +You can use `convert_model` to convert your model to use Synchronized BatchNorm easily. + +```python +import torch.nn as nn +from torchvision import models +from sync_batchnorm import convert_model +# m is a standard pytorch model +m = models.resnet18(True) +m = nn.DataParallel(m) +# after convert, m is using SyncBN +m = convert_model(m) +``` + +See also `tests/test_sync_batchnorm.py` for numeric result comparison. + +## Implementation details and highlights + +If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look +at the code with detailed comments. Here we only emphasize some highlights of the implementation: + +- This implementation is in pure-python. No C++ extra extension libs. +- Easy to use as demonstrated above. +- It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`. +- The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME +amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i +(i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced. +This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this +will usually not be the issue for most of the models. + +## Known issues + +#### Runtime error on backward pass. + +Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like: + +``` +Assertion `pos >= 0 && pos < buffer.size()` failed. +``` + +This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the + instructions [here](https://github.com/pytorch/pytorch#from-source). + +#### Numeric error. + +Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less +numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in +`tests/test_sync_batchnorm.py`. + +## Authors and License: + +Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz). + +**Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant). + +Distributed under **MIT License** (See LICENSE) + diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/__init__.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b36c74b1808b56ded68cf080a689db7e0ee4e --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..998223a0e0242dc4a5b2fcd74af79dc7232794da --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/__init__.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..63661389782806ea2182c049448df5d05fc6d2f1 --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4538ae3c50b4c457a9fa19bf22b5b1a7b666ee --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py @@ -0,0 +1,62 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm_v2.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 11/01/2018 +# +# Distributed under terms of the MIT license. + +""" +Test the numerical implementation of batch normalization. + +Author: acgtyrant. +See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 +""" + +import unittest + +import torch +import torch.nn as nn +import torch.optim as optim + +from sync_batchnorm.unittest import TorchTestCase +from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl + + +class NumericTestCasev2(TorchTestCase): + def testNumericBatchNorm(self): + CHANNELS = 16 + batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1) + optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01) + + batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1) + batchnorm2.weight.data.copy_(batchnorm1.weight.data) + batchnorm2.bias.data.copy_(batchnorm1.bias.data) + optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01) + + for _ in range(100): + input_ = torch.rand(16, CHANNELS, 16, 16) + + input1 = input_.clone().requires_grad_(True) + output1 = batchnorm1(input1) + output1.sum().backward() + optimizer1.step() + + input2 = input_.clone().requires_grad_(True) + output2 = batchnorm2(input2) + output2.sum().backward() + optimizer2.step() + + self.assertTensorClose(input1, input2) + self.assertTensorClose(output1, output2) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad) + self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad) + self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean) + self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean) + + +if __name__ == '__main__': + unittest.main() + diff --git a/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7b6c64c06fc26348489cd15669501a2098c82f --- /dev/null +++ b/scripts/maskscratches/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import set_sbn_eps_mode +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + +set_sbn_eps_mode('plus') + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/maskscratches/detection_models/__init__.py b/scripts/maskscratches/detection_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/maskscratches/detection_models/__pycache__/__init__.cpython-310.pyc b/scripts/maskscratches/detection_models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fd544c23d99f300a2ae5f5c429e86874ae07933 Binary files /dev/null and b/scripts/maskscratches/detection_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/__pycache__/__init__.cpython-311.pyc b/scripts/maskscratches/detection_models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01500a36a3495e9059e3d924374e8cbd2a4d2148 Binary files /dev/null and b/scripts/maskscratches/detection_models/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/__pycache__/antialiasing.cpython-310.pyc b/scripts/maskscratches/detection_models/__pycache__/antialiasing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c80269e4744f16d24ca591b092ea2ff0f69fa7d8 Binary files /dev/null and b/scripts/maskscratches/detection_models/__pycache__/antialiasing.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/__pycache__/antialiasing.cpython-311.pyc b/scripts/maskscratches/detection_models/__pycache__/antialiasing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de6266e06924c8b7f08c1ba29600410665be3b04 Binary files /dev/null and b/scripts/maskscratches/detection_models/__pycache__/antialiasing.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/__pycache__/networks.cpython-310.pyc b/scripts/maskscratches/detection_models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53d5f279ddb8a78520bea5611fa613f9287b170a Binary files /dev/null and b/scripts/maskscratches/detection_models/__pycache__/networks.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/__pycache__/networks.cpython-311.pyc b/scripts/maskscratches/detection_models/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d4ad9d8a53a7d2bf299b59411ef77b40fd7285 Binary files /dev/null and b/scripts/maskscratches/detection_models/__pycache__/networks.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/antialiasing.py b/scripts/maskscratches/detection_models/antialiasing.py new file mode 100644 index 0000000000000000000000000000000000000000..03715425d52959f844fff88cf574d026d1a10a67 --- /dev/null +++ b/scripts/maskscratches/detection_models/antialiasing.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.parallel +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +class Downsample(nn.Module): + # https://github.com/adobe/antialiased-cnns + + def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [ + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + ] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.0) + self.channels = channels + + # print('Filter size [%i]'%filt_size) + if self.filt_size == 1: + a = np.array([1.0,]) + elif self.filt_size == 2: + a = np.array([1.0, 1.0]) + elif self.filt_size == 3: + a = np.array([1.0, 2.0, 1.0]) + elif self.filt_size == 4: + a = np.array([1.0, 3.0, 3.0, 1.0]) + elif self.filt_size == 5: + a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) + elif self.filt_size == 6: + a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) + elif self.filt_size == 7: + a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + + filt = torch.Tensor(a[:, None] * a[None, :]) + filt = filt / torch.sum(filt) + self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if self.filt_size == 1: + if self.pad_off == 0: + return inp[:, :, :: self.stride, :: self.stride] + else: + return self.pad(inp)[:, :, :: self.stride, :: self.stride] + else: + return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +def get_pad_layer(pad_type): + PadLayer = None + if pad_type in ["refl", "reflect"]: + PadLayer = nn.ReflectionPad2d + elif pad_type in ["repl", "replicate"]: + PadLayer = nn.ReplicationPad2d + elif pad_type == "zero": + PadLayer = nn.ZeroPad2d + else: + print("Pad type [%s] not recognized" % pad_type) + return PadLayer diff --git a/scripts/maskscratches/detection_models/networks.py b/scripts/maskscratches/detection_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed515fda35bee285cb31c817d233b529a89ba8f --- /dev/null +++ b/scripts/maskscratches/detection_models/networks.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn + +from scripts.maskscratches import Downsample +from scripts.maskscratches import DataParallelWithCallback + + +class UNet(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + depth=5, + conv_num=2, + wf=6, + padding=True, + batch_norm=True, + up_mode="upsample", + with_tanh=False, + sync_bn=True, + antialiasing=True, + ): + """ + Implementation of + U-Net: Convolutional Networks for Biomedical Image Segmentation + (Ronneberger et al., 2015) + https://arxiv.org/abs/1505.04597 + Using the default arguments will yield the exact version used + in the original paper + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + depth (int): depth of the network + wf (int): number of filters in the first layer is 2**wf + padding (bool): if True, apply padding such that the input shape + is the same as the output. + This may introduce artifacts + batch_norm (bool): Use BatchNorm after layers with an + activation function + up_mode (str): one of 'upconv' or 'upsample'. + 'upconv' will use transposed convolutions for + learned upsampling. + 'upsample' will use bilinear upsampling. + """ + + super().__init__() + assert up_mode in ("upconv", "upsample") + self.padding = padding + self.depth = depth - 1 + prev_channels = in_channels + + self.first = nn.Sequential( + *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] + ) + prev_channels = 2 ** wf + + self.down_path = nn.ModuleList() + self.down_sample = nn.ModuleList() + for i in range(depth): + if antialiasing and depth > 0: + self.down_sample.append( + nn.Sequential( + *[ + nn.ReflectionPad2d(1), + nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), + nn.BatchNorm2d(prev_channels), + nn.LeakyReLU(0.2, True), + Downsample(channels=prev_channels, stride=2), + ] + ) + ) + else: + self.down_sample.append( + nn.Sequential( + *[ + nn.ReflectionPad2d(1), + nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), + nn.BatchNorm2d(prev_channels), + nn.LeakyReLU(0.2, True), + ] + ) + ) + self.down_path.append( + UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) + ) + prev_channels = 2 ** (wf + i + 1) + + self.up_path = nn.ModuleList() + for i in reversed(range(depth)): + self.up_path.append( + UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) + ) + prev_channels = 2 ** (wf + i) + + if with_tanh: + self.last = nn.Sequential( + *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] + ) + else: + self.last = nn.Sequential( + *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] + ) + + if sync_bn: + self = DataParallelWithCallback(self) + + def forward(self, x): + x = self.first(x) + + blocks = [] + for i, down_block in enumerate(self.down_path): + blocks.append(x) + x = self.down_sample[i](x) + x = down_block(x) + + for i, up in enumerate(self.up_path): + x = up(x, blocks[-i - 1]) + + return self.last(x) + + +class UNetConvBlock(nn.Module): + def __init__(self, conv_num, in_size, out_size, padding, batch_norm): + super(UNetConvBlock, self).__init__() + block = [] + + for _ in range(conv_num): + block.append(nn.ReflectionPad2d(padding=int(padding))) + block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) + if batch_norm: + block.append(nn.BatchNorm2d(out_size)) + block.append(nn.LeakyReLU(0.2, True)) + in_size = out_size + + self.block = nn.Sequential(*block) + + def forward(self, x): + out = self.block(x) + return out + + +class UNetUpBlock(nn.Module): + def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): + super(UNetUpBlock, self).__init__() + if up_mode == "upconv": + self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) + elif up_mode == "upsample": + self.up = nn.Sequential( + nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), + nn.ReflectionPad2d(1), + nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), + ) + + self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) + + def center_crop(self, layer, target_size): + _, _, layer_height, layer_width = layer.size() + diff_y = (layer_height - target_size[0]) // 2 + diff_x = (layer_width - target_size[1]) // 2 + return layer[:, :, diff_y: (diff_y + target_size[0]), diff_x: (diff_x + target_size[1])] + + def forward(self, x, bridge): + up = self.up(x) + crop1 = self.center_crop(bridge, up.shape[2:]) + out = torch.cat([up, crop1], 1) + out = self.conv_block(out) + + return out + + diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__init__.py b/scripts/maskscratches/detection_models/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b36c74b1808b56ded68cf080a689db7e0ee4e --- /dev/null +++ b/scripts/maskscratches/detection_models/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/__init__.cpython-310.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fc9bbb45102928cd6e942d2f94b5e343fcffab4 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/__init__.cpython-311.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..681ee7ea3aa8325b414171164ac50e97d6c868f4 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/__init__.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3d3ee4752bdea30a6e888f005613ac46285b649 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/batchnorm.cpython-311.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/batchnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..809d3eb7198f5736047f727f86cd93e6c6056113 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/batchnorm.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/comm.cpython-310.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/comm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd24d22121aae6ca5ba41de9e904719e37ad17e Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/comm.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/comm.cpython-311.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/comm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e484befcc86c97686f48df02a8e584ed90a90b85 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/comm.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/replicate.cpython-310.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/replicate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..939628a4f63b7d2cccc840f2be71f14a87e072a8 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/replicate.cpython-310.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/replicate.cpython-311.pyc b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/replicate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db91d9bd7af87e7c4b110477b99589d0663dae63 Binary files /dev/null and b/scripts/maskscratches/detection_models/sync_batchnorm/__pycache__/replicate.cpython-311.pyc differ diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/batchnorm.py b/scripts/maskscratches/detection_models/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/scripts/maskscratches/detection_models/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/batchnorm_reimpl.py b/scripts/maskscratches/detection_models/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/scripts/maskscratches/detection_models/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/comm.py b/scripts/maskscratches/detection_models/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/scripts/maskscratches/detection_models/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/replicate.py b/scripts/maskscratches/detection_models/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab6616dc8b4e2d6daf2d96bab2ba35f6f6283f1 --- /dev/null +++ b/scripts/maskscratches/detection_models/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/scripts/maskscratches/detection_models/sync_batchnorm/unittest.py b/scripts/maskscratches/detection_models/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..998223a0e0242dc4a5b2fcd74af79dc7232794da --- /dev/null +++ b/scripts/maskscratches/detection_models/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/scripts/maskscratches/scratches_detector.py b/scripts/maskscratches/scratches_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..5ccf10e9597571bc7e9edc835fae06d5f49c6aa4 --- /dev/null +++ b/scripts/maskscratches/scratches_detector.py @@ -0,0 +1,68 @@ +############################################################################# +# +# Source from: +# https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life +# Forked from: +# +# Reimplemented by: Leonel Hernández +# +############################################################################## +import logging +import os +import numpy as np +import torch +import torch.nn.functional as F +import torchvision as tv +from PIL import ImageFile, Image + +from scripts.maskscratches.detection_models import networks +from scripts.util import data_transforms, scale_tensor, tensor_to_ndarray + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class ScratchesDetector: + + def __init__(self, snapshot_folder): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model_mask = networks.UNet( + in_channels=1, + out_channels=1, + depth=4, + conv_num=2, + wf=6, + padding=True, + batch_norm=True, + up_mode="upsample", + with_tanh=False, + sync_bn=True, + antialiasing=True, + ) + + model_path = os.path.join(snapshot_folder, "detection/FT_Epoch_latest.pt") + checkpoint = torch.load(model_path, map_location=device) + self.model_mask.load_state_dict(checkpoint["model_state"]) + self.model_mask.cpu() + self.model_mask.eval() + + def process(self, image: Image) -> np.array: + logging.info("Start detecting scratches") + transformed_image = data_transforms(image, size="full_size") + image = transformed_image.convert("L") + image = tv.transforms.ToTensor()(image) + image = tv.transforms.Normalize([0.5], [0.5])(image) + image = torch.unsqueeze(image, 0) + _, _, ow, oh = image.shape + scratch_image_scale = scale_tensor(image) + + scratch_image_scale = scratch_image_scale.cpu() + with torch.no_grad(): + prediction = torch.sigmoid(self.model_mask(scratch_image_scale)) + + prediction = prediction.data.cpu() + prediction = F.interpolate(prediction, [ow, oh], mode="nearest") + + tensor_mask = (prediction >= 0.4).float() + scratches_mask_image = tensor_to_ndarray(tensor_mask) + transformed_image = np.array(transformed_image) + return transformed_image, scratches_mask_image diff --git a/scripts/morph_video.py b/scripts/morph_video.py new file mode 100644 index 0000000000000000000000000000000000000000..5b278941eaa3353f2bcdf3f90a7fc88472b84892 --- /dev/null +++ b/scripts/morph_video.py @@ -0,0 +1,377 @@ +import sys +import os +import dlib +import numpy as np +from skimage import io +import cv2 +from imutils import face_utils +import argparse +import shutil +import random + +import subprocess + + +class NoFaceFound(Exception): + """Raised when there is no face found""" + pass + +def calculate_margin_help(img1,img2): + size1 = img1.shape + size2 = img2.shape + diff0 = abs(size1[0]-size2[0])//2 + diff1 = abs(size1[1]-size2[1])//2 + avg0 = (size1[0]+size2[0])//2 + avg1 = (size1[1]+size2[1])//2 + + return [size1,size2,diff0,diff1,avg0,avg1] + +def crop_image(img1,img2): + [size1,size2,diff0,diff1,avg0,avg1] = calculate_margin_help(img1,img2) + + if(size1[0] == size2[0] and size1[1] == size2[1]): + return [img1,img2] + + elif(size1[0] <= size2[0] and size1[1] <= size2[1]): + scale0 = size1[0]/size2[0] + scale1 = size1[1]/size2[1] + if(scale0 > scale1): + res = cv2.resize(img2,None,fx=scale0,fy=scale0,interpolation=cv2.INTER_AREA) + else: + res = cv2.resize(img2,None,fx=scale1,fy=scale1,interpolation=cv2.INTER_AREA) + return crop_image_help(img1,res) + + elif(size1[0] >= size2[0] and size1[1] >= size2[1]): + scale0 = size2[0]/size1[0] + scale1 = size2[1]/size1[1] + if(scale0 > scale1): + res = cv2.resize(img1,None,fx=scale0,fy=scale0,interpolation=cv2.INTER_AREA) + else: + res = cv2.resize(img1,None,fx=scale1,fy=scale1,interpolation=cv2.INTER_AREA) + return crop_image_help(res,img2) + + elif(size1[0] >= size2[0] and size1[1] <= size2[1]): + return [img1[diff0:avg0,:],img2[:,-diff1:avg1]] + + else: + return [img1[:,diff1:avg1],img2[-diff0:avg0,:]] + +def crop_image_help(img1,img2): + [size1,size2,diff0,diff1,avg0,avg1] = calculate_margin_help(img1,img2) + + if(size1[0] == size2[0] and size1[1] == size2[1]): + return [img1,img2] + + elif(size1[0] <= size2[0] and size1[1] <= size2[1]): + return [img1,img2[-diff0:avg0,-diff1:avg1]] + + elif(size1[0] >= size2[0] and size1[1] >= size2[1]): + return [img1[diff0:avg0,diff1:avg1],img2] + + elif(size1[0] >= size2[0] and size1[1] <= size2[1]): + return [img1[diff0:avg0,:],img2[:,-diff1:avg1]] + + else: + return [img1[:,diff1:avg1],img2[diff0:avg0,:]] + +def generate_face_correspondences(theImage1, theImage2): + # Detect the points of face. + detector = dlib.get_frontal_face_detector() + predictor = dlib.shape_predictor('./models/shape_predictor_68_face_landmarks.dat') + corresp = np.zeros((68,2)) + + imgList = crop_image(theImage1,theImage2) + list1 = [] + list2 = [] + j = 1 + + for img in imgList: + + size = (img.shape[0],img.shape[1]) + if(j == 1): + currList = list1 + else: + currList = list2 + + # Ask the detector to find the bounding boxes of each face. The 1 in the + # second argument indicates that we should upsample the image 1 time. This + # will make everything bigger and allow us to detect more faces. + + dets = detector(img, 1) + + try: + if len(dets) == 0: + raise NoFaceFound + except NoFaceFound: + print("Sorry, but I couldn't find a face in the image.") + + j=j+1 + + for k, rect in enumerate(dets): + + # Get the landmarks/parts for the face in rect. + shape = predictor(img, rect) + # corresp = face_utils.shape_to_np(shape) + + for i in range(0,68): + x = shape.part(i).x + y = shape.part(i).y + currList.append((x, y)) + corresp[i][0] += x + corresp[i][1] += y + # cv2.circle(img, (x, y), 2, (0, 255, 0), 2) + + # Add back the background + currList.append((1,1)) + currList.append((size[1]-1,1)) + currList.append(((size[1]-1)//2,1)) + currList.append((1,size[0]-1)) + currList.append((1,(size[0]-1)//2)) + currList.append(((size[1]-1)//2,size[0]-1)) + currList.append((size[1]-1,size[0]-1)) + currList.append(((size[1]-1),(size[0]-1)//2)) + + # Add back the background + narray = corresp/2 + narray = np.append(narray,[[1,1]],axis=0) + narray = np.append(narray,[[size[1]-1,1]],axis=0) + narray = np.append(narray,[[(size[1]-1)//2,1]],axis=0) + narray = np.append(narray,[[1,size[0]-1]],axis=0) + narray = np.append(narray,[[1,(size[0]-1)//2]],axis=0) + narray = np.append(narray,[[(size[1]-1)//2,size[0]-1]],axis=0) + narray = np.append(narray,[[size[1]-1,size[0]-1]],axis=0) + narray = np.append(narray,[[(size[1]-1),(size[0]-1)//2]],axis=0) + + return [size,imgList[0],imgList[1],list1,list2,narray] + + + + +# Check if a point is inside a rectangle +def rect_contains(rect, point): + + if point[0] < rect[0]: + return False + elif point[1] < rect[1]: + return False + elif point[0] > rect[2]: + return False + elif point[1] > rect[3]: + return False + return True + +# Write the delaunay triangles into a file +def draw_delaunay(f_w, f_h, subdiv, dictionary1): + + list4 = [] + + triangleList = subdiv.getTriangleList() + r = (0, 0, f_w, f_h) + + for t in triangleList : + pt1 = (int(t[0]), int(t[1])) + pt2 = (int(t[2]), int(t[3])) + pt3 = (int(t[4]), int(t[5])) + + if rect_contains(r, pt1) and rect_contains(r, pt2) and rect_contains(r, pt3) : + list4.append((dictionary1[pt1],dictionary1[pt2],dictionary1[pt3])) + + dictionary1 = {} + return list4 + +def make_delaunay(f_w, f_h, theList, img1, img2): + + # Make a rectangle. + rect = (0, 0, f_w, f_h) + + # Create an instance of Subdiv2D. + subdiv = cv2.Subdiv2D(rect) + + # Make a points list and a searchable dictionary. + theList = theList.tolist() + points = [(int(x[0]),int(x[1])) for x in theList] + dictionary = {x[0]:x[1] for x in list(zip(points, range(76)))} + + # Insert points into subdiv + for p in points : + subdiv.insert(p) + + # Make a delaunay triangulation list. + list4 = draw_delaunay(f_w, f_h, subdiv, dictionary) + + # Return the list. + return list4 + + + +import numpy as np +import cv2 +import sys +import os +import math +from subprocess import Popen, PIPE +from PIL import Image + +# Apply affine transform calculated using srcTri and dstTri to src and +# output an image of size. +def apply_affine_transform(src, srcTri, dstTri, size) : + + # Given a pair of triangles, find the affine transform. + warpMat = cv2.getAffineTransform(np.float32(srcTri), np.float32(dstTri)) + + # Apply the Affine Transform just found to the src image + dst = cv2.warpAffine(src, warpMat, (size[0], size[1]), None, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) + + return dst + + +# Warps and alpha blends triangular regions from img1 and img2 to img +def morph_triangle(img1, img2, img, t1, t2, t, alpha) : + + # Find bounding rectangle for each triangle + r1 = cv2.boundingRect(np.float32([t1])) + r2 = cv2.boundingRect(np.float32([t2])) + r = cv2.boundingRect(np.float32([t])) + + # Offset points by left top corner of the respective rectangles + t1Rect = [] + t2Rect = [] + tRect = [] + + for i in range(0, 3): + tRect.append(((t[i][0] - r[0]),(t[i][1] - r[1]))) + t1Rect.append(((t1[i][0] - r1[0]),(t1[i][1] - r1[1]))) + t2Rect.append(((t2[i][0] - r2[0]),(t2[i][1] - r2[1]))) + + # Get mask by filling triangle + mask = np.zeros((r[3], r[2], 3), dtype = np.float32) + cv2.fillConvexPoly(mask, np.int32(tRect), (1.0, 1.0, 1.0), 16, 0) + + # Apply warpImage to small rectangular patches + img1Rect = img1[r1[1]:r1[1] + r1[3], r1[0]:r1[0] + r1[2]] + img2Rect = img2[r2[1]:r2[1] + r2[3], r2[0]:r2[0] + r2[2]] + + size = (r[2], r[3]) + warpImage1 = apply_affine_transform(img1Rect, t1Rect, tRect, size) + warpImage2 = apply_affine_transform(img2Rect, t2Rect, tRect, size) + + # Alpha blend rectangular patches + imgRect = (1.0 - alpha) * warpImage1 + alpha * warpImage2 + + # Copy triangular region of the rectangular patch to the output image + img[r[1]:r[1]+r[3], r[0]:r[0]+r[2]] = img[r[1]:r[1]+r[3], r[0]:r[0]+r[2]] * ( 1 - mask ) + imgRect * mask + + +def generate_morph_sequence(duration, frame_rate, img1, img2, points1, points2, tri_list, size, output): + num_images = int(duration * frame_rate) + p = subprocess.Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-r', str(frame_rate), '-s', str(size[1])+'x'+str(size[0]), '-i', '-', '-c:v', 'libx264', '-crf', '25', '-vf', 'scale=trunc(iw/2)*2:trunc(ih/2)*2', '-pix_fmt', 'yuv420p', output], stdin=subprocess.PIPE) + + for _ in range(10):#(int(frame_rate/3)): + res = Image.fromarray(cv2.cvtColor(np.uint8(img1), cv2.COLOR_BGR2RGB)) + res.save(p.stdin, 'JPEG') + + for j in range(0, num_images): + img1 = np.float32(img1) + img2 = np.float32(img2) + points = [] + alpha = j / (num_images - 1) + + for i in range(0, len(points1)): + x = (1 - alpha) * points1[i][0] + alpha * points2[i][0] + y = (1 - alpha) * points1[i][1] + alpha * points2[i][1] + points.append((x, y)) + + morphed_frame = np.zeros(img1.shape, dtype=img1.dtype) + + for i in range(len(tri_list)): + x = int(tri_list[i][0]) + y = int(tri_list[i][1]) + z = int(tri_list[i][2]) + + t1 = [points1[x], points1[y], points1[z]] + t2 = [points2[x], points2[y], points2[z]] + t = [points[x], points[y], points[z]] + + morph_triangle(img1, img2, morphed_frame, t1, t2, t, alpha) + + res = Image.fromarray(cv2.cvtColor(np.uint8(morphed_frame), cv2.COLOR_BGR2RGB)) + res.save(p.stdin, 'JPEG') + + p.stdin.close() + p.wait() + + + +# def doMorphing(image_paths, duration, frame_rate, output): +# output_files = [] +# to_delete = [] +# for i in range(len(image_paths) - 1): +# img1 = cv2.imread(image_paths[i]) +# img2 = cv2.imread(image_paths[i + 1]) +# size, img1, img2, points1, points2, list3 = generate_face_correspondences(img1, img2) +# tri = make_delaunay(size[1], size[0], list3, img1, img2) +# output_file = f"{output}_{i}.mp4" +# to_delete.append(output_file) +# generate_morph_sequence(duration, frame_rate, img1, img2, points1, points2, tri, size, output_file) +# output_files.append(output_file) + +# # Concatenate videos into one +# ffmpeg_command = ['ffmpeg', '-y', '-f', 'concat', '-safe', '0', '-i', 'files.txt', '-c', 'copy', f"{output}_combined.mp4"] +# with open(f'files.txt', 'w') as f: +# for file in output_files: +# f.write(f"file '{file}'\n") +# subprocess.run(ffmpeg_command) +# os.remove(f'files.txt') + +# # Convert the final combined video to a GIF +# gif_command = [ +# 'ffmpeg', '-y', '-i', f"{output}_combined.mp4", '-vf', 'fps=10,scale=320:-1:flags=lanczos,palettegen', f'{output}_palette.png' +# ] +# subprocess.run(gif_command) + +# gif_command = [ +# 'ffmpeg', '-y', '-i', f"{output}_combined.mp4", '-i', f'{output}_palette.png', '-filter_complex', +# 'fps=10,scale=320:-1:flags=lanczos[x];[x][1:v]paletteuse', f'{output}.gif' +# ] +# subprocess.run(gif_command) +# os.remove(f'{output}_palette.png') +# for file_delete in to_delete: +# os.remove(file_delete) + + +def doMorphing(image_paths, duration, frame_rate, output): + output_files = [] + for i in range(len(image_paths) - 1): + img1 = cv2.imread(image_paths[i]) + img2 = cv2.imread(image_paths[i + 1]) + size, img1, img2, points1, points2, list3 = generate_face_correspondences(img1, img2) + tri = make_delaunay(size[1], size[0], list3, img1, img2) + output_file = f"{output}_{i}.mp4" + generate_morph_sequence(duration, frame_rate, img1, img2, points1, points2, tri, size, output_file) + output_files.append(output_file) + + # Concatenate videos into one + ffmpeg_command = ['ffmpeg', '-y', '-f', 'concat', '-safe', '0', '-i', 'files.txt', '-c', 'copy', f"{output}_combined.mp4"] + with open(f'files.txt', 'w') as f: + for file in output_files: + f.write(f"file '{file}'\n") + subprocess.run(ffmpeg_command) + os.remove(f'files.txt') + + # Convert the final combined video to a GIF + gif_command = [ + 'ffmpeg', '-y', '-i', f"{output}_combined.mp4", '-vf', 'fps=10,scale=600:-1:flags=lanczos,palettegen', f'{output}_palette.png' + ] + subprocess.run(gif_command) + + gif_command = [ + 'ffmpeg', '-y', '-i', f"{output}_combined.mp4", '-i', f'{output}_palette.png', '-filter_complex', + 'fps=10,scale=320:-1:flags=lanczos[x];[x][1:v]paletteuse', f'{output}.gif' + ] + subprocess.run(gif_command) + os.remove(f'{output}_palette.png') + + + + + diff --git a/scripts/paths_config.py b/scripts/paths_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0f01a3abc8534e58f6c09ba64dcdd8a5daec4cfb --- /dev/null +++ b/scripts/paths_config.py @@ -0,0 +1,12 @@ +dataset_paths = { + 'celeba_test': '', + 'ffhq': '', +} + +model_paths = { + 'pretrained_psp_encoder': 'pretrained_models/psp_ffhq_encode.pt', + 'ir_se50': 'pretrained_models/model_ir_se50.pth', + 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt', + 'shape_predictor': 'shape_predictor_68_face_landmarks.dat', + 'age_predictor': 'pretrained_models/dex_age_classifier.pth' +} diff --git a/scripts/psp.py b/scripts/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..7a63048e20c51c0ab422f1aa4cccc827be446aa9 --- /dev/null +++ b/scripts/psp.py @@ -0,0 +1,131 @@ +""" +This file defines the core research contribution +""" +import copy +from argparse import Namespace + +import torch +from torch import nn +import math + +from scripts.paths_config import model_paths +from scripts.encoders import psp_encoders +from scripts.stylegan2.model import Generator + + +class pSp(nn.Module): + + def __init__(self, opts): + super(pSp, self).__init__() + self.set_opts(opts) + self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 + # Define architecture + self.encoder = self.set_encoder() + self.decoder = Generator(self.opts.output_size, 512, 8) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_encoder(self): + return psp_encoders.GradualStyleEncoder(50, 'ir_se', self.n_styles, self.opts) + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print(f'Loading SAM from checkpoint: {self.opts.checkpoint_path}') + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=False) + self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True) + if self.opts.start_from_encoded_w_plus: + self.pretrained_encoder = self.__get_pretrained_psp_encoder() + self.pretrained_encoder.load_state_dict(self.__get_keys(ckpt, 'pretrained_encoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading encoders weights from irse50!') + encoder_ckpt = torch.load(model_paths['ir_se50']) + # Transfer the RGB input of the irse50 network to the first 3 input channels of SAM's encoder + if self.opts.input_nc != 3: + shape = encoder_ckpt['input_layer.0.weight'].shape + altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) + altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight'] + encoder_ckpt['input_layer.0.weight'] = altered_input_layer + self.encoder.load_state_dict(encoder_ckpt, strict=False) + print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=True) + self.__load_latent_avg(ckpt, repeat=self.n_styles) + if self.opts.start_from_encoded_w_plus: + self.pretrained_encoder = self.__load_pretrained_psp_encoder() + self.pretrained_encoder.eval() + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None, input_is_full=False): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + codes = codes + self.latent_avg + # normalize with respect to the latent of the encoded image of pretrained pSp encoder + elif self.opts.start_from_encoded_w_plus: + with torch.no_grad(): + encoded_latents = self.pretrained_encoder(x[:, :-1, :, :]) + encoded_latents = encoded_latents + self.latent_avg + codes = codes + encoded_latents + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = (not input_code) or (input_is_full) + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images + + def set_opts(self, opts): + self.opts = opts + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.opts.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None + + def __get_pretrained_psp_encoder(self): + opts_encoder = vars(copy.deepcopy(self.opts)) + opts_encoder['input_nc'] = 3 + opts_encoder = Namespace(**opts_encoder) + encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.n_styles, opts_encoder) + return encoder + + def __load_pretrained_psp_encoder(self): + print(f'Loading pSp encoder from checkpoint: {self.opts.pretrained_psp_path}') + ckpt = torch.load(self.opts.pretrained_psp_path, map_location='cpu') + encoder_ckpt = self.__get_keys(ckpt, name='encoder') + encoder = self.__get_pretrained_psp_encoder() + encoder.load_state_dict(encoder_ckpt, strict=False) + return encoder + + @staticmethod + def __get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt diff --git a/scripts/stylegan2/__init__.py b/scripts/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/stylegan2/__pycache__/__init__.cpython-310.pyc b/scripts/stylegan2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18b64908aaaa8005f8b0a0a7903cc0bf5a660577 Binary files /dev/null and b/scripts/stylegan2/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/stylegan2/__pycache__/model.cpython-310.pyc b/scripts/stylegan2/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3532e4ae039c55f2b0e4048a15ad8080c09ac041 Binary files /dev/null and b/scripts/stylegan2/__pycache__/model.cpython-310.pyc differ diff --git a/scripts/stylegan2/model.py b/scripts/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..357bf51a28741de75a6f73f7ba0fe43182f0e66d --- /dev/null +++ b/scripts/stylegan2/model.py @@ -0,0 +1,671 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from scripts.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + elif return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/scripts/stylegan2/op/__init__.py b/scripts/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/scripts/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/scripts/stylegan2/op/__pycache__/__init__.cpython-310.pyc b/scripts/stylegan2/op/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb44b6b8f9157e49cc4d3a1a32069c7b61c7b0a Binary files /dev/null and b/scripts/stylegan2/op/__pycache__/__init__.cpython-310.pyc differ diff --git a/scripts/stylegan2/op/__pycache__/fused_act.cpython-310.pyc b/scripts/stylegan2/op/__pycache__/fused_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3e666769f2f87ca230ce511693f3825ab50303 Binary files /dev/null and b/scripts/stylegan2/op/__pycache__/fused_act.cpython-310.pyc differ diff --git a/scripts/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc b/scripts/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bec4f4463b849097f03cda3be6caee4f64c3081b Binary files /dev/null and b/scripts/stylegan2/op/__pycache__/upfirdn2d.cpython-310.pyc differ diff --git a/scripts/stylegan2/op/fused_act.py b/scripts/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..973a84fffde53668d31397da5fb993bbc95f7be0 --- /dev/null +++ b/scripts/stylegan2/op/fused_act.py @@ -0,0 +1,85 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/scripts/stylegan2/op/fused_bias_act.cpp b/scripts/stylegan2/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/scripts/stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/scripts/stylegan2/op/fused_bias_act_kernel.cu b/scripts/stylegan2/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/scripts/stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/scripts/stylegan2/op/upfirdn2d.cpp b/scripts/stylegan2/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/scripts/stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/scripts/stylegan2/op/upfirdn2d.py b/scripts/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cb52219689592e2745600abb19fad02740a139 --- /dev/null +++ b/scripts/stylegan2/op/upfirdn2d.py @@ -0,0 +1,184 @@ +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, ::down_y, ::down_x, :] \ No newline at end of file diff --git a/scripts/stylegan2/op/upfirdn2d_kernel.cu b/scripts/stylegan2/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a88bc7720da6cd54fccd0c4a03dd20fde85c063d --- /dev/null +++ b/scripts/stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} \ No newline at end of file diff --git a/scripts/talker.py b/scripts/talker.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef7b3c18a0bac1077668e723935b6d516a6133d --- /dev/null +++ b/scripts/talker.py @@ -0,0 +1,90 @@ +import torch +import shutil +import torch +from scripts.utils.preprocess import CropAndExtract +from scripts.test_audio2coeff import Audio2Coeff +from scripts.facerender.animate import AnimateFromCoeff +from scripts.generate_batch import get_data +from scripts.generate_facerender_batch import get_facerender_data +import uuid +import os + + + +class sad_talker: + def __init__(self): + self.size = 256 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.sadtalker_paths = {'checkpoint': './models/SadTalker_V0.0.2_256.safetensors', 'dir_of_BFM_fitting': './scripts/config', 'audio2pose_yaml_path': './scripts/config/auido2pose.yaml', 'audio2exp_yaml_path': './scripts/config/auido2exp.yaml', 'use_safetensor': True, 'mappingnet_checkpoint': './models/mapping_00109-model.pth.tar', 'facerender_yaml': './scripts/config/facerender_still.yaml'} + self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device) + self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device) + self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device) + + def genrate_video(self,image_path , audio_path , output_folder , still = True ): + + try: + preprocess = 'full' + + temp_file_uuid = str(uuid.uuid4()) + save_dir = f'./outputs/{temp_file_uuid}' + + first_frame_dir = os.path.join(save_dir, 'first_frame_dir') + os.makedirs(first_frame_dir, exist_ok=True) + + first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(image_path, first_frame_dir, preprocess,\ + source_image_flag=True, pic_size=self.size) + + if first_coeff_path is None: + return None + + ref_eyeblink_coeff_path=None + ref_pose_coeff_path=None + + pose_style = 0 + batch_size = 2 + input_yaw_list = None + input_pitch_list = None + input_roll_list = None + background_enhancer = None + enhancer = None + expression_scale = 1. + + + batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path, still=still) + coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) + + data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, + batch_size, input_yaw_list, input_pitch_list, input_roll_list, + expression_scale=expression_scale, still_mode=still, preprocess=preprocess, size=self.size) + + + + result = self.animate_from_coeff.generate(data, save_dir, image_path, crop_info, \ + enhancer=enhancer, background_enhancer=background_enhancer, preprocess=preprocess, img_size=self.size) + + shutil.move(result,f"{output_folder}/output.mp4" ) + shutil.rmtree(save_dir) + + return True + + except: + + shutil.rmtree(save_dir) + + return False + + def __del__(self): + + self.preprocess_model = None + self.audio_to_coeff = None + self.animate_from_coeff = None + del self.preprocess_model + del self.audio_to_coeff + del self.animate_from_coeff + + torch.cuda.empty_cache() + import gc + gc.collect() + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/scripts/test_audio2coeff.py b/scripts/test_audio2coeff.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7d75cdac19f70ba527a1c5136b9ffce2c6127e --- /dev/null +++ b/scripts/test_audio2coeff.py @@ -0,0 +1,123 @@ +import os +import torch +import numpy as np +from scipy.io import savemat, loadmat +from yacs.config import CfgNode as CN +from scipy.signal import savgol_filter + +import safetensors +import safetensors.torch + +from scripts.audio2pose_models.audio2pose import Audio2Pose +from scripts.audio2exp_models.networks import SimpleWrapperV2 +from scripts.audio2exp_models.audio2exp import Audio2Exp +from scripts.utils.safetensor_helper import load_x_from_safetensor + +def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if model is not None: + model.load_state_dict(checkpoint['model']) + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + return checkpoint['epoch'] + +class Audio2Coeff(): + + def __init__(self, sadtalker_path, device): + #load config + fcfg_pose = open(sadtalker_path['audio2pose_yaml_path']) + cfg_pose = CN.load_cfg(fcfg_pose) + cfg_pose.freeze() + fcfg_exp = open(sadtalker_path['audio2exp_yaml_path']) + cfg_exp = CN.load_cfg(fcfg_exp) + cfg_exp.freeze() + + # load audio2pose_model + self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device) + self.audio2pose_model = self.audio2pose_model.to(device) + self.audio2pose_model.eval() + for param in self.audio2pose_model.parameters(): + param.requires_grad = False + + try: + if sadtalker_path['use_safetensor']: + checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) + self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose')) + else: + load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device) + except: + raise Exception("Failed in loading audio2pose_checkpoint") + + # load audio2exp_model + netG = SimpleWrapperV2() + netG = netG.to(device) + for param in netG.parameters(): + netG.requires_grad = False + netG.eval() + try: + if sadtalker_path['use_safetensor']: + checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) + netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp')) + else: + load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device) + except: + raise Exception("Failed in loading audio2exp_checkpoint") + self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) + self.audio2exp_model = self.audio2exp_model.to(device) + for param in self.audio2exp_model.parameters(): + param.requires_grad = False + self.audio2exp_model.eval() + + self.device = device + + def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None): + + with torch.no_grad(): + #test + results_dict_exp= self.audio2exp_model.test(batch) + exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64 + + #for class_id in range(1): + #class_id = 0#(i+10)%45 + #class_id = random.randint(0,46) #46 styles can be selected + batch['class'] = torch.LongTensor([pose_style]).to(self.device) + results_dict_pose = self.audio2pose_model.test(batch) + pose_pred = results_dict_pose['pose_pred'] #bs T 6 + + pose_len = pose_pred.shape[1] + if pose_len<13: + pose_len = int((pose_len-1)/2)*2+1 + pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device) + else: + pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) + + coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70 + + coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() + + if ref_pose_coeff_path is not None: + coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path) + + savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])), + {'coeff_3dmm': coeffs_pred_numpy}) + + return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])) + + def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path): + num_frames = coeffs_pred_numpy.shape[0] + refpose_coeff_dict = loadmat(ref_pose_coeff_path) + refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70] + refpose_num_frames = refpose_coeff.shape[0] + if refpose_num_frames 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size - 1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/scripts/util/util.py b/scripts/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a4de48b9ecb401f9d0c402185549f0c909e095ef --- /dev/null +++ b/scripts/util/util.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import print_function + +import os + +import PIL +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision.utils import make_grid + + +def array2image(ndarray): + return PIL.Image.fromarray(np.uint8(ndarray)).convert('RGB') + + +def tensor_to_ndarray(tensor, nrow=1, padding=0, normalize=True): + grid = make_grid(tensor, nrow, padding, normalize) + return grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + + +def irregular_hole_synthesize(image, mask): + img_np = np.array(image).astype("uint8") + mask_np = np.array(mask).astype("uint8") + mask_np = mask_np / 255 + img_new = img_np * (1 - mask_np) + mask_np * 255 + return PIL.Image.fromarray(img_new.astype("uint8")).convert("RGB") + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: + image_numpy = image_numpy[:, :, 0] + return image_numpy.astype(imtype) + + +# Converts a one-hot tensor into a colorful label map +def tensor2label(label_tensor, n_label, imtype=np.uint8): + if n_label == 0: + return tensor2im(label_tensor, imtype) + label_tensor = label_tensor.cpu().float() + if label_tensor.size()[0] > 1: + label_tensor = label_tensor.max(0, keepdim=True)[1] + label_tensor = Colorize(n_label)(label_tensor) + label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) + return label_numpy.astype(imtype) + + +def scale_tensor(img_tensor, default_scale=256): + _, _, w, h = img_tensor.shape + if w < h: + ow = default_scale + oh = h / w * default_scale + else: + oh = default_scale + ow = w / h * default_scale + + oh = int(round(oh / 16) * 16) + ow = int(round(ow / 16) * 16) + + return F.interpolate(img_tensor, [ow, oh], mode="bilinear") + + +def data_transforms(img, size="full_size", method=Image.BICUBIC): + if size == "full_size": + ow, oh = img.size + h = int(round(oh / 16) * 16) + w = int(round(ow / 16) * 16) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + + elif size == "scale_256": + ow, oh = img.size + pw, ph = ow, oh + if ow < oh: + ow = 256 + oh = ph / pw * 256 + else: + oh = 256 + ow = pw / ph * 256 + + h = int(round(oh / 16) * 16) + w = int(round(ow / 16) * 16) + if (h == ph) and (w == pw): + return img + return img.resize((w, h), method) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) diff --git a/scripts/util/visualizer.py b/scripts/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1cbc8cc2589263383ae0004b7c6862b03efb6622 --- /dev/null +++ b/scripts/util/visualizer.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import ntpath +import time +from . import util +# from . import html +import scipy.misc + +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + + +class Visualizer: + def __init__(self, opt): + # self.opt = opt + self.tf_log = opt.tf_log + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if self.tf_log: + import tensorflow as tf + self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + self.writer = tf.summary.FileWriter(self.log_dir) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): + if self.tf_log: # show images in tensorboard output + img_summaries = [] + for label, image_numpy in visuals.items(): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + scipy.misc.toimage(image_numpy).save(s, format="jpeg") + # Create an Image object + img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], + width=image_numpy.shape[1]) + # Create a Summary value + img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) + + # Create and write Summary + summary = self.tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) + ims.append(img_path) + txts.append(label + str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_%s.jpg' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 10: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims) / 2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): + summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t, lr): + message = '(epoch: %d, iters: %d, time: %.3f lr: %.5f) ' % (epoch, i, t, lr) + for k, v in errors.items(): + if v != 0: + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + def print_save(self, message): + + print(message) + + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = '%s_%s.jpg' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) diff --git a/scripts/utils/__pycache__/audio.cpython-310.pyc b/scripts/utils/__pycache__/audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..064f4f3bb120980f5666223fa046d804bc54ea66 Binary files /dev/null and b/scripts/utils/__pycache__/audio.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/croper.cpython-310.pyc b/scripts/utils/__pycache__/croper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fff9b7fe372955e8b7e28ef56713e8110f82440 Binary files /dev/null and b/scripts/utils/__pycache__/croper.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/face_enhancer.cpython-310.pyc b/scripts/utils/__pycache__/face_enhancer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..061ec0b2570ab1269a2020c7a09089b01bd8fb8b Binary files /dev/null and b/scripts/utils/__pycache__/face_enhancer.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/hparams.cpython-310.pyc b/scripts/utils/__pycache__/hparams.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e17e6042d629961fa2d36501dc5d88c3cf9e089e Binary files /dev/null and b/scripts/utils/__pycache__/hparams.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/init_path.cpython-310.pyc b/scripts/utils/__pycache__/init_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73fc9ad4210606985a69ffcbb96773fa93f240b5 Binary files /dev/null and b/scripts/utils/__pycache__/init_path.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/paste_pic.cpython-310.pyc b/scripts/utils/__pycache__/paste_pic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f84e993f58cc8ee604d44a0dbdfe2d8f5efee3c Binary files /dev/null and b/scripts/utils/__pycache__/paste_pic.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/preprocess.cpython-310.pyc b/scripts/utils/__pycache__/preprocess.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..138e58f0baae6e6a960283caf29b0d5ce924dcde Binary files /dev/null and b/scripts/utils/__pycache__/preprocess.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/safetensor_helper.cpython-310.pyc b/scripts/utils/__pycache__/safetensor_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d712082709013a35ca90b75073aea903da3f5279 Binary files /dev/null and b/scripts/utils/__pycache__/safetensor_helper.cpython-310.pyc differ diff --git a/scripts/utils/__pycache__/videoio.cpython-310.pyc b/scripts/utils/__pycache__/videoio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe5601d88424380b2e1c3be91670cf571a8fa3a Binary files /dev/null and b/scripts/utils/__pycache__/videoio.cpython-310.pyc differ diff --git a/scripts/utils/audio.py b/scripts/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..429db42e0069fdcfecd313d58f56dcee43c813f0 --- /dev/null +++ b/scripts/utils/audio.py @@ -0,0 +1,136 @@ +import librosa +import librosa.filters +import numpy as np +# import tensorflow as tf +from scipy import signal +from scipy.io import wavfile +from scripts.utils.hparams import hparams as hp + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _lws_processor(): + import lws + return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + +def _stft(y): + if hp.use_lws: + return _lws_processor(hp).stft(y).T + else: + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/scripts/utils/croper.py b/scripts/utils/croper.py new file mode 100644 index 0000000000000000000000000000000000000000..aa58719290e727652c8c2a3676cc786973d6c565 --- /dev/null +++ b/scripts/utils/croper.py @@ -0,0 +1,144 @@ +import os +import cv2 +import time +import glob +import argparse +import scipy +import numpy as np +from PIL import Image +import torch +from tqdm import tqdm +from itertools import cycle + +from scripts.face3d.extract_kp_videos_safe import KeypointExtractor +from facexlib.alignment import landmark_98_to_68 + +import numpy as np +from PIL import Image + +class Preprocesser: + def __init__(self, device='cuda'): + self.predictor = KeypointExtractor(device) + + def get_landmark(self, img_np): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + with torch.no_grad(): + dets = self.predictor.det_net.detect_faces(img_np, 0.97) + + if len(dets) == 0: + return None + det = dets[0] + + img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :] + lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0] + + #### keypoints to the original location + lm[:,0] += int(det[0]) + lm[:,1] += int(det[1]) + + return lm + + def align_face(self, img, lm, output_size=1024): + """ + :param filepath: str + :return: PIL Image + """ + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference + x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化 + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度 + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点 + qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍 + + # Shrink. + # 如果计算出的四边形太大了,就按比例缩小它 + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + else: + rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1])))) + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + # img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + # if enable_padding and max(pad) > border - 4: + # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # h, w, _ = img.shape + # y, x, _ = np.ogrid[:h, :w, :1] + # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + # blur = qsize * 0.02 + # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + # quad += pad[:2] + + # Transform. + quad = (quad + 0.5).flatten() + lx = max(min(quad[0], quad[2]), 0) + ly = max(min(quad[1], quad[7]), 0) + rx = min(max(quad[4], quad[6]), img.size[0]) + ry = min(max(quad[3], quad[5]), img.size[0]) + + # Save aligned image. + return rsize, crop, [lx, ly, rx, ry] + + def crop(self, img_np_list, still=False, xsize=512): # first frame for all video + img_np = img_np_list[0] + lm = self.get_landmark(img_np) + + if lm is None: + raise 'can not detect the landmark from source image' + rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + for _i in range(len(img_np_list)): + _inp = img_np_list[_i] + _inp = cv2.resize(_inp, (rsize[0], rsize[1])) + _inp = _inp[cly:cry, clx:crx] + if not still: + _inp = _inp[ly:ry, lx:rx] + img_np_list[_i] = _inp + return img_np_list, crop, quad + diff --git a/scripts/utils/face_enhancer.py b/scripts/utils/face_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4e9cbe250df0fd0c6fd1abb27888df853520b4 --- /dev/null +++ b/scripts/utils/face_enhancer.py @@ -0,0 +1,123 @@ +import os +import torch + +from gfpgan import GFPGANer + +from tqdm import tqdm + +from scripts.utils.videoio import load_video_to_cv2 + +import cv2 + + +class GeneratorWithLen(object): + """ From https://stackoverflow.com/a/7460929 """ + + def __init__(self, gen, length): + self.gen = gen + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen + +def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): + gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) + return list(gen) + +def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): + """ Provide a generator with a __len__ method so that it can passed to functions that + call len()""" + + if os.path.isfile(images): # handle video to images + # TODO: Create a generator version of load_video_to_cv2 + images = load_video_to_cv2(images) + + gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) + gen_with_len = GeneratorWithLen(gen, len(images)) + return gen_with_len + +def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): + """ Provide a generator function so that all of the enhanced images don't need + to be stored in memory at the same time. This can save tons of RAM compared to + the enhancer function. """ + + print('face enhancer....') + if not isinstance(images, list) and os.path.isfile(images): # handle video to images + images = load_video_to_cv2(images) + + # ------------------------ set up GFPGAN restorer ------------------------ + if method == 'gfpgan': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.4' + url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' + elif method == 'RestoreFormer': + arch = 'RestoreFormer' + channel_multiplier = 2 + model_name = 'RestoreFormer' + url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' + elif method == 'codeformer': # TODO: + arch = 'CodeFormer' + channel_multiplier = 2 + model_name = 'CodeFormer' + url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + else: + raise ValueError(f'Wrong model version {method}.') + + + # ------------------------ set up background upsampler ------------------------ + if bg_upsampler == 'realesrgan': + if not torch.cuda.is_available(): # CPU + import warnings + warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' + 'If you really want to use it, please modify the corresponding codes.') + bg_upsampler = None + else: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + bg_upsampler = RealESRGANer( + scale=2, + model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + model=model, + tile=400, + tile_pad=10, + pre_pad=0, + half=True) # need to set False in CPU mode + else: + bg_upsampler = None + + # determine model paths + model_path = os.path.join('gfpgan/weights', model_name + '.pth') + + if not os.path.isfile(model_path): + model_path = os.path.join('checkpoints', model_name + '.pth') + + if not os.path.isfile(model_path): + # download pre-trained models from url + model_path = url + + restorer = GFPGANer( + model_path=model_path, + upscale=2, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=bg_upsampler) + + # ------------------------ restore ------------------------ + for idx in tqdm(range(len(images)), 'Face Enhancer:'): + + img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) + + # restore faces and background if necessary + cropped_faces, restored_faces, r_img = restorer.enhance( + img, + has_aligned=False, + only_center_face=False, + paste_back=True) + + r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) + yield r_img diff --git a/scripts/utils/hparams.py b/scripts/utils/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..743c5c7d5a5a9e686f1ccd6fb3c2fb5cb382d62b --- /dev/null +++ b/scripts/utils/hparams.py @@ -0,0 +1,160 @@ +from glob import glob +import os + +class HParams: + def __init__(self, **kwargs): + self.data = {} + + for key, value in kwargs.items(): + self.data[key] = value + + def __getattr__(self, key): + if key not in self.data: + raise AttributeError("'HParams' object has no attribute %s" % key) + return self.data[key] + + def set_hparam(self, key, value): + self.data[key] = value + + +# Default hyperparameters +hparams = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=16, + initial_learning_rate=1e-4, + nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=20, + checkpoint_interval=3000, + eval_interval=3000, + writer_interval=300, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=1000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + + +# Default hyperparameters +hparamsdebug = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=2, + initial_learning_rate=1e-3, + nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=0, + checkpoint_interval=10000, + eval_interval=10, + writer_interval=5, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=10000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + +def hparams_debug_string(): + values = hparams.values() + hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] + return "Hyperparameters:\n" + "\n".join(hp) diff --git a/scripts/utils/init_path.py b/scripts/utils/init_path.py new file mode 100644 index 0000000000000000000000000000000000000000..5f38d11907bd0dc789992062ce7f02d8876c638f --- /dev/null +++ b/scripts/utils/init_path.py @@ -0,0 +1,47 @@ +import os +import glob + +def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): + + if old_version: + #### load all the checkpoint of `pth` + sadtalker_paths = { + 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), + 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), + 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), + 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), + 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') + } + + use_safetensor = False + elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): + print('using safetensor as default') + sadtalker_paths = { + "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), + } + use_safetensor = True + else: + print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") + use_safetensor = False + + sadtalker_paths = { + 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), + 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), + 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), + 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), + 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') + } + + sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' + sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') + sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') + sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') + + if 'full' in preprocess: + sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') + sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') + else: + sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') + sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') + + return sadtalker_paths \ No newline at end of file diff --git a/scripts/utils/model2safetensor.py b/scripts/utils/model2safetensor.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5fd4d9fd0c816039e4153463d0799e80af3bec --- /dev/null +++ b/scripts/utils/model2safetensor.py @@ -0,0 +1,141 @@ +import torch +import yaml +import os + +import safetensors +from safetensors.torch import save_file +from yacs.config import CfgNode as CN +import sys + +sys.path.append('/apdcephfs/private_shadowcun/SadTalker') + +from scripts.face3d.models import networks + +from scripts.facerender.modules.keypoint_detector import HEEstimator, KPDetector +from scripts.facerender.modules.mapping import MappingNet +from scripts.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator + +from scripts.audio2pose_models.audio2pose import Audio2Pose +from scripts.audio2exp_models.networks import SimpleWrapperV2 +from scripts.test_audio2coeff import load_cpk + +size = 256 +############ face vid2vid +config_path = os.path.join('src', 'config', 'facerender.yaml') +current_root_path = '.' + +path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') +net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') +checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') +net_recon.load_state_dict(checkpoint['net_recon']) + +with open(config_path) as f: + config = yaml.safe_load(f) + +generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) +kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) +he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], + **config['model_params']['common_params']) +mapping = MappingNet(**config['model_params']['mapping_params']) + +def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, + kp_detector=None, he_estimator=None, optimizer_generator=None, + optimizer_discriminator=None, optimizer_kp_detector=None, + optimizer_he_estimator=None, device="cpu"): + + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint['he_estimator']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) + + return checkpoint['epoch'] + + +def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, + kp_detector=None, he_estimator=None, + device="cpu"): + + checkpoint = safetensors.torch.load_file(checkpoint_path) + + if generator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'generator' in k: + x_generator[k.replace('generator.', '')] = v + generator.load_state_dict(x_generator) + if kp_detector is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'kp_extractor' in k: + x_generator[k.replace('kp_extractor.', '')] = v + kp_detector.load_state_dict(x_generator) + if he_estimator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'he_estimator' in k: + x_generator[k.replace('he_estimator.', '')] = v + he_estimator.load_state_dict(x_generator) + + return None + +free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' +load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) + +wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') + +audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') +audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') + +audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') +audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') + +fcfg_pose = open(audio2pose_yaml_path) +cfg_pose = CN.load_cfg(fcfg_pose) +cfg_pose.freeze() +audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) +audio2pose_model.eval() +load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') + +# load audio2exp_model +netG = SimpleWrapperV2() +netG.eval() +load_cpk(audio2exp_checkpoint, model=netG, device='cpu') + +class SadTalker(torch.nn.Module): + def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): + super(SadTalker, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.audio2exp = netG + self.audio2pose = audio2pose + self.face_3drecon = face_3drecon + + +model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) + +# here, we want to convert it to safetensor +save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") + +### test +load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) \ No newline at end of file diff --git a/scripts/utils/paste_pic.py b/scripts/utils/paste_pic.py new file mode 100644 index 0000000000000000000000000000000000000000..b733f1b2fb727c061a4d96f9c0c87364abd677a7 --- /dev/null +++ b/scripts/utils/paste_pic.py @@ -0,0 +1,69 @@ +import cv2, os +import numpy as np +from tqdm import tqdm +import uuid + +from scripts.utils.videoio import save_video_with_watermark + +def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): + + if not os.path.isfile(pic_path): + raise ValueError('pic_path must be a valid path to video/image file') + elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: + # loader for first frame + full_img = cv2.imread(pic_path) + else: + # loader for videos + video_stream = cv2.VideoCapture(pic_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + break + full_img = frame + frame_h = full_img.shape[0] + frame_w = full_img.shape[1] + + video_stream = cv2.VideoCapture(video_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + crop_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + crop_frames.append(frame) + + if len(crop_info) != 3: + print("you didn't crop the image") + return + else: + r_w, r_h = crop_info[0] + clx, cly, crx, cry = crop_info[1] + lx, ly, rx, ry = crop_info[2] + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + + if extended_crop: + oy1, oy2, ox1, ox2 = cly, cry, clx, crx + else: + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + + tmp_path = str(uuid.uuid4())+'.mp4' + out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) + for crop_frame in tqdm(crop_frames, 'seamlessClone:'): + p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) + + mask = 255*np.ones(p.shape, p.dtype) + location = ((ox1+ox2) // 2, (oy1+oy2) // 2) + gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) + out_tmp.write(gen_img) + + out_tmp.release() + + save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False) + os.remove(tmp_path) diff --git a/scripts/utils/preprocess.py b/scripts/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1a297e385f9fecdc4c5c360d646be0d62c532540 --- /dev/null +++ b/scripts/utils/preprocess.py @@ -0,0 +1,170 @@ +import numpy as np +import cv2, os, sys, torch +from tqdm import tqdm +from PIL import Image + +# 3dmm extraction +import safetensors +import safetensors.torch +from scripts.face3d.util.preprocess import align_img +from scripts.face3d.util.load_mats import load_lm3d +from scripts.face3d.models import networks + +from scipy.io import loadmat, savemat +from scripts.utils.croper import Preprocesser + + +import warnings + +from scripts.utils.safetensor_helper import load_x_from_safetensor +warnings.filterwarnings("ignore") + +def split_coeff(coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + + +class CropAndExtract(): + def __init__(self, sadtalker_path, device): + + self.propress = Preprocesser(device) + self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) + + if sadtalker_path['use_safetensor']: + checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) + self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) + else: + checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) + self.net_recon.load_state_dict(checkpoint['net_recon']) + + self.net_recon.eval() + self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) + self.device = device + + def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256): + + pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] + + landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') + coeff_path = os.path.join(save_dir, pic_name+'.mat') + png_path = os.path.join(save_dir, pic_name+'.png') + + #load input + if not os.path.isfile(input_path): + raise ValueError('input_path must be a valid path to video/image file') + elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: + # loader for first frame + full_frames = [cv2.imread(input_path)] + fps = 25 + else: + # loader for videos + video_stream = cv2.VideoCapture(input_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + full_frames.append(frame) + if source_image_flag: + break + + x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] + + #### crop images as the + if 'crop' in crop_or_resize.lower(): # default crop + x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) + elif 'full' in crop_or_resize.lower(): + x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) + else: # resize mode + oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] + crop_info = ((ox2 - ox1, oy2 - oy1), None, None) + + frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] + if len(frames_pil) == 0: + print('No face is detected in the input file') + return None, None + + # save crop info + for frame in frames_pil: + cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) + + # 2. get the landmark according to the detected face. + if not os.path.isfile(landmarks_path): + lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) + else: + print(' Using saved landmarks.') + lm = np.loadtxt(landmarks_path).astype(np.float32) + lm = lm.reshape([len(x_full_frames), -1, 2]) + + if not os.path.isfile(coeff_path): + # load 3dmm paramter generator from Deep3DFaceRecon_pytorch + video_coeffs, full_coeffs = [], [] + for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): + frame = frames_pil[idx] + W,H = frame.size + lm1 = lm[idx].reshape([-1, 2]) + + if np.mean(lm1) == -1: + lm1 = (self.lm3d_std[:, :2]+1)/2. + lm1 = np.concatenate( + [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 + ) + else: + lm1[:, -1] = H - 1 - lm1[:, -1] + + trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) + + trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32) + im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) + + with torch.no_grad(): + full_coeff = self.net_recon(im_t) + coeffs = split_coeff(full_coeff) + + pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} + + pred_coeff = np.concatenate([ + pred_coeff['exp'], + pred_coeff['angle'], + pred_coeff['trans'], + trans_params[2:][None], + ], 1) + video_coeffs.append(pred_coeff) + full_coeffs.append(full_coeff.cpu().numpy()) + + semantic_npy = np.array(video_coeffs)[:,0] + + savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]}) + + return coeff_path, png_path, crop_info diff --git a/scripts/utils/safetensor_helper.py b/scripts/utils/safetensor_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..3cdbdd21e4ed656dfe2d31a57360afb3e96480b3 --- /dev/null +++ b/scripts/utils/safetensor_helper.py @@ -0,0 +1,8 @@ + + +def load_x_from_safetensor(checkpoint, key): + x_generator = {} + for k,v in checkpoint.items(): + if key in k: + x_generator[k.replace(key+'.', '')] = v + return x_generator \ No newline at end of file diff --git a/scripts/utils/text2speech.py b/scripts/utils/text2speech.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cd9f588b3dbfd9bc1e4c80c34ff13a74cadadc --- /dev/null +++ b/scripts/utils/text2speech.py @@ -0,0 +1,20 @@ +import os +import tempfile +from TTS.api import TTS + + +class TTSTalker(): + def __init__(self) -> None: + model_name = TTS().list_models()[0] + self.tts = TTS(model_name) + + def test(self, text, language='en'): + + tempf = tempfile.NamedTemporaryFile( + delete = False, + suffix = ('.'+'wav'), + ) + + self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) + + return tempf.name diff --git a/scripts/utils/videoio.py b/scripts/utils/videoio.py new file mode 100644 index 0000000000000000000000000000000000000000..08bfbdd7d4be97dc17fea4ad7b2733e9eb0ef975 --- /dev/null +++ b/scripts/utils/videoio.py @@ -0,0 +1,41 @@ +import shutil +import uuid + +import os + +import cv2 + +def load_video_to_cv2(input_path): + video_stream = cv2.VideoCapture(input_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + return full_frames + +def save_video_with_watermark(video, audio, save_path, watermark=False): + temp_file = str(uuid.uuid4())+'.mp4' + cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) + os.system(cmd) + + if watermark is False: + shutil.move(temp_file, save_path) + else: + # watermark + try: + ##### check if stable-diffusion-webui + import webui + from modules import paths + watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png" + except: + # get the root path of sadtalker. + dir_path = os.path.dirname(os.path.realpath(__file__)) + watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" + + cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) + os.system(cmd) + os.remove(temp_file) \ No newline at end of file diff --git a/scripts/weights/facelib/detection_Resnet50_Final.pth b/scripts/weights/facelib/detection_Resnet50_Final.pth new file mode 100644 index 0000000000000000000000000000000000000000..16546738ce0a00a9fd47585e0fc52744d31cc117 --- /dev/null +++ b/scripts/weights/facelib/detection_Resnet50_Final.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d +size 109497761 diff --git a/scripts/weights/facelib/parsing_parsenet.pth b/scripts/weights/facelib/parsing_parsenet.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ac2efc50360a79c9905dbac57d9d99cbfbe863c --- /dev/null +++ b/scripts/weights/facelib/parsing_parsenet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2 +size 85331193 diff --git a/subprocess/LivePortrait/.gitignore b/subprocess/LivePortrait/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1f85f19b8dc076778ff4cf5dcf758628ccf2df8f --- /dev/null +++ b/subprocess/LivePortrait/.gitignore @@ -0,0 +1,21 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +**/__pycache__/ +*.py[cod] +**/*.py[cod] +*$py.class + +# Model weights +**/*.pth +**/*.onnx + +pretrained_weights/*.md +pretrained_weights/docs + +# Ipython notebook +*.ipynb + +# Temporary files or benchmark resources +animations/* +tmp/* +.vscode/launch.json diff --git a/subprocess/LivePortrait/.vscode/settings.json b/subprocess/LivePortrait/.vscode/settings.json new file mode 100644 index 0000000000000000000000000000000000000000..1bca84ccf9fed7936fc93d2704ff4eab6c734728 --- /dev/null +++ b/subprocess/LivePortrait/.vscode/settings.json @@ -0,0 +1,19 @@ +{ + "[python]": { + "editor.tabSize": 4 + }, + "files.eol": "\n", + "files.insertFinalNewline": true, + "files.trimFinalNewlines": true, + "files.trimTrailingWhitespace": true, + "files.exclude": { + "**/.git": true, + "**/.svn": true, + "**/.hg": true, + "**/CVS": true, + "**/.DS_Store": true, + "**/Thumbs.db": true, + "**/*.crswap": true, + "**/__pycache__": true + } +} diff --git a/subprocess/LivePortrait/LICENSE b/subprocess/LivePortrait/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9e8f5026e9273b98745188ec4bbc8ac05b2b22ef --- /dev/null +++ b/subprocess/LivePortrait/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14.mp4 b/subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b75d37fdea47db26f14f69f85b8994e114f2cc9c --- /dev/null +++ b/subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f25d31d83ae06f8dc313ae6ef9016269336219f25ffa8765d60920d90f314b7d +size 6107008 diff --git a/subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14_concat.mp4 b/subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1cfc6a6641f837c1ac035aa47a9d46b99919225a --- /dev/null +++ b/subprocess/LivePortrait/animations/IMG_20240410_151931_Bokeh~2--d14_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30dde4ba24793ed58398f9ee4d54a5a7d40bee7ed2634f63603a8670cdb66b97 +size 2933593 diff --git a/subprocess/LivePortrait/animations/image3--d6.mp4 b/subprocess/LivePortrait/animations/image3--d6.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7b84139eaae2f30e1cdcadd5dde9b2c35832ee1b Binary files /dev/null and b/subprocess/LivePortrait/animations/image3--d6.mp4 differ diff --git a/subprocess/LivePortrait/animations/image3--d6_concat.mp4 b/subprocess/LivePortrait/animations/image3--d6_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0b322e7d61b6e968674f38f24ac1382695175fb6 --- /dev/null +++ b/subprocess/LivePortrait/animations/image3--d6_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be7f3b6e664d27974e38fb168e6aef4379e91d50da6e55023f009ce6c00e9e54 +size 5005845 diff --git a/subprocess/LivePortrait/animations/image3--intro.mp4 b/subprocess/LivePortrait/animations/image3--intro.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e93bae84ae82eadf739f6f5d3b4514a26344ab4 Binary files /dev/null and b/subprocess/LivePortrait/animations/image3--intro.mp4 differ diff --git a/subprocess/LivePortrait/animations/image3--intro_concat.mp4 b/subprocess/LivePortrait/animations/image3--intro_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4f562366a6893575066de2ee96595752967a54ec --- /dev/null +++ b/subprocess/LivePortrait/animations/image3--intro_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:079eac3a6cf84b04e18723a066448fddffcd71988b1f609aedd14578797384d9 +size 1373427 diff --git a/subprocess/LivePortrait/animations/input--d10.mp4 b/subprocess/LivePortrait/animations/input--d10.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..da3c6bdca05a2ae1e3acfd585dd499c8a7147576 --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d10.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18c5ebdbfbd40d4a05045acacd6af673962a411672b137bba48f482210d491e6 +size 2421900 diff --git a/subprocess/LivePortrait/animations/input--d10_concat.mp4 b/subprocess/LivePortrait/animations/input--d10_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..00ffb3bc6d4903d36eced467d7821b5aa0b17666 --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d10_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c7b8a86d3bcb23090991b567683093097f1c7634e478b3953f972136e6124bb +size 3403717 diff --git a/subprocess/LivePortrait/animations/input--d13.mp4 b/subprocess/LivePortrait/animations/input--d13.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..59b219cd07692d3a39ae4d4a7051e7914f4d6ae6 --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d13.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c7b7df54d0aeb8a75eaa5ddec371ffcd51308e5a42d6a273c0b90f40637e828 +size 1424807 diff --git a/subprocess/LivePortrait/animations/input--d13_concat.mp4 b/subprocess/LivePortrait/animations/input--d13_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6c130bc85a2f03fe6dea60126ed013420ef5ac07 --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d13_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d071a3d6c188a9f36d45d82cebd2cd761c7f6702e9c5e1c887e238261169efd +size 2013707 diff --git a/subprocess/LivePortrait/animations/input--d19.mp4 b/subprocess/LivePortrait/animations/input--d19.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9ec7fafa20814458f9ba9d2d89babb78c160f2d1 Binary files /dev/null and b/subprocess/LivePortrait/animations/input--d19.mp4 differ diff --git a/subprocess/LivePortrait/animations/input--d19_concat.mp4 b/subprocess/LivePortrait/animations/input--d19_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..094cbc030140592b0d145dd411264a5c42d8c5fa Binary files /dev/null and b/subprocess/LivePortrait/animations/input--d19_concat.mp4 differ diff --git a/subprocess/LivePortrait/animations/input--d3.mp4 b/subprocess/LivePortrait/animations/input--d3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..59ac6adc061ca88b7bb7572542ff141a4cda23d6 Binary files /dev/null and b/subprocess/LivePortrait/animations/input--d3.mp4 differ diff --git a/subprocess/LivePortrait/animations/input--d3_concat.mp4 b/subprocess/LivePortrait/animations/input--d3_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..56cdedec1e5798ae6c7f200d262435b7bcfad39e --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d3_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00212de5f9bc729ce91060887668022f48a9507f147b953c53b963680324aa53 +size 1632067 diff --git a/subprocess/LivePortrait/animations/input--d9.mp4 b/subprocess/LivePortrait/animations/input--d9.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7fc5db9f6c9685ffade8320bfc38e3737db0ddb2 --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d9.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c22dc451add7cf70badaae8517774f881a5d900a0909fee15605771b499cf413 +size 3321644 diff --git a/subprocess/LivePortrait/animations/input--d9_concat.mp4 b/subprocess/LivePortrait/animations/input--d9_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..644f0b9675c0ef5d1bcf7be37dc197f460d9466b --- /dev/null +++ b/subprocess/LivePortrait/animations/input--d9_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8a3a13fb284d8805076a6adfff2684728f8b7efe6c3d6591259dd2a0a528ee3 +size 4673474 diff --git a/subprocess/LivePortrait/animations/s1--d13.mp4 b/subprocess/LivePortrait/animations/s1--d13.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cf1cd0f4719bc216276850e6b0ed84ca13b4aed8 --- /dev/null +++ b/subprocess/LivePortrait/animations/s1--d13.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c2c741ffc7e32a705ee22f17d831f397d2890eb7bb0e7c126b5b514ec2f015a +size 1546609 diff --git a/subprocess/LivePortrait/animations/s1--d13_concat.mp4 b/subprocess/LivePortrait/animations/s1--d13_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e2a2065beeaa38755c96d137da4ae08342db1418 --- /dev/null +++ b/subprocess/LivePortrait/animations/s1--d13_concat.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:943508c9079ac81bda287f468c22dd085872c770a84fafc8fee65ba5636b1932 +size 1723831 diff --git a/subprocess/LivePortrait/animations/s10--d0.mp4 b/subprocess/LivePortrait/animations/s10--d0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8cd2425b82630e54cd28974dff53fe47f8587402 Binary files /dev/null and b/subprocess/LivePortrait/animations/s10--d0.mp4 differ diff --git a/subprocess/LivePortrait/animations/s10--d0_concat.mp4 b/subprocess/LivePortrait/animations/s10--d0_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..23e648a14be52cf68ecd2117b489a82c0be2b01b Binary files /dev/null and b/subprocess/LivePortrait/animations/s10--d0_concat.mp4 differ diff --git a/subprocess/LivePortrait/animations/s5--d0.mp4 b/subprocess/LivePortrait/animations/s5--d0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c68d4fc12a4ca4597875431d902db14b8691f020 Binary files /dev/null and b/subprocess/LivePortrait/animations/s5--d0.mp4 differ diff --git a/subprocess/LivePortrait/animations/s5--d0_concat.mp4 b/subprocess/LivePortrait/animations/s5--d0_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b00fe6f4721169eff9b0358a3b3be9297e1e4963 Binary files /dev/null and b/subprocess/LivePortrait/animations/s5--d0_concat.mp4 differ diff --git a/subprocess/LivePortrait/animations/s6--d0.mp4 b/subprocess/LivePortrait/animations/s6--d0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..dbbfbdad86952b35c3ee19584add8ecd2244025d Binary files /dev/null and b/subprocess/LivePortrait/animations/s6--d0.mp4 differ diff --git a/subprocess/LivePortrait/animations/s6--d0_concat.mp4 b/subprocess/LivePortrait/animations/s6--d0_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7f6e876d5cd430d9789810fa98a0525b2214da74 Binary files /dev/null and b/subprocess/LivePortrait/animations/s6--d0_concat.mp4 differ diff --git a/subprocess/LivePortrait/animations/s6--d18.mp4 b/subprocess/LivePortrait/animations/s6--d18.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c1f66e195e6c2297fa15b7eaa993ac0c58f8bc7e Binary files /dev/null and b/subprocess/LivePortrait/animations/s6--d18.mp4 differ diff --git a/subprocess/LivePortrait/animations/s6--d18_concat.mp4 b/subprocess/LivePortrait/animations/s6--d18_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..89eefd235afa7314bdf4ecfe250d89a01f404887 Binary files /dev/null and b/subprocess/LivePortrait/animations/s6--d18_concat.mp4 differ diff --git a/subprocess/LivePortrait/animations/s9--d0.mp4 b/subprocess/LivePortrait/animations/s9--d0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..04d54d9300084786fa73459a38db5511fe256dbc Binary files /dev/null and b/subprocess/LivePortrait/animations/s9--d0.mp4 differ diff --git a/subprocess/LivePortrait/animations/s9--d0_concat.mp4 b/subprocess/LivePortrait/animations/s9--d0_concat.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c4bb81b068b5eb553814b5cf52a6e37636f9e870 Binary files /dev/null and b/subprocess/LivePortrait/animations/s9--d0_concat.mp4 differ diff --git a/subprocess/LivePortrait/app.py b/subprocess/LivePortrait/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8e91e6a7bd7ddd428e7a3fe3551c8085d3a1dd38 --- /dev/null +++ b/subprocess/LivePortrait/app.py @@ -0,0 +1,205 @@ +# coding: utf-8 + +""" +The entrance of the gradio +""" +import os +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +import tyro +import gradio as gr +import os.path as osp +from src.utils.helper import load_description +from src.gradio_pipeline import GradioPipeline +from src.config.crop_config import CropConfig +from src.config.argument_config import ArgumentConfig +from src.config.inference_config import InferenceConfig + + +def partial_fields(target_class, kwargs): + return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) + + +# set tyro theme +tyro.extras.set_accent_color("bright_cyan") +print(ArgumentConfig) +args = tyro.cli(ArgumentConfig) + +# specify configs for inference +inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig +crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig + +gradio_pipeline = GradioPipeline( + inference_cfg=inference_cfg, + crop_cfg=crop_cfg, + args=args +) + + +def gpu_wrapped_execute_video(*args, **kwargs): + return gradio_pipeline.execute_video(*args, **kwargs) + + +def gpu_wrapped_execute_image(*args, **kwargs): + return gradio_pipeline.execute_image(*args, **kwargs) + + +# assets +title_md = "assets/gradio_title.md" +example_portrait_dir = "assets/examples/source" +example_video_dir = "assets/examples/driving" +data_examples = [ + [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], + [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], + [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False], + [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False], + [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False], + [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True], +] +#################### interface logic #################### + +# Define components first +eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio") +lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio") +retargeting_input_image = gr.Image(type="filepath") +output_image = gr.Image(type="numpy") +output_image_paste_back = gr.Image(type="numpy") +output_video = gr.Video() +output_video_concat = gr.Video() + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.HTML(load_description(title_md)) + gr.Markdown(load_description("assets/gradio_description_upload.md")) + with gr.Row(): + with gr.Accordion(open=True, label="Source Portrait"): + image_input = gr.Image(type="filepath") + gr.Examples( + examples=[ + [osp.join(example_portrait_dir, "s9.jpg")], + [osp.join(example_portrait_dir, "s6.jpg")], + [osp.join(example_portrait_dir, "s10.jpg")], + [osp.join(example_portrait_dir, "s5.jpg")], + [osp.join(example_portrait_dir, "s7.jpg")], + [osp.join(example_portrait_dir, "s12.jpg")], + ], + inputs=[image_input], + cache_examples=False, + ) + with gr.Accordion(open=True, label="Driving Video"): + video_input = gr.Video() + gr.Examples( + examples=[ + [osp.join(example_video_dir, "d0.mp4")], + [osp.join(example_video_dir, "d18.mp4")], + [osp.join(example_video_dir, "d19.mp4")], + [osp.join(example_video_dir, "d14.mp4")], + [osp.join(example_video_dir, "d6.mp4")], + ], + inputs=[video_input], + cache_examples=False, + ) + with gr.Row(): + with gr.Accordion(open=False, label="Animation Instructions and Options"): + gr.Markdown(load_description("assets/gradio_description_animation.md")) + with gr.Row(): + flag_relative_input = gr.Checkbox(value=True, label="relative motion") + flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)") + flag_remap_input = gr.Checkbox(value=True, label="paste-back") + flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)") + with gr.Row(): + with gr.Column(): + process_button_animation = gr.Button("🚀 Animate", variant="primary") + with gr.Column(): + process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear") + with gr.Row(): + with gr.Column(): + with gr.Accordion(open=True, label="The animated video in the original image space"): + output_video.render() + with gr.Column(): + with gr.Accordion(open=True, label="The animated video"): + output_video_concat.render() + with gr.Row(): + # Examples + gr.Markdown("## You could also choose the examples below by one click ⬇️") + with gr.Row(): + gr.Examples( + examples=data_examples, + fn=gpu_wrapped_execute_video, + inputs=[ + image_input, + video_input, + flag_relative_input, + flag_do_crop_input, + flag_remap_input, + flag_crop_driving_video_input + ], + outputs=[output_image, output_image_paste_back], + examples_per_page=len(data_examples), + cache_examples=False, + ) + gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True) + with gr.Row(visible=True): + eye_retargeting_slider.render() + lip_retargeting_slider.render() + with gr.Row(visible=True): + process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary") + process_button_reset_retargeting = gr.ClearButton( + [ + eye_retargeting_slider, + lip_retargeting_slider, + retargeting_input_image, + output_image, + output_image_paste_back + ], + value="🧹 Clear" + ) + with gr.Row(visible=True): + with gr.Column(): + with gr.Accordion(open=True, label="Retargeting Input"): + retargeting_input_image.render() + gr.Examples( + examples=[ + [osp.join(example_portrait_dir, "s9.jpg")], + [osp.join(example_portrait_dir, "s6.jpg")], + [osp.join(example_portrait_dir, "s10.jpg")], + [osp.join(example_portrait_dir, "s5.jpg")], + [osp.join(example_portrait_dir, "s7.jpg")], + [osp.join(example_portrait_dir, "s12.jpg")], + ], + inputs=[retargeting_input_image], + cache_examples=False, + ) + with gr.Column(): + with gr.Accordion(open=True, label="Retargeting Result"): + output_image.render() + with gr.Column(): + with gr.Accordion(open=True, label="Paste-back Result"): + output_image_paste_back.render() + # binding functions for buttons + process_button_retargeting.click( + # fn=gradio_pipeline.execute_image, + fn=gpu_wrapped_execute_image, + inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input], + outputs=[output_image, output_image_paste_back], + show_progress=True + ) + process_button_animation.click( + fn=gpu_wrapped_execute_video, + inputs=[ + image_input, + video_input, + flag_relative_input, + flag_do_crop_input, + flag_remap_input, + flag_crop_driving_video_input + ], + outputs=[output_video, output_video_concat], + show_progress=True + ) + + +demo.launch( + server_port=args.server_port, + share=args.share, + server_name=args.server_name +) diff --git a/subprocess/LivePortrait/assets/.gitignore b/subprocess/LivePortrait/assets/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..892dfa4274b60c1629e26719bbd1e462fcce33e8 --- /dev/null +++ b/subprocess/LivePortrait/assets/.gitignore @@ -0,0 +1,2 @@ +examples/driving/*.pkl +examples/driving/*_crop.mp4 diff --git a/subprocess/LivePortrait/assets/docs/changelog/2024-07-10.md b/subprocess/LivePortrait/assets/docs/changelog/2024-07-10.md new file mode 100644 index 0000000000000000000000000000000000000000..8b7fa8880808ee37b43a86851a557570deae12f3 --- /dev/null +++ b/subprocess/LivePortrait/assets/docs/changelog/2024-07-10.md @@ -0,0 +1,17 @@ +## 2024/07/10 + +**First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️ +The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository. + +### Updates + +- Audio and video concatenating: If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you may need to install `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94). + +- Driving video auto-cropping: Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`. + +- Template making: Added the ability to create templates to protect privacy. The template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These templates can be used to generate videos without needing the original driving video. By default, the template will be generated and saved as a .pkl file with the same name as the driving video. Once generated, you can specify it using the `-d` or `--driving_info` option. + + +### Others + +- If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62). diff --git a/subprocess/LivePortrait/assets/docs/inference.gif b/subprocess/LivePortrait/assets/docs/inference.gif new file mode 100644 index 0000000000000000000000000000000000000000..7e18022e5245dcb6449df6d190b538d5ca024e06 Binary files /dev/null and b/subprocess/LivePortrait/assets/docs/inference.gif differ diff --git a/subprocess/LivePortrait/assets/docs/showcase.gif b/subprocess/LivePortrait/assets/docs/showcase.gif new file mode 100644 index 0000000000000000000000000000000000000000..fae84c2d3550a37446e482286b70902b21e2e232 --- /dev/null +++ b/subprocess/LivePortrait/assets/docs/showcase.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bca5f38bfd555bf7c013312d87883afdf39d97fba719ac171c60f897af49e21 +size 6623248 diff --git a/subprocess/LivePortrait/assets/docs/showcase2.gif b/subprocess/LivePortrait/assets/docs/showcase2.gif new file mode 100644 index 0000000000000000000000000000000000000000..29175c0eeb85b9db0ffd61e3e9281dffe3536352 --- /dev/null +++ b/subprocess/LivePortrait/assets/docs/showcase2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb1fffb139681775780b2956e7d0289f55d199c1a3e14ab263887864d4b0d586 +size 2881351 diff --git a/subprocess/LivePortrait/assets/examples/driving/d0.mp4 b/subprocess/LivePortrait/assets/examples/driving/d0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..92391dd3ff235fc82f29b7cc77fe4a7ce183d934 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d0.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63f6f9962e1fdf6e6722172e7a18155204858d5d5ce3b1e0646c150360c33bed +size 2958395 diff --git a/subprocess/LivePortrait/assets/examples/driving/d0.pkl b/subprocess/LivePortrait/assets/examples/driving/d0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..91be76dd7e9ba4f0322f530358e599393f619705 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56217aa88f7b03483f6a04294fe7813e0d8de624e041284d72ce1436766cd860 +size 41087 diff --git a/subprocess/LivePortrait/assets/examples/driving/d1.pkl b/subprocess/LivePortrait/assets/examples/driving/d1.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8e11db176d93c34f7b44aa94487ac0a6715168cb --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d1.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16b47d68396e4a5fc0756b4c83827e8fc27c08bc92be1aa04809f741d9db95f9 +size 8599 diff --git a/subprocess/LivePortrait/assets/examples/driving/d10.mp4 b/subprocess/LivePortrait/assets/examples/driving/d10.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9123c788df48b50227c6aeb9a1d4061510424223 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d10.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9074d557e127f4a506bd0def25e2b3182676e917d53982852d33c807a95ef1fb +size 1146372 diff --git a/subprocess/LivePortrait/assets/examples/driving/d11.mp4 b/subprocess/LivePortrait/assets/examples/driving/d11.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..378d00065aaa1d30d6b14be10e0e78188deba152 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/driving/d11.mp4 differ diff --git a/subprocess/LivePortrait/assets/examples/driving/d12.mp4 b/subprocess/LivePortrait/assets/examples/driving/d12.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..984922e5c722fa9672dc6c6765bf1183466daf5b Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/driving/d12.mp4 differ diff --git a/subprocess/LivePortrait/assets/examples/driving/d13.mp4 b/subprocess/LivePortrait/assets/examples/driving/d13.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..23b6af6e4afa879a11ec8284bfdb3253739e6b41 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d13.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d03e39c28323cde1c5fc6c5629aa83fe6c834fa7c9ed2dac969e1247eaafdb60 +size 2475854 diff --git a/subprocess/LivePortrait/assets/examples/driving/d14.mp4 b/subprocess/LivePortrait/assets/examples/driving/d14.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e4a25d614cae7ae9b0425539da1c24d09d06c7db Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/driving/d14.mp4 differ diff --git a/subprocess/LivePortrait/assets/examples/driving/d18.mp4 b/subprocess/LivePortrait/assets/examples/driving/d18.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c23ade1841fa5744af3ffdc3a42d52b9227a9d2e Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/driving/d18.mp4 differ diff --git a/subprocess/LivePortrait/assets/examples/driving/d19.mp4 b/subprocess/LivePortrait/assets/examples/driving/d19.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..07562e983f601d00bcc0e388fe60872c046bcbaf Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/driving/d19.mp4 differ diff --git a/subprocess/LivePortrait/assets/examples/driving/d2.pkl b/subprocess/LivePortrait/assets/examples/driving/d2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6bc7d490b84d9a06436b133e4c1457b6055f8dc1 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d2.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:021a0e83d4ae81ab75b49b31a6cf75ac7987c86e02808aced3dd49894512a082 +size 8599 diff --git a/subprocess/LivePortrait/assets/examples/driving/d3.mp4 b/subprocess/LivePortrait/assets/examples/driving/d3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8b70b6aa3c0e566a4fa3e5959f2d3b916e99b708 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d3.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef5c86e49b1b43dcb1449b499eb5a7f0cbae2f78aec08b5598193be1e4257099 +size 1430968 diff --git a/subprocess/LivePortrait/assets/examples/driving/d5.pkl b/subprocess/LivePortrait/assets/examples/driving/d5.pkl new file mode 100644 index 0000000000000000000000000000000000000000..4fde2987728e8e4491600c68c951189c095ef90e --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d5.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91f2863838a089fe418b22864e7c48ac1f2b9d4513afb033a9d9dd5979a90b8c +size 77776 diff --git a/subprocess/LivePortrait/assets/examples/driving/d6.mp4 b/subprocess/LivePortrait/assets/examples/driving/d6.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..44f351385cef843b21b03fab8c3b10e0c005ec5e --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d6.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00e3ea79bbf28cbdc4fbb67ec655d9a0fe876e880ec45af55ae481348d0c0fff +size 1967790 diff --git a/subprocess/LivePortrait/assets/examples/driving/d7.pkl b/subprocess/LivePortrait/assets/examples/driving/d7.pkl new file mode 100644 index 0000000000000000000000000000000000000000..be7c9d74d12c19a8460b215da117de737889b5b2 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d7.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84aed70f3dd01ebd818c51fc11762eeff51efeef05b1f15a660b105c4c0748da +size 93496 diff --git a/subprocess/LivePortrait/assets/examples/driving/d8.pkl b/subprocess/LivePortrait/assets/examples/driving/d8.pkl new file mode 100644 index 0000000000000000000000000000000000000000..dc631a4fcbd858afec881f66fd13b19f820d2eab --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d8.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:099afc34d40794aa733644af76dcd1bc387573c381a82340018df7019a06d68e +size 144334 diff --git a/subprocess/LivePortrait/assets/examples/driving/d9.mp4 b/subprocess/LivePortrait/assets/examples/driving/d9.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7803b3bf5c460a79d94e5cfbedb0de1f52d449d2 --- /dev/null +++ b/subprocess/LivePortrait/assets/examples/driving/d9.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a414aa1d547be35306d692065a2157434bf40a6025ba8e30ce12e5bb322cc33 +size 2257929 diff --git a/subprocess/LivePortrait/assets/examples/source/s0.jpg b/subprocess/LivePortrait/assets/examples/source/s0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef44c593be38cea30422fff9ed986a8a77889348 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s0.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s1.jpg b/subprocess/LivePortrait/assets/examples/source/s1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ebacda3519a1452aee239f7e104d2c6ff40beb25 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s1.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s10.jpg b/subprocess/LivePortrait/assets/examples/source/s10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ee9616b592f070fbe90a8717da01477e8d4ee01f Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s10.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s11.jpg b/subprocess/LivePortrait/assets/examples/source/s11.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd2fa2d2867336215012943addd7c7def2a29ccb Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s11.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s12.jpg b/subprocess/LivePortrait/assets/examples/source/s12.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d3d65c1e8e099ec279d730d296875b937f885417 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s12.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s2.jpg b/subprocess/LivePortrait/assets/examples/source/s2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e851bd20b65c552266a87bb87a9b509e3ea56f7d Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s2.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s3.jpg b/subprocess/LivePortrait/assets/examples/source/s3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9f3ba2a358e5b88450e7466761dff3e983e18e16 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s3.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s4.jpg b/subprocess/LivePortrait/assets/examples/source/s4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..17f611bf942ad168d4e4d03b7e5c42d6650c4be1 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s4.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s5.jpg b/subprocess/LivePortrait/assets/examples/source/s5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9abad7ef061b93579a373cf141d38710d9b1e32d Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s5.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s6.jpg b/subprocess/LivePortrait/assets/examples/source/s6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91c13d5f2b48d143ca596566ad10f0a0e5693da4 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s6.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s7.jpg b/subprocess/LivePortrait/assets/examples/source/s7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf96f2d5651f7ae0faf08193ecd3df282c5c3b53 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s7.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s8.jpg b/subprocess/LivePortrait/assets/examples/source/s8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b415ed1d4a4e5cf01e6dc30d6b4ced20814558d5 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s8.jpg differ diff --git a/subprocess/LivePortrait/assets/examples/source/s9.jpg b/subprocess/LivePortrait/assets/examples/source/s9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ef7251ba10bf83356587016b126a52bdbca7b18 Binary files /dev/null and b/subprocess/LivePortrait/assets/examples/source/s9.jpg differ diff --git a/subprocess/LivePortrait/assets/gradio_description_animation.md b/subprocess/LivePortrait/assets/gradio_description_animation.md new file mode 100644 index 0000000000000000000000000000000000000000..cad1ad62bb41113c0d2e75a93581748ff65d384f --- /dev/null +++ b/subprocess/LivePortrait/assets/gradio_description_animation.md @@ -0,0 +1,16 @@ +🔥 To animate the source portrait with the driving video, please follow these steps: +
+1. In the Animation Options section, we recommend enabling the do crop (source) option if faces occupy a small portion of your image. +
+
+2. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. +
+
+3. If you want to upload your own driving video, the best practice: + + - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`. + - Focus on the head area, similar to the example videos. + - Minimize shoulder movement. + - Make sure the first frame of driving video is a frontal face with **neutral expression**. + +
diff --git a/subprocess/LivePortrait/assets/gradio_description_retargeting.md b/subprocess/LivePortrait/assets/gradio_description_retargeting.md new file mode 100644 index 0000000000000000000000000000000000000000..4ff1a80d0a025b765f111a8c45a25cd20d5753d9 --- /dev/null +++ b/subprocess/LivePortrait/assets/gradio_description_retargeting.md @@ -0,0 +1,4 @@ +
+ +## Retargeting +🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the 🚗 Retargeting button. You can try running it multiple times. 😊 Set both ratios to 0.8 to see what's going on! diff --git a/subprocess/LivePortrait/assets/gradio_description_upload.md b/subprocess/LivePortrait/assets/gradio_description_upload.md new file mode 100644 index 0000000000000000000000000000000000000000..035a6c2332bbd7485a367612e20818cc26dad857 --- /dev/null +++ b/subprocess/LivePortrait/assets/gradio_description_upload.md @@ -0,0 +1,2 @@ +## 🤗 This is the official gradio demo for **LivePortrait**. +
Please upload or use a webcam to get a Source Portrait (any aspect ratio) and upload a Driving Video (1:1 aspect ratio, or any aspect ratio with do crop (driving video) checked).
diff --git a/subprocess/LivePortrait/assets/gradio_title.md b/subprocess/LivePortrait/assets/gradio_title.md new file mode 100644 index 0000000000000000000000000000000000000000..c9bbfc2e89419eaafabfe636e4d3230eb0b5e7b0 --- /dev/null +++ b/subprocess/LivePortrait/assets/gradio_title.md @@ -0,0 +1,11 @@ +
+
+

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

+
+ Project Page + + +
+
+
diff --git a/subprocess/LivePortrait/inference.py b/subprocess/LivePortrait/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8387e7f01657b64430f50603b45557d8ace3304f --- /dev/null +++ b/subprocess/LivePortrait/inference.py @@ -0,0 +1,33 @@ +# coding: utf-8 + +import tyro +from src.config.argument_config import ArgumentConfig +from src.config.inference_config import InferenceConfig +from src.config.crop_config import CropConfig +from src.live_portrait_pipeline import LivePortraitPipeline + + +def partial_fields(target_class, kwargs): + return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) + + +def main(): + # set tyro theme + tyro.extras.set_accent_color("bright_cyan") + args = tyro.cli(ArgumentConfig) + + # specify configs for inference + inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig + crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig + + live_portrait_pipeline = LivePortraitPipeline( + inference_cfg=inference_cfg, + crop_cfg=crop_cfg + ) + + # run + live_portrait_pipeline.execute(args) + + +if __name__ == '__main__': + main() diff --git a/subprocess/LivePortrait/pretrained_weights/.gitattributes b/subprocess/LivePortrait/pretrained_weights/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..1240e442c9107184a69cb6f4a3961ba16cbb8a1b --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/.gitattributes @@ -0,0 +1,45 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +liveportrait/retargeting_models/stitching_retargeting_module.pth filter=lfs diff=lfs merge=lfs -text +liveportrait/base_models/appearance_feature_extractor.pth filter=lfs diff=lfs merge=lfs -text +liveportrait/base_models/motion_extractor.pth filter=lfs diff=lfs merge=lfs -text +liveportrait/base_models/spade_generator.pth filter=lfs diff=lfs merge=lfs -text +liveportrait/base_models/warping_module.pth filter=lfs diff=lfs merge=lfs -text +insightface/models/buffalo_l/2d106det.onnx filter=lfs diff=lfs merge=lfs -text +insightface/models/buffalo_l/det_10g.onnx filter=lfs diff=lfs merge=lfs -text +liveportrait/landmark.onnx filter=lfs diff=lfs merge=lfs -text +docs/inference.gif filter=lfs diff=lfs merge=lfs -text +docs/showcase2.gif filter=lfs diff=lfs merge=lfs -text diff --git a/subprocess/LivePortrait/pretrained_weights/.gitignore b/subprocess/LivePortrait/pretrained_weights/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e646a996685d1a4a487decd6dfb0022954b88b61 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/.gitignore @@ -0,0 +1,18 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +**/__pycache__/ +*.py[cod] +**/*.py[cod] +*$py.class + +# Model weights +#**/*.pth +#**/*.onnx + +# Ipython notebook +*.ipynb + +# Temporary files or benchmark resources +animations/* +tmp/* +gradio_cached_examples/ diff --git a/subprocess/LivePortrait/pretrained_weights/README.md b/subprocess/LivePortrait/pretrained_weights/README.md new file mode 100644 index 0000000000000000000000000000000000000000..06fced1472a40110698cb3b3bebcf334f63e970f --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/README.md @@ -0,0 +1,148 @@ +--- +license: mit +--- + +

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

+ +
+ Jianzhu Guo 1†  + Dingyun Zhang 1,2  + Xiaoqiang Liu 1  + Zhizhou Zhong 1,3  + Yuan Zhang 1  +
+ +
+ Pengfei Wan 1  + Di Zhang 1  +
+ +
+ 1 Kuaishou Technology  2 University of Science and Technology of China  3 Fudan University  +
+ +
+
+ + + + +
+
+ +

+ showcase +
+ 🔥 For more results, visit our homepage 🔥 +

+ + + +## 🔥 Updates +- **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned! +- **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168). + +## Introduction +This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168). +We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖. + +## 🔥 Getting Started +### 1. Clone the code and prepare the environment +```bash +git clone https://github.com/KwaiVGI/LivePortrait +cd LivePortrait + +# create env using conda +conda create -n LivePortrait python==3.9.18 +conda activate LivePortrait +# install dependencies with pip +pip install -r requirements.txt +``` + +### 2. Download pretrained weights +Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows: +```text +pretrained_weights +├── insightface +│ └── models +│ └── buffalo_l +│ ├── 2d106det.onnx +│ └── det_10g.onnx +└── liveportrait + ├── base_models + │ ├── appearance_feature_extractor.pth + │ ├── motion_extractor.pth + │ ├── spade_generator.pth + │ └── warping_module.pth + ├── landmark.onnx + └── retargeting_models + └── stitching_retargeting_module.pth +``` + +### 3. Inference 🚀 + +```bash +python inference.py +``` + +If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result. + +

+ image +

+ +Or, you can change the input by specifying the `-s` and `-d` arguments: + +```bash +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 + +# or disable pasting back +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback + +# more options to see +python inference.py -h +``` + +**More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊 + +### 4. Gradio interface + +We also provide a Gradio interface for a better experience, just run by: + +```bash +python app.py +``` + +### 5. Inference speed evaluation 🚀🚀🚀 +We have also provided a script to evaluate the inference speed of each module: + +```bash +python speed.py +``` + +Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`: + +| Model | Parameters(M) | Model Size(MB) | Inference(ms) | +|-----------------------------------|:-------------:|:--------------:|:-------------:| +| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 | +| Motion Extractor | 28.12 | 108 | 0.84 | +| Spade Generator | 55.37 | 212 | 7.59 | +| Warping Module | 45.53 | 174 | 5.21 | +| Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 | + +*Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.* + + +## Acknowledgements +We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions. + +## Citation 💖 +If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: +```bibtex +@article{guo2024live, + title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control}, + author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang}, + year = {2024}, + journal = {arXiv preprint:2407.03168}, +} +``` diff --git a/subprocess/LivePortrait/pretrained_weights/docs/inference.gif b/subprocess/LivePortrait/pretrained_weights/docs/inference.gif new file mode 100644 index 0000000000000000000000000000000000000000..7e18022e5245dcb6449df6d190b538d5ca024e06 Binary files /dev/null and b/subprocess/LivePortrait/pretrained_weights/docs/inference.gif differ diff --git a/subprocess/LivePortrait/pretrained_weights/docs/showcase2.gif b/subprocess/LivePortrait/pretrained_weights/docs/showcase2.gif new file mode 100644 index 0000000000000000000000000000000000000000..29175c0eeb85b9db0ffd61e3e9281dffe3536352 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/docs/showcase2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb1fffb139681775780b2956e7d0289f55d199c1a3e14ab263887864d4b0d586 +size 2881351 diff --git a/subprocess/LivePortrait/pretrained_weights/insightface/models/buffalo_l/2d106det.onnx b/subprocess/LivePortrait/pretrained_weights/insightface/models/buffalo_l/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/insightface/models/buffalo_l/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/subprocess/LivePortrait/pretrained_weights/insightface/models/buffalo_l/det_10g.onnx b/subprocess/LivePortrait/pretrained_weights/insightface/models/buffalo_l/det_10g.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa586e034379fa5ea5babc8aa73d47afcd0fa6c2 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/insightface/models/buffalo_l/det_10g.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +size 16923827 diff --git a/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth new file mode 100644 index 0000000000000000000000000000000000000000..f05eb700c3eca1939c9d4e436bd063217eaa4587 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5279bb8654293dbdf327030b397f107237dd9212fb11dd75b83dfb635211ceb5 +size 3387959 diff --git a/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/motion_extractor.pth b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/motion_extractor.pth new file mode 100644 index 0000000000000000000000000000000000000000..a118cb8e26afc734be9abd4a6ef0163adcbd63b0 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/motion_extractor.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:251e6a94ad667a1d0c69526d292677165110ef7f0cf0f6d199f0e414e8aa0ca5 +size 112545506 diff --git a/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/spade_generator.pth b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/spade_generator.pth new file mode 100644 index 0000000000000000000000000000000000000000..0086702b84762790e06c5a4332f36d0857f594fc --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/spade_generator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4780afc7909a9f84e24c01d73b31a555ef651521a1fe3b2429bd04534d992aee +size 221813590 diff --git a/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/warping_module.pth b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/warping_module.pth new file mode 100644 index 0000000000000000000000000000000000000000..e9d4cd1bcb62e2b654c28e32f66e56d51fb10389 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/liveportrait/base_models/warping_module.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f61a6f265fe344f14132364859a78bdbbc2068577170693da57fb96d636e282 +size 182180086 diff --git a/subprocess/LivePortrait/pretrained_weights/liveportrait/landmark.onnx b/subprocess/LivePortrait/pretrained_weights/liveportrait/landmark.onnx new file mode 100644 index 0000000000000000000000000000000000000000..48eb59185aa92b6efa2855ce99129d8aff248938 --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/liveportrait/landmark.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31d22a5041326c31f19b78886939a634a5aedcaa5ab8b9b951a1167595d147db +size 114666491 diff --git a/subprocess/LivePortrait/pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth b/subprocess/LivePortrait/pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth new file mode 100644 index 0000000000000000000000000000000000000000..59f0f3830b78b8587f0bd8b9ef8fd3ffdbd9290a --- /dev/null +++ b/subprocess/LivePortrait/pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3652d5a3f95099141a56986aaddec92fadf0a73c87a20fac9a2c07c32b28b611 +size 2393098 diff --git a/subprocess/LivePortrait/readme.md b/subprocess/LivePortrait/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..607637d57ecc99e8f78b8360c15fb4638c42f78d --- /dev/null +++ b/subprocess/LivePortrait/readme.md @@ -0,0 +1,193 @@ +

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

+ +
+ Jianzhu Guo 1†  + Dingyun Zhang 1,2  + Xiaoqiang Liu 1  + Zhizhou Zhong 1,3  + Yuan Zhang 1  +
+ +
+ Pengfei Wan 1  + Di Zhang 1  +
+ +
+ 1 Kuaishou Technology  2 University of Science and Technology of China  3 Fudan University  +
+ +
+
+ + + + +
+
+ +

+ showcase +
+ 🔥 For more results, visit our homepage 🔥 +

+ + + +## 🔥 Updates +- **`2024/07/10`**: 💪 We support audio and video concatenating, driving video auto-cropping, and template making to protect privacy. More to see [here](assets/docs/changelog/2024-07-10.md). +- **`2024/07/09`**: 🤗 We released the [HuggingFace Space](https://huggingface.co/spaces/KwaiVGI/liveportrait), thanks to the HF team and [Gradio](https://github.com/gradio-app/gradio)! +- **`2024/07/04`**: 😊 We released the initial version of the inference code and models. Continuous updates, stay tuned! +- **`2024/07/04`**: 🔥 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168). + + + +## Introduction +This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168). +We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖. + +## 🔥 Getting Started +### 1. Clone the code and prepare the environment +```bash +git clone https://github.com/KwaiVGI/LivePortrait +cd LivePortrait + +# create env using conda +conda create -n LivePortrait python==3.9.18 +conda activate LivePortrait +# install dependencies with pip +pip install -r requirements.txt +``` + +### 2. Download pretrained weights + +Download the pretrained weights from HuggingFace: +```bash +# you may need to run `git lfs install` first +git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights +``` + +Or, download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows: +```text +pretrained_weights +├── insightface +│ └── models +│ └── buffalo_l +│ ├── 2d106det.onnx +│ └── det_10g.onnx +└── liveportrait + ├── base_models + │ ├── appearance_feature_extractor.pth + │ ├── motion_extractor.pth + │ ├── spade_generator.pth + │ └── warping_module.pth + ├── landmark.onnx + └── retargeting_models + └── stitching_retargeting_module.pth +``` + +### 3. Inference 🚀 + +#### Fast hands-on +```bash +python inference.py +``` + +If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result. + +

+ image +

+ +Or, you can change the input by specifying the `-s` and `-d` arguments: + +```bash +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 + +# disable pasting back to run faster +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback + +# more options to see +python inference.py -h +``` + +#### Driving video auto-cropping + +📕 To use your own driving video, we **recommend**: + - Crop it to a **1:1** aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-cropping by `--flag_crop_driving_video`. + - Focus on the head area, similar to the example videos. + - Minimize shoulder movement. + - Make sure the first frame of driving video is a frontal face with **neutral expression**. + +Below is a auto-cropping case by `--flag_crop_driving_video`: +```bash +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d13.mp4 --flag_crop_driving_video +``` + +If you find the results of auto-cropping is not well, you can modify the `--scale_crop_video`, `--vy_ratio_crop_video` options to adjust the scale and offset, or do it manually. + +#### Template making +You can also use the `.pkl` file auto-generated to speed up the inference, and **protect privacy**, such as: +```bash +python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d5.pkl +``` + +**Discover more interesting results on our [Homepage](https://liveportrait.github.io)** 😊 + +### 4. Gradio interface 🤗 + +We also provide a Gradio interface for a better experience, just run by: + +```bash +python app.py +``` + +You can specify the `--server_port`, `--share`, `--server_name` arguments to satisfy your needs! + +**Or, try it out effortlessly on [HuggingFace](https://huggingface.co/spaces/KwaiVGI/LivePortrait) 🤗** + +### 5. Inference speed evaluation 🚀🚀🚀 +We have also provided a script to evaluate the inference speed of each module: + +```bash +python speed.py +``` + +Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`: + +| Model | Parameters(M) | Model Size(MB) | Inference(ms) | +|-----------------------------------|:-------------:|:--------------:|:-------------:| +| Appearance Feature Extractor | 0.84 | 3.3 | 0.82 | +| Motion Extractor | 28.12 | 108 | 0.84 | +| Spade Generator | 55.37 | 212 | 7.59 | +| Warping Module | 45.53 | 174 | 5.21 | +| Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 | + +*Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.* + +## Community Resources 🤗 + +Discover the invaluable resources contributed by our community to enhance your LivePortrait experience: + +- [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) by [@kijai](https://github.com/kijai) +- [comfyui-liveportrait](https://github.com/shadowcz007/comfyui-liveportrait) by [@shadowcz007](https://github.com/shadowcz007) +- [LivePortrait hands-on tutorial](https://www.youtube.com/watch?v=uyjSTAOY7yI) by [@AI Search](https://www.youtube.com/@theAIsearch) +- [ComfyUI tutorial](https://www.youtube.com/watch?v=8-IcDDmiUMM) by [@Sebastian Kamph](https://www.youtube.com/@sebastiankamph) +- [LivePortrait In ComfyUI](https://www.youtube.com/watch?v=aFcS31OWMjE) by [@Benji](https://www.youtube.com/@TheFutureThinker) +- [Replicate Playground](https://replicate.com/fofr/live-portrait) and [cog-comfyui](https://github.com/fofr/cog-comfyui) by [@fofr](https://github.com/fofr) + +And many more amazing contributions from our community! + +## Acknowledgements +We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions. + +## Citation 💖 +If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX: +```bibtex +@article{guo2024liveportrait, + title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control}, + author = {Guo, Jianzhu and Zhang, Dingyun and Liu, Xiaoqiang and Zhong, Zhizhou and Zhang, Yuan and Wan, Pengfei and Zhang, Di}, + journal = {arXiv preprint arXiv:2407.03168}, + year = {2024} +} +``` diff --git a/subprocess/LivePortrait/requirements.txt b/subprocess/LivePortrait/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b2e1c85102bfaa2873fe00b6ce030e847debee1b --- /dev/null +++ b/subprocess/LivePortrait/requirements.txt @@ -0,0 +1,22 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.3.0 +torchvision==0.18.0 +torchaudio==2.3.0 + +numpy==1.26.4 +pyyaml==6.0.1 +opencv-python==4.10.0.84 +scipy==1.13.1 +imageio==2.34.2 +lmdb==1.4.1 +tqdm==4.66.4 +rich==13.7.1 +ffmpeg-python==0.2.0 +onnxruntime-gpu==1.18.0 +onnx==1.16.1 +scikit-image==0.24.0 +albumentations==1.4.10 +matplotlib==3.9.0 +imageio-ffmpeg==0.5.1 +tyro==0.8.5 +gradio==4.37.1 diff --git a/subprocess/LivePortrait/speed.py b/subprocess/LivePortrait/speed.py new file mode 100644 index 0000000000000000000000000000000000000000..3cad2483a1eae9a7f89a73c961480404083038dd --- /dev/null +++ b/subprocess/LivePortrait/speed.py @@ -0,0 +1,192 @@ +# coding: utf-8 + +""" +Benchmark the inference speed of each module in LivePortrait. + +TODO: heavy GPT style, need to refactor +""" + +import yaml +import torch +import time +import numpy as np +from src.utils.helper import load_model, concat_feat +from src.config.inference_config import InferenceConfig + + +def initialize_inputs(batch_size=1): + """ + Generate random input tensors and move them to GPU + """ + feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() + kp_source = torch.randn(batch_size, 21, 3).cuda().half() + kp_driving = torch.randn(batch_size, 21, 3).cuda().half() + source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() + generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() + eye_close_ratio = torch.randn(batch_size, 3).cuda().half() + lip_close_ratio = torch.randn(batch_size, 2).cuda().half() + feat_stitching = concat_feat(kp_source, kp_driving).half() + feat_eye = concat_feat(kp_source, eye_close_ratio).half() + feat_lip = concat_feat(kp_source, lip_close_ratio).half() + + inputs = { + 'feature_3d': feature_3d, + 'kp_source': kp_source, + 'kp_driving': kp_driving, + 'source_image': source_image, + 'generator_input': generator_input, + 'feat_stitching': feat_stitching, + 'feat_eye': feat_eye, + 'feat_lip': feat_lip + } + + return inputs + + +def load_and_compile_models(cfg, model_config): + """ + Load and compile models for inference + """ + appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device, 'appearance_feature_extractor') + motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device, 'motion_extractor') + warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device, 'warping_module') + spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device, 'spade_generator') + stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device, 'stitching_retargeting_module') + + models_with_params = [ + ('Appearance Feature Extractor', appearance_feature_extractor), + ('Motion Extractor', motion_extractor), + ('Warping Network', warping_module), + ('SPADE Decoder', spade_generator) + ] + + compiled_models = {} + for name, model in models_with_params: + model = model.half() + model = torch.compile(model, mode='max-autotune') # Optimize for inference + model.eval() # Switch to evaluation mode + compiled_models[name] = model + + retargeting_models = ['stitching', 'eye', 'lip'] + for retarget in retargeting_models: + module = stitching_retargeting_module[retarget].half() + module = torch.compile(module, mode='max-autotune') # Optimize for inference + module.eval() # Switch to evaluation mode + stitching_retargeting_module[retarget] = module + + return compiled_models, stitching_retargeting_module + + +def warm_up_models(compiled_models, stitching_retargeting_module, inputs): + """ + Warm up models to prepare them for benchmarking + """ + print("Warm up start!") + with torch.no_grad(): + for _ in range(10): + compiled_models['Appearance Feature Extractor'](inputs['source_image']) + compiled_models['Motion Extractor'](inputs['source_image']) + compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) + compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required + stitching_retargeting_module['stitching'](inputs['feat_stitching']) + stitching_retargeting_module['eye'](inputs['feat_eye']) + stitching_retargeting_module['lip'](inputs['feat_lip']) + print("Warm up end!") + + +def measure_inference_times(compiled_models, stitching_retargeting_module, inputs): + """ + Measure inference times for each model + """ + times = {name: [] for name in compiled_models.keys()} + times['Retargeting Models'] = [] + + overall_times = [] + + with torch.no_grad(): + for _ in range(100): + torch.cuda.synchronize() + overall_start = time.time() + + start = time.time() + compiled_models['Appearance Feature Extractor'](inputs['source_image']) + torch.cuda.synchronize() + times['Appearance Feature Extractor'].append(time.time() - start) + + start = time.time() + compiled_models['Motion Extractor'](inputs['source_image']) + torch.cuda.synchronize() + times['Motion Extractor'].append(time.time() - start) + + start = time.time() + compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) + torch.cuda.synchronize() + times['Warping Network'].append(time.time() - start) + + start = time.time() + compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required + torch.cuda.synchronize() + times['SPADE Decoder'].append(time.time() - start) + + start = time.time() + stitching_retargeting_module['stitching'](inputs['feat_stitching']) + stitching_retargeting_module['eye'](inputs['feat_eye']) + stitching_retargeting_module['lip'](inputs['feat_lip']) + torch.cuda.synchronize() + times['Retargeting Models'].append(time.time() - start) + + overall_times.append(time.time() - overall_start) + + return times, overall_times + + +def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times): + """ + Print benchmark results with average and standard deviation of inference times + """ + average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()} + std_times = {name: np.std(times[name]) * 1000 for name in times.keys()} + + for name, model in compiled_models.items(): + num_params = sum(p.numel() for p in model.parameters()) + num_params_in_millions = num_params / 1e6 + print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M") + + for index, retarget in enumerate(retargeting_models): + num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters()) + num_params_in_millions = num_params / 1e6 + print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M") + + for name, avg_time in average_times.items(): + std_time = std_times[name] + print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)") + + +def main(): + """ + Main function to benchmark speed and model parameters + """ + # Sample input tensors + inputs = initialize_inputs() + + # Load configuration + cfg = InferenceConfig(device_id=0) + model_config_path = cfg.models_config + with open(model_config_path, 'r') as file: + model_config = yaml.safe_load(file) + + # Load and compile models + compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) + + # Warm up models + warm_up_models(compiled_models, stitching_retargeting_module, inputs) + + # Measure inference times + times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs) + + # Print benchmark results + print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times) + + +if __name__ == "__main__": + main() diff --git a/subprocess/LivePortrait/src/__pycache__/gradio_pipeline.cpython-310.pyc b/subprocess/LivePortrait/src/__pycache__/gradio_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..813564ec8ca0b90a69b2e8aefe747fb594a1109c Binary files /dev/null and b/subprocess/LivePortrait/src/__pycache__/gradio_pipeline.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/__pycache__/gradio_pipeline.cpython-39.pyc b/subprocess/LivePortrait/src/__pycache__/gradio_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33d2acb7162450f596fe784d4607efaf255a9a6c Binary files /dev/null and b/subprocess/LivePortrait/src/__pycache__/gradio_pipeline.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/__pycache__/live_portrait_pipeline.cpython-310.pyc b/subprocess/LivePortrait/src/__pycache__/live_portrait_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9c7c5d1d57a2ea5fb2552a49a4639cd8496b15b Binary files /dev/null and b/subprocess/LivePortrait/src/__pycache__/live_portrait_pipeline.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/__pycache__/live_portrait_pipeline.cpython-39.pyc b/subprocess/LivePortrait/src/__pycache__/live_portrait_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60b0a620e4cc0e9c618d9f90da8e0521e605c66 Binary files /dev/null and b/subprocess/LivePortrait/src/__pycache__/live_portrait_pipeline.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/__pycache__/live_portrait_wrapper.cpython-310.pyc b/subprocess/LivePortrait/src/__pycache__/live_portrait_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe9a0162bc65096c38c2a9fbe4272e6dc0bffbd0 Binary files /dev/null and b/subprocess/LivePortrait/src/__pycache__/live_portrait_wrapper.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/__pycache__/live_portrait_wrapper.cpython-39.pyc b/subprocess/LivePortrait/src/__pycache__/live_portrait_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22ed74ba11548f236fceef01a74367129ad40e2d Binary files /dev/null and b/subprocess/LivePortrait/src/__pycache__/live_portrait_wrapper.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/config/__init__.py b/subprocess/LivePortrait/src/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/subprocess/LivePortrait/src/config/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/config/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..455f96f26cfe6d2a81fa09ef0813714d2538b3b0 Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/config/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..474f67705bd1e59ad812f03611f384d9431fd70e Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/argument_config.cpython-310.pyc b/subprocess/LivePortrait/src/config/__pycache__/argument_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ecb6ba4656b0c5e8e2ab3f41c4f490a3c7747c8 Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/argument_config.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/argument_config.cpython-39.pyc b/subprocess/LivePortrait/src/config/__pycache__/argument_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee31b9ad4a955b50d14f1d69124c658d7ae25f5b Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/argument_config.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/base_config.cpython-310.pyc b/subprocess/LivePortrait/src/config/__pycache__/base_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8d01e8b98e40320ed3369f8c139030918a4e22f Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/base_config.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/base_config.cpython-39.pyc b/subprocess/LivePortrait/src/config/__pycache__/base_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6463b2addaa81a3ddd3321f254030bf8057d31d7 Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/base_config.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/crop_config.cpython-310.pyc b/subprocess/LivePortrait/src/config/__pycache__/crop_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd36d7ceee25ef8538eda5b34782dea3d6af26bf Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/crop_config.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/crop_config.cpython-39.pyc b/subprocess/LivePortrait/src/config/__pycache__/crop_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5496a1b79fa2cf2b80a5f9aac68f2b257364b11 Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/crop_config.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/inference_config.cpython-310.pyc b/subprocess/LivePortrait/src/config/__pycache__/inference_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d4ad3ce4da41810eb7a9693baa460883451788c Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/inference_config.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/config/__pycache__/inference_config.cpython-39.pyc b/subprocess/LivePortrait/src/config/__pycache__/inference_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bec8296a8365ee91c83b6aa981acb99a4080404 Binary files /dev/null and b/subprocess/LivePortrait/src/config/__pycache__/inference_config.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/config/argument_config.py b/subprocess/LivePortrait/src/config/argument_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0bbaa201ccc42631aa9b4681c9ccadf8e8379983 --- /dev/null +++ b/subprocess/LivePortrait/src/config/argument_config.py @@ -0,0 +1,47 @@ +# coding: utf-8 + +""" +All configs for user +""" + +from dataclasses import dataclass +import tyro +from typing_extensions import Annotated +from typing import Optional +from .base_config import PrintableConfig, make_abs_path + + +@dataclass(repr=False) # use repr from PrintableConfig +class ArgumentConfig(PrintableConfig): + ########## input arguments ########## + source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait + driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) + output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video + + ########## inference arguments ########## + flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False. + flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video + device_id: int = 0 # gpu device id + flag_force_cpu: bool = False # force cpu inference, WIP! + flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False + flag_eye_retargeting: bool = False # not recommend to be True, WIP + flag_lip_retargeting: bool = False # not recommend to be True, WIP + flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large + flag_relative_motion: bool = True # whether to use relative motion + flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space + flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space + flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True + + ########## crop arguments ########## + scale: float = 2.3 # the ratio of face area is smaller if scale is larger + vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space + vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space + + scale_crop_video: float = 2.2 # scale factor for cropping video + vx_ratio_crop_video: float = 0. # adjust y offset + vy_ratio_crop_video: float = -0.1 # adjust x offset + + ########## gradio arguments ########## + server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server + share: bool = False # whether to share the server to public + server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all diff --git a/subprocess/LivePortrait/src/config/base_config.py b/subprocess/LivePortrait/src/config/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..216b8be50aecc8af4b9d1d2a9401e034dd7769e4 --- /dev/null +++ b/subprocess/LivePortrait/src/config/base_config.py @@ -0,0 +1,29 @@ +# coding: utf-8 + +""" +pretty printing class +""" + +from __future__ import annotations +import os.path as osp +from typing import Tuple + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +class PrintableConfig: # pylint: disable=too-few-public-methods + """Printable Config defining str function""" + + def __repr__(self): + lines = [self.__class__.__name__ + ":"] + for key, val in vars(self).items(): + if isinstance(val, Tuple): + flattened_val = "[" + for item in val: + flattened_val += str(item) + "\n" + flattened_val = flattened_val.rstrip("\n") + val = flattened_val + "]" + lines += f"{key}: {str(val)}".split("\n") + return "\n ".join(lines) diff --git a/subprocess/LivePortrait/src/config/crop_config.py b/subprocess/LivePortrait/src/config/crop_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c8b12c4528409dabafcc593f7d86daf7e9c2a4 --- /dev/null +++ b/subprocess/LivePortrait/src/config/crop_config.py @@ -0,0 +1,28 @@ +# coding: utf-8 + +""" +parameters used for crop faces +""" + +import os.path as osp +from dataclasses import dataclass +from typing import Union, List +from .base_config import PrintableConfig + + +@dataclass(repr=False) # use repr from PrintableConfig +class CropConfig(PrintableConfig): + device_id: int = 0 # gpu device id + flag_force_cpu: bool = False # force cpu inference, WIP + ########## source image cropping option ########## + dsize: int = 512 # crop size + scale: float = 2.5 # scale factor + vx_ratio: float = 0 # vx ratio + vy_ratio: float = -0.125 # vy ratio +up, -down + max_face_num: int = 0 # max face number, 0 mean no limit + + ########## driving video auto cropping option ########## + scale_crop_video: float = 2.2 #2.0 # scale factor for cropping video + vx_ratio_crop_video: float = 0. # adjust y offset + vy_ratio_crop_video: float = -0.1 # adjust x offset + direction: str = 'large-small' # direction of cropping diff --git a/subprocess/LivePortrait/src/config/inference_config.py b/subprocess/LivePortrait/src/config/inference_config.py new file mode 100644 index 0000000000000000000000000000000000000000..70eedd81d42a3958e05cb2124ded7a79be5649a3 --- /dev/null +++ b/subprocess/LivePortrait/src/config/inference_config.py @@ -0,0 +1,51 @@ +# coding: utf-8 + +""" +config dataclass used for inference +""" + +import os.path as osp +import cv2 +from numpy import ndarray +from dataclasses import dataclass +from typing import Literal, Tuple +from .base_config import PrintableConfig, make_abs_path + + +@dataclass(repr=False) # use repr from PrintableConfig +class InferenceConfig(PrintableConfig): + # MODEL CONFIG, NOT EXPOERTED PARAMS + models_config: str = make_abs_path('./models.yaml') # portrait animation config + checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F + checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M + checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G + checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W + checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip + + # EXPOERTED PARAMS + flag_use_half_precision: bool = True + flag_crop_driving_video: bool = False + device_id: int = 0 + flag_lip_zero: bool = True + flag_eye_retargeting: bool = False + flag_lip_retargeting: bool = False + flag_stitching: bool = True + flag_relative_motion: bool = True + flag_pasteback: bool = True + flag_do_crop: bool = True + flag_do_rot: bool = True + flag_force_cpu: bool = False + + # NOT EXPOERTED PARAMS + lip_zero_threshold: float = 0.03 # threshold for flag_lip_zero + anchor_frame: int = 0 # TO IMPLEMENT + + input_shape: Tuple[int, int] = (256, 256) # input shape + output_format: Literal['mp4', 'gif'] = 'mp4' # output video format + crf: int = 15 # crf for output video + output_fps: int = 25 # default output fps + + mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR) + size_gif: int = 256 # default gif size, TO IMPLEMENT + source_max_dim: int = 1280 # the max dim of height and width of source image + source_division: int = 2 # make sure the height and width of source image can be divided by this number diff --git a/subprocess/LivePortrait/src/config/models.yaml b/subprocess/LivePortrait/src/config/models.yaml new file mode 100644 index 0000000000000000000000000000000000000000..131d1c65025c31e37af9239e211ea14454128a2e --- /dev/null +++ b/subprocess/LivePortrait/src/config/models.yaml @@ -0,0 +1,43 @@ +model_params: + appearance_feature_extractor_params: # the F in the paper + image_channel: 3 + block_expansion: 64 + num_down_blocks: 2 + max_features: 512 + reshape_channel: 32 + reshape_depth: 16 + num_resblocks: 6 + motion_extractor_params: # the M in the paper + num_kp: 21 + backbone: convnextv2_tiny + warping_module_params: # the W in the paper + num_kp: 21 + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + spade_generator_params: # the G in the paper + upscale: 2 # represents upsample factor 256x256 -> 512x512 + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + stitching_retargeting_module_params: # the S in the paper + stitching: + input_size: 126 # (21*3)*2 + hidden_sizes: [128, 128, 64] + output_size: 65 # (21*3)+2(tx,ty) + lip: + input_size: 65 # (21*3)+2 + hidden_sizes: [128, 128, 64] + output_size: 63 # (21*3) + eye: + input_size: 66 # (21*3)+3 + hidden_sizes: [256, 256, 128, 128, 64] + output_size: 63 # (21*3) diff --git a/subprocess/LivePortrait/src/gradio_pipeline.py b/subprocess/LivePortrait/src/gradio_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f7343f7df6b8a6c6815c5af3526ed6dc857a7c0c --- /dev/null +++ b/subprocess/LivePortrait/src/gradio_pipeline.py @@ -0,0 +1,117 @@ +# coding: utf-8 + +""" +Pipeline for gradio +""" +import gradio as gr + +from .config.argument_config import ArgumentConfig +from .live_portrait_pipeline import LivePortraitPipeline +from .utils.io import load_img_online +from .utils.rprint import rlog as log +from .utils.crop import prepare_paste_back, paste_back +from .utils.camera import get_rotation_matrix + + +def update_args(args, user_args): + """update the args according to user inputs + """ + for k, v in user_args.items(): + if hasattr(args, k): + setattr(args, k, v) + return args + + +class GradioPipeline(LivePortraitPipeline): + + def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig): + super().__init__(inference_cfg, crop_cfg) + # self.live_portrait_wrapper = self.live_portrait_wrapper + self.args = args + + def execute_video( + self, + input_image_path, + input_video_path, + flag_relative_input, + flag_do_crop_input, + flag_remap_input, + flag_crop_driving_video_input + ): + """ for video driven potrait animation + """ + if input_image_path is not None and input_video_path is not None: + args_user = { + 'source_image': input_image_path, + 'driving_info': input_video_path, + 'flag_relative': flag_relative_input, + 'flag_do_crop': flag_do_crop_input, + 'flag_pasteback': flag_remap_input, + 'flag_crop_driving_video': flag_crop_driving_video_input + } + # update config from user input + self.args = update_args(self.args, args_user) + self.live_portrait_wrapper.update_config(self.args.__dict__) + self.cropper.update_config(self.args.__dict__) + # video driven animation + video_path, video_path_concat = self.execute(self.args) + gr.Info("Run successfully!", duration=2) + return video_path, video_path_concat, + else: + raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5) + + def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True): + """ for single image retargeting + """ + # disposable feature + f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ + self.prepare_retargeting(input_image, flag_do_crop) + + if input_eye_ratio is None or input_lip_ratio is None: + raise gr.Error("Invalid ratio input 💥!", duration=5) + else: + inference_cfg = self.live_portrait_wrapper.inference_cfg + x_s_user = x_s_user.to(self.live_portrait_wrapper.device) + f_s_user = f_s_user.to(self.live_portrait_wrapper.device) + # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user) + eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor) + # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user) + lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor) + num_kp = x_s_user.shape[1] + # default: use x_s + x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) + # D(W(f_s; x_s, x′_d)) + out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new) + out = self.live_portrait_wrapper.parse_output(out['out'])[0] + out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori) + gr.Info("Run successfully!", duration=2) + return out, out_to_ori_blend + + def prepare_retargeting(self, input_image, flag_do_crop=True): + """ for single image retargeting + """ + if input_image is not None: + # gr.Info("Upload successfully!", duration=2) + inference_cfg = self.live_portrait_wrapper.inference_cfg + ######## process source portrait ######## + img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16) + log(f"Load source image from {input_image}.") + crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg) + if flag_do_crop: + I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) + else: + I_s = self.live_portrait_wrapper.prepare_source(img_rgb) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) + R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + ############################################ + f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) + source_lmk_user = crop_info['lmk_crop'] + crop_M_c2o = crop_info['M_c2o'] + mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) + return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb + else: + # when press the clear button, go here + raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5) diff --git a/subprocess/LivePortrait/src/live_portrait_pipeline.py b/subprocess/LivePortrait/src/live_portrait_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a27b52f82758583795a6780f0c83487f55366ac1 --- /dev/null +++ b/subprocess/LivePortrait/src/live_portrait_pipeline.py @@ -0,0 +1,285 @@ +# coding: utf-8 + +""" +Pipeline of LivePortrait +""" + +import torch +torch.backends.cudnn.benchmark = True # disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warning + +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) +import numpy as np +import os +import os.path as osp +from rich.progress import track + +from .config.argument_config import ArgumentConfig +from .config.inference_config import InferenceConfig +from .config.crop_config import CropConfig +from .utils.cropper import Cropper +from .utils.camera import get_rotation_matrix +from .utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream +from .utils.crop import _transform_img, prepare_paste_back, paste_back +from .utils.io import load_image_rgb, load_driving_info, resize_to_limit, dump, load +from .utils.helper import mkdir, basename, dct2device, is_video, is_template, remove_suffix +from .utils.rprint import rlog as log +# from .utils.viz import viz_lmk +from .live_portrait_wrapper import LivePortraitWrapper + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +class LivePortraitPipeline(object): + + def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): + self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(inference_cfg=inference_cfg) + self.cropper: Cropper = Cropper(crop_cfg=crop_cfg) + + def execute(self, args: ArgumentConfig): + # for convenience + inf_cfg = self.live_portrait_wrapper.inference_cfg + device = self.live_portrait_wrapper.device + crop_cfg = self.cropper.crop_cfg + + ######## process source portrait ######## + img_rgb = load_image_rgb(args.source_image) + img_rgb = resize_to_limit(img_rgb, inf_cfg.source_max_dim, inf_cfg.source_division) + log(f"Load source image from {args.source_image}") + + crop_info = self.cropper.crop_source_image(img_rgb, crop_cfg) + if crop_info is None: + raise Exception("No face detected in the source image!") + source_lmk = crop_info['lmk_crop'] + img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] + + if inf_cfg.flag_do_crop: + I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + else: + img_crop_256x256 = cv2.resize(img_rgb, (256, 256)) # force to resize to 256x256 + I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) + x_c_s = x_s_info['kp'] + R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) + + flag_lip_zero = inf_cfg.flag_lip_zero # not overwrite + if flag_lip_zero: + # let lip-open scalar to be 0 at first + c_d_lip_before_animation = [0.] + combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) + if combined_lip_ratio_tensor_before_animation[0][0] < inf_cfg.lip_zero_threshold: + flag_lip_zero = False + else: + lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) + ############################################ + + ######## process driving info ######## + flag_load_from_template = is_template(args.driving_info) + driving_rgb_crop_256x256_lst = None + wfp_template = None + + if flag_load_from_template: + # NOTE: load from template, it is fast, but the cropping video is None + log(f"Load from template: {args.driving_info}, NOT the video, so the cropping video and audio are both NULL.", style='bold green') + template_dct = load(args.driving_info) + n_frames = template_dct['n_frames'] + + # set output_fps + output_fps = template_dct.get('output_fps', inf_cfg.output_fps) + log(f'The FPS of template: {output_fps}') + + if args.flag_crop_driving_video: + log("Warning: flag_crop_driving_video is True, but the driving info is a template, so it is ignored.") + + elif osp.exists(args.driving_info) and is_video(args.driving_info): + # load from video file, AND make motion template + log(f"Load video: {args.driving_info}") + if osp.isdir(args.driving_info): + output_fps = inf_cfg.output_fps + else: + output_fps = int(get_fps(args.driving_info)) + log(f'The FPS of {args.driving_info} is: {output_fps}') + + log(f"Load video file (mp4 mov avi etc...): {args.driving_info}") + driving_rgb_lst = load_driving_info(args.driving_info) + + ######## make motion template ######## + log("Start making motion template...") + if inf_cfg.flag_crop_driving_video: + ret = self.cropper.crop_driving_video(driving_rgb_lst) + log(f'Driving video is cropped, {len(ret["frame_crop_lst"])} frames are processed.') + driving_rgb_crop_lst, driving_lmk_crop_lst = ret['frame_crop_lst'], ret['lmk_crop_lst'] + driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_crop_lst] + else: + driving_lmk_crop_lst = self.cropper.calc_lmks_from_cropped_video(driving_rgb_lst) + driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] # force to resize to 256x256 + + c_d_eyes_lst, c_d_lip_lst = self.live_portrait_wrapper.calc_driving_ratio(driving_lmk_crop_lst) + # save the motion template + I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_crop_256x256_lst) + template_dct = self.make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps) + + wfp_template = remove_suffix(args.driving_info) + '.pkl' + dump(wfp_template, template_dct) + log(f"Dump motion template to {wfp_template}") + + n_frames = I_d_lst.shape[0] + else: + raise Exception(f"{args.driving_info} not exists or unsupported driving info types!") + ######################################### + + ######## prepare for pasteback ######## + I_p_pstbk_lst = None + if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: + mask_ori_float = prepare_paste_back(inf_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) + I_p_pstbk_lst = [] + log("Prepared pasteback mask done.") + ######################################### + + I_p_lst = [] + R_d_0, x_d_0_info = None, None + + for i in track(range(n_frames), description='🚀Animating...', total=n_frames): + x_d_i_info = template_dct['motion'][i] + x_d_i_info = dct2device(x_d_i_info, device) + R_d_i = x_d_i_info['R_d'] + + if i == 0: + R_d_0 = R_d_i + x_d_0_info = x_d_i_info + + if inf_cfg.flag_relative_motion: + R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) + t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + else: + R_new = R_d_i + delta_new = x_d_i_info['exp'] + scale_new = x_s_info['scale'] + t_new = x_d_i_info['t'] + + t_new[..., 2].fill_(0) # zero tz + x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new + + # Algorithm 1: + if not inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: + # without stitching or retargeting + if flag_lip_zero: + x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + else: + pass + elif inf_cfg.flag_stitching and not inf_cfg.flag_eye_retargeting and not inf_cfg.flag_lip_retargeting: + # with stitching and without retargeting + if flag_lip_zero: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + else: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + else: + eyes_delta, lip_delta = None, None + if inf_cfg.flag_eye_retargeting: + c_d_eyes_i = c_d_eyes_lst[i] + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) + # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) + eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) + if inf_cfg.flag_lip_retargeting: + c_d_lip_i = c_d_lip_lst[i] + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) + # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) + lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) + + if inf_cfg.flag_relative_motion: # use x_s + x_d_i_new = x_s + \ + (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ + (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + else: # use x_d,i + x_d_i_new = x_d_i_new + \ + (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ + (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + + if inf_cfg.flag_stitching: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + + out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) + I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] + I_p_lst.append(I_p_i) + + if inf_cfg.flag_pasteback and inf_cfg.flag_do_crop and inf_cfg.flag_stitching: + # TODO: pasteback is slow, considering optimize it using multi-threading or GPU + I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori_float) + I_p_pstbk_lst.append(I_p_pstbk) + + mkdir(args.output_dir) + wfp_concat = None + flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info) + + ######### build final concact result ######### + # driving frame | source image | generation, or source image | generation + frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst) + wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') + images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) + + if flag_has_audio: + # final result with concact + wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4') + add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio) + os.replace(wfp_concat_with_audio, wfp_concat) + log(f"Replace {wfp_concat} with {wfp_concat_with_audio}") + + # save drived result + wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') + if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: + images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) + else: + images2video(I_p_lst, wfp=wfp, fps=output_fps) + + ######### build final result ######### + if flag_has_audio: + wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_with_audio.mp4') + add_audio_to_video(wfp, args.driving_info, wfp_with_audio) + os.replace(wfp_with_audio, wfp) + log(f"Replace {wfp} with {wfp_with_audio}") + + # final log + if wfp_template not in (None, ''): + log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green') + log(f'Animated video: {wfp}') + log(f'Animated video with concact: {wfp_concat}') + + return wfp, wfp_concat + + def make_motion_template(self, I_d_lst, c_d_eyes_lst, c_d_lip_lst, **kwargs): + n_frames = I_d_lst.shape[0] + template_dct = { + 'n_frames': n_frames, + 'output_fps': kwargs.get('output_fps', 25), + 'motion': [], + 'c_d_eyes_lst': [], + 'c_d_lip_lst': [], + } + + for i in track(range(n_frames), description='Making motion templates...', total=n_frames): + # collect s_d, R_d, δ_d and t_d for inference + I_d_i = I_d_lst[i] + x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) + R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) + + item_dct = { + 'scale': x_d_i_info['scale'].cpu().numpy().astype(np.float32), + 'R_d': R_d_i.cpu().numpy().astype(np.float32), + 'exp': x_d_i_info['exp'].cpu().numpy().astype(np.float32), + 't': x_d_i_info['t'].cpu().numpy().astype(np.float32), + } + + template_dct['motion'].append(item_dct) + + c_d_eyes = c_d_eyes_lst[i].astype(np.float32) + template_dct['c_d_eyes_lst'].append(c_d_eyes) + + c_d_lip = c_d_lip_lst[i].astype(np.float32) + template_dct['c_d_lip_lst'].append(c_d_lip) + + return template_dct diff --git a/subprocess/LivePortrait/src/live_portrait_wrapper.py b/subprocess/LivePortrait/src/live_portrait_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8869b952d92041d1da6a546be3b14baa02aa57a5 --- /dev/null +++ b/subprocess/LivePortrait/src/live_portrait_wrapper.py @@ -0,0 +1,311 @@ +# coding: utf-8 + +""" +Wrapper for LivePortrait core functions +""" + +import os.path as osp +import numpy as np +import cv2 +import torch +import yaml + +from .utils.timer import Timer +from .utils.helper import load_model, concat_feat +from .utils.camera import headpose_pred_to_degree, get_rotation_matrix +from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio +from .config.inference_config import InferenceConfig +from .utils.rprint import rlog as log + + +class LivePortraitWrapper(object): + + def __init__(self, inference_cfg: InferenceConfig): + + self.inference_cfg = inference_cfg + self.device_id = inference_cfg.device_id + if inference_cfg.flag_force_cpu: + self.device = 'cpu' + else: + self.device = 'cuda:' + str(self.device_id) + + model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader) + # init F + self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor') + log(f'Load appearance_feature_extractor done.') + # init M + self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor') + log(f'Load motion_extractor done.') + # init W + self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module') + log(f'Load warping_module done.') + # init G + self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator') + log(f'Load spade_generator done.') + # init S and R + if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S): + self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module') + log(f'Load stitching_retargeting_module done.') + else: + self.stitching_retargeting_module = None + + + + self.timer = Timer() + + def update_config(self, user_args): + for k, v in user_args.items(): + if hasattr(self.inference_cfg, k): + setattr(self.inference_cfg, k, v) + + def prepare_source(self, img: np.ndarray) -> torch.Tensor: + """ construct the input as standard + img: HxWx3, uint8, 256x256 + """ + h, w = img.shape[:2] + if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]: + x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1])) + else: + x = img.copy() + + if x.ndim == 3: + x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1 + elif x.ndim == 4: + x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1 + else: + raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') + x = np.clip(x, 0, 1) # clip to 0~1 + x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW + x = x.to(self.device) + return x + + def prepare_driving_videos(self, imgs) -> torch.Tensor: + """ construct the input as standard + imgs: NxBxHxWx3, uint8 + """ + if isinstance(imgs, list): + _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1 + elif isinstance(imgs, np.ndarray): + _imgs = imgs + else: + raise ValueError(f'imgs type error: {type(imgs)}') + + y = _imgs.astype(np.float32) / 255. + y = np.clip(y, 0, 1) # clip to 0~1 + y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW + y = y.to(self.device) + + return y + + def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor: + """ get the appearance feature of the image by F + x: Bx3xHxW, normalized to 0~1 + """ + with torch.no_grad(): + with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision): + feature_3d = self.appearance_feature_extractor(x) + + return feature_3d.float() + + def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict: + """ get the implicit keypoint information + x: Bx3xHxW, normalized to 0~1 + flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape + return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp' + """ + with torch.no_grad(): + with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision): + kp_info = self.motion_extractor(x) + + if self.inference_cfg.flag_use_half_precision: + # float the dict + for k, v in kp_info.items(): + if isinstance(v, torch.Tensor): + kp_info[k] = v.float() + + flag_refine_info: bool = kwargs.get('flag_refine_info', True) + if flag_refine_info: + bs = kp_info['kp'].shape[0] + kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1 + kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1 + kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1 + kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3 + kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3 + + return kp_info + + def get_pose_dct(self, kp_info: dict) -> dict: + pose_dct = dict( + pitch=headpose_pred_to_degree(kp_info['pitch']).item(), + yaw=headpose_pred_to_degree(kp_info['yaw']).item(), + roll=headpose_pred_to_degree(kp_info['roll']).item(), + ) + return pose_dct + + def get_fs_and_kp_info(self, source_prepared, driving_first_frame): + + # get the canonical keypoints of source image by M + source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True) + source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll']) + + # get the canonical keypoints of first driving frame by M + driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True) + driving_first_frame_rotation = get_rotation_matrix( + driving_first_frame_kp_info['pitch'], + driving_first_frame_kp_info['yaw'], + driving_first_frame_kp_info['roll'] + ) + + # get feature volume by F + source_feature_3d = self.extract_feature_3d(source_prepared) + + return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation + + def transform_keypoint(self, kp_info: dict): + """ + transform the implicit keypoints with the pose, shift, and expression deformation + kp: BxNx3 + """ + kp = kp_info['kp'] # (bs, k, 3) + pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] + + t, exp = kp_info['t'], kp_info['exp'] + scale = kp_info['scale'] + + pitch = headpose_pred_to_degree(pitch) + yaw = headpose_pred_to_degree(yaw) + roll = headpose_pred_to_degree(roll) + + bs = kp.shape[0] + if kp.ndim == 2: + num_kp = kp.shape[1] // 3 # Bx(num_kpx3) + else: + num_kp = kp.shape[1] # Bxnum_kpx3 + + rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) + + # Eqn.2: s * (R * x_c,s + exp) + t + kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3) + kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) + kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty + + return kp_transformed + + def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor: + """ + kp_source: BxNx3 + eye_close_ratio: Bx3 + Return: Bx(3*num_kp+2) + """ + feat_eye = concat_feat(kp_source, eye_close_ratio) + + with torch.no_grad(): + delta = self.stitching_retargeting_module['eye'](feat_eye) + + return delta + + def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: + """ + kp_source: BxNx3 + lip_close_ratio: Bx2 + """ + feat_lip = concat_feat(kp_source, lip_close_ratio) + + with torch.no_grad(): + delta = self.stitching_retargeting_module['lip'](feat_lip) + + return delta + + def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ + kp_source: BxNx3 + kp_driving: BxNx3 + Return: Bx(3*num_kp+2) + """ + feat_stiching = concat_feat(kp_source, kp_driving) + + with torch.no_grad(): + delta = self.stitching_retargeting_module['stitching'](feat_stiching) + + return delta + + def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ conduct the stitching + kp_source: Bxnum_kpx3 + kp_driving: Bxnum_kpx3 + """ + + if self.stitching_retargeting_module is not None: + + bs, num_kp = kp_source.shape[:2] + + kp_driving_new = kp_driving.clone() + delta = self.stitch(kp_source, kp_driving_new) + + delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3 + delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2 + + kp_driving_new += delta_exp + kp_driving_new[..., :2] += delta_tx_ty + + return kp_driving_new + + return kp_driving + + def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ get the image after the warping of the implicit keypoints + feature_3d: Bx32x16x64x64, feature volume + kp_source: BxNx3 + kp_driving: BxNx3 + """ + # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)) + with torch.no_grad(): + with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision): + # get decoder input + ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) + # decode + ret_dct['out'] = self.spade_generator(feature=ret_dct['out']) + + # float the dict + if self.inference_cfg.flag_use_half_precision: + for k, v in ret_dct.items(): + if isinstance(v, torch.Tensor): + ret_dct[k] = v.float() + + return ret_dct + + def parse_output(self, out: torch.Tensor) -> np.ndarray: + """ construct the output as standard + return: 1xHxWx3, uint8 + """ + out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3 + out = np.clip(out, 0, 1) # clip to 0~1 + out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255 + + return out + + def calc_driving_ratio(self, driving_lmk_lst): + input_eye_ratio_lst = [] + input_lip_ratio_lst = [] + for lmk in driving_lmk_lst: + # for eyes retargeting + input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) + # for lip retargeting + input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None])) + return input_eye_ratio_lst, input_lip_ratio_lst + + def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk): + c_s_eyes = calc_eye_close_ratio(source_lmk[None]) + c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device) + c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device) + # [c_s,eyes, c_d,eyes,i] + combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1) + return combined_eye_ratio_tensor + + def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk): + c_s_lip = calc_lip_close_ratio(source_lmk[None]) + c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device) + c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1 + # [c_s,lip, c_d,lip,i] + combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2 + return combined_lip_ratio_tensor diff --git a/subprocess/LivePortrait/src/modules/__init__.py b/subprocess/LivePortrait/src/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/subprocess/LivePortrait/src/modules/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b9de48b6778d3a65edc3166502f82fc6e9b4c2c Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd5f1e3a9fdfc2b0fa357e78b977515b20ca1efc Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/appearance_feature_extractor.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/appearance_feature_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d71758976c98dcd5b196ad3431ac965edad22f Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/appearance_feature_extractor.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/appearance_feature_extractor.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/appearance_feature_extractor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c62a9a2580a2d4587c972d8078e2165ca079fd6 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/appearance_feature_extractor.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/convnextv2.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/convnextv2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab96ef4fde53e58880d8df0d11e8d458b5cae8d5 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/convnextv2.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/convnextv2.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/convnextv2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..858f86322482623b06e20f69c4b88c8ae75eb6a4 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/convnextv2.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/dense_motion.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/dense_motion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc20e1307f6a86564c2a630715bbb7204b0a698 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/dense_motion.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/dense_motion.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/dense_motion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b99bfd897f44b08208810951c51087c5462de619 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/dense_motion.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/motion_extractor.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/motion_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b0d8d85820a8770be6f2ca4ad3e3bc2ac87d050 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/motion_extractor.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/motion_extractor.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/motion_extractor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295edfb20175e947bd484c9ddf68e21ce6fe7f86 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/motion_extractor.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/spade_generator.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/spade_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a7771089392e07648152d82ff55876e3c17498f Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/spade_generator.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/spade_generator.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/spade_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05b3789777a193eb79f1d43d8be3bd1e62c957c Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/spade_generator.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/stitching_retargeting_network.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/stitching_retargeting_network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a447a381b2e64e284cb2b32b0e4ef95a8fcc8ac1 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/stitching_retargeting_network.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/stitching_retargeting_network.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/stitching_retargeting_network.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d203aafea85e559f6c129e1903c93fd1bf9d476 Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/stitching_retargeting_network.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/util.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ef1e3df94df943ffde20545e0341130165d0fe Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/util.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/util.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99af4b944ecc146fe2aedcb10f12257cac713a9a Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/util.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/warping_network.cpython-310.pyc b/subprocess/LivePortrait/src/modules/__pycache__/warping_network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76500301f1babdb5471737bc8370e4425f6b3f0b Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/warping_network.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/modules/__pycache__/warping_network.cpython-39.pyc b/subprocess/LivePortrait/src/modules/__pycache__/warping_network.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8152c62aa6a25998f9b438699401acdb81b00d7b Binary files /dev/null and b/subprocess/LivePortrait/src/modules/__pycache__/warping_network.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/modules/appearance_feature_extractor.py b/subprocess/LivePortrait/src/modules/appearance_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..8d89e4f18a2fbe58447f52ab4c5e3f2011a4ec80 --- /dev/null +++ b/subprocess/LivePortrait/src/modules/appearance_feature_extractor.py @@ -0,0 +1,48 @@ +# coding: utf-8 + +""" +Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. +""" + +import torch +from torch import nn +from .util import SameBlock2d, DownBlock2d, ResBlock3d + + +class AppearanceFeatureExtractor(nn.Module): + + def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): + super(AppearanceFeatureExtractor, self).__init__() + self.image_channel = image_channel + self.block_expansion = block_expansion + self.num_down_blocks = num_down_blocks + self.max_features = max_features + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + def forward(self, source_image): + out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 + + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape # ->Bx512x64x64 + + f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64 + f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 + return f_s diff --git a/subprocess/LivePortrait/src/modules/convnextv2.py b/subprocess/LivePortrait/src/modules/convnextv2.py new file mode 100644 index 0000000000000000000000000000000000000000..83ea12662b607854915df8c7abb160b588d330b1 --- /dev/null +++ b/subprocess/LivePortrait/src/modules/convnextv2.py @@ -0,0 +1,149 @@ +# coding: utf-8 + +""" +This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation. +""" + +import torch +import torch.nn as nn +# from timm.models.layers import trunc_normal_, DropPath +from .util import LayerNorm, DropPath, trunc_normal_, GRN + +__all__ = ['convnextv2_tiny'] + + +class Block(nn.Module): + """ ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=3, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0., + **kwargs + ): + super().__init__() + self.depths = depths + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + + # NOTE: the output semantic items + num_bins = kwargs.get('num_bins', 66) + num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints + self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints + + # print('dims[-1]: ', dims[-1]) + self.fc_scale = nn.Linear(dims[-1], 1) # scale + self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins + self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins + self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins + self.fc_t = nn.Linear(dims[-1], 3) # translation + self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x) + + # implicit keypoints + kp = self.fc_kp(x) + + # pose and expression deformation + pitch = self.fc_pitch(x) + yaw = self.fc_yaw(x) + roll = self.fc_roll(x) + t = self.fc_t(x) + exp = self.fc_exp(x) + scale = self.fc_scale(x) + + ret_dct = { + 'pitch': pitch, + 'yaw': yaw, + 'roll': roll, + 't': t, + 'exp': exp, + 'scale': scale, + + 'kp': kp, # canonical keypoint + } + + return ret_dct + + +def convnextv2_tiny(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + return model diff --git a/subprocess/LivePortrait/src/modules/dense_motion.py b/subprocess/LivePortrait/src/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..0eec0c46345f8854b125a51eaee730bd4ee77f7d --- /dev/null +++ b/subprocess/LivePortrait/src/modules/dense_motion.py @@ -0,0 +1,104 @@ +# coding: utf-8 + +""" +The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving +""" + +from torch import nn +import torch.nn.functional as F +import torch +from .util import Hourglass, make_coordinate_grid, kp2gaussian + + +class DenseMotionNetwork(nn.Module): + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True): + super(DenseMotionNetwork, self).__init__() + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G + self.norm = nn.BatchNorm3d(compress, affine=True) + self.num_kp = num_kp + self.flag_estimate_occlusion_map = estimate_occlusion_map + + if self.flag_estimate_occlusion_map: + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64) + identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3) + coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3) + + k = coordinate_grid.shape[1] + + # NOTE: there lacks an one-order flow + driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + # adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3) + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] # (d=16, h=64, w=64) + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) + heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w) + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64) + + feature = self.compress(feature) # (bs, 4, 16, 64, 64) + feature = self.norm(feature) # (bs, 4, 16, 64, 64) + feature = F.relu(feature) # (bs, 4, 16, 64, 64) + + out_dict = dict() + + # 1. deform 3d feature + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64) + + # 2. (bs, 1+num_kp, d, h, w) + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w) + + input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64) + input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64) + + prediction = self.hourglass(input) + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.flag_estimate_occlusion_map: + bs, _, d, h, w = prediction.shape + prediction_reshape = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64 + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/subprocess/LivePortrait/src/modules/motion_extractor.py b/subprocess/LivePortrait/src/modules/motion_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b2982e53c52d9ec1e0bec0453cc05edb51a15d23 --- /dev/null +++ b/subprocess/LivePortrait/src/modules/motion_extractor.py @@ -0,0 +1,35 @@ +# coding: utf-8 + +""" +Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image +""" + +from torch import nn +import torch + +from .convnextv2 import convnextv2_tiny +from .util import filter_state_dict + +model_dict = { + 'convnextv2_tiny': convnextv2_tiny, +} + + +class MotionExtractor(nn.Module): + def __init__(self, **kwargs): + super(MotionExtractor, self).__init__() + + # default is convnextv2_base + backbone = kwargs.get('backbone', 'convnextv2_tiny') + self.detector = model_dict.get(backbone)(**kwargs) + + def load_pretrained(self, init_path: str): + if init_path not in (None, ''): + state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model'] + state_dict = filter_state_dict(state_dict, remove_name='head') + ret = self.detector.load_state_dict(state_dict, strict=False) + print(f'Load pretrained model from {init_path}, ret: {ret}') + + def forward(self, x): + out = self.detector(x) + return out diff --git a/subprocess/LivePortrait/src/modules/spade_generator.py b/subprocess/LivePortrait/src/modules/spade_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..147a9aed0c7707fe6ae3d59ce1a30154ef75afcc --- /dev/null +++ b/subprocess/LivePortrait/src/modules/spade_generator.py @@ -0,0 +1,59 @@ +# coding: utf-8 + +""" +Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image. +""" + +import torch +from torch import nn +import torch.nn.functional as F +from .util import SPADEResnetBlock + + +class SPADEDecoder(nn.Module): + def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2): + for i in range(num_down_blocks): + input_channels = min(max_features, block_expansion * (2 ** (i + 1))) + self.upscale = upscale + super().__init__() + norm_G = 'spadespectralinstance' + label_num_channels = input_channels # 256 + + self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels) + self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels) + self.up = nn.Upsample(scale_factor=2) + + if self.upscale is None or self.upscale <= 1: + self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1) + else: + self.conv_img = nn.Sequential( + nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), + nn.PixelShuffle(upscale_factor=2) + ) + + def forward(self, feature): + seg = feature # Bx256x64x64 + x = self.fc(feature) # Bx512x64x64 + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + + x = self.up(x) # Bx512x64x64 -> Bx512x128x128 + x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 + x = self.up(x) # Bx256x128x128 -> Bx256x256x256 + x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW + x = torch.sigmoid(x) # Bx3xHxW + + return x \ No newline at end of file diff --git a/subprocess/LivePortrait/src/modules/stitching_retargeting_network.py b/subprocess/LivePortrait/src/modules/stitching_retargeting_network.py new file mode 100644 index 0000000000000000000000000000000000000000..5f50b7cf5a21cd71c70a7bbaaa4b6b68b4762ea3 --- /dev/null +++ b/subprocess/LivePortrait/src/modules/stitching_retargeting_network.py @@ -0,0 +1,38 @@ +# coding: utf-8 + +""" +Stitching module(S) and two retargeting modules(R) defined in the paper. + +- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in +the stitching region. + +- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially +when a person with small eyes drives a person with larger eyes. + +- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that +the lips are in a closed state, which facilitates better animation driving. +""" +from torch import nn + + +class StitchingRetargetingNetwork(nn.Module): + def __init__(self, input_size, hidden_sizes, output_size): + super(StitchingRetargetingNetwork, self).__init__() + layers = [] + for i in range(len(hidden_sizes)): + if i == 0: + layers.append(nn.Linear(input_size, hidden_sizes[i])) + else: + layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(hidden_sizes[-1], output_size)) + self.mlp = nn.Sequential(*layers) + + def initialize_weights_to_zero(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + return self.mlp(x) diff --git a/subprocess/LivePortrait/src/modules/util.py b/subprocess/LivePortrait/src/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f83980b24372bee38779ceeb3349fca91735e56e --- /dev/null +++ b/subprocess/LivePortrait/src/modules/util.py @@ -0,0 +1,441 @@ +# coding: utf-8 + +""" +This file defines various neural network modules and utility functions, including convolutional and residual blocks, +normalizations, and functions for spatial transformation and tensor manipulation. +""" + +from torch import nn +import torch.nn.functional as F +import torch +import torch.nn.utils.spectral_norm as spectral_norm +import math +import warnings + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp + + coordinate_grid = make_coordinate_grid(spatial_size, mean) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, ref, **kwargs): + d, h, w = spatial_size + x = torch.arange(w).type(ref.dtype).to(ref.device) + y = torch.arange(h).type(ref.dtype).to(ref.device) + z = torch.arange(d).type(ref.dtype).to(ref.device) + + # NOTE: must be right-down-in + x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right + y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom + z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ConvT2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1): + super(ConvT2d, self).__init__() + + self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride, + padding=padding, output_padding=output_padding) + self.norm = nn.InstanceNorm2d(out_features) + + def forward(self, x): + out = self.convT(x) + out = self.norm(out) + out = F.leaky_relu(out) + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.norm1 = nn.BatchNorm3d(in_features, affine=True) + self.norm2 = nn.BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = nn.BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + + +class GRN(nn.Module): + """ GRN (Global Response Normalization) layer + """ + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def drop_path(x, drop_prob=0., training=False, scale_by_keep=True): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/subprocess/LivePortrait/src/modules/warping_network.py b/subprocess/LivePortrait/src/modules/warping_network.py new file mode 100644 index 0000000000000000000000000000000000000000..9191a197055a954272ee8ed86c5e34f3f33f9ad5 --- /dev/null +++ b/subprocess/LivePortrait/src/modules/warping_network.py @@ -0,0 +1,77 @@ +# coding: utf-8 + +""" +Warping field estimator(W) defined in the paper, which generates a warping field using the implicit +keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s. +""" + +from torch import nn +import torch.nn.functional as F +from .util import SameBlock2d +from .dense_motion import DenseMotionNetwork + + +class WarpingNetwork(nn.Module): + def __init__( + self, + num_kp, + block_expansion, + max_features, + num_down_blocks, + reshape_channel, + estimate_occlusion_map=False, + dense_motion_params=None, + **kwargs + ): + super(WarpingNetwork, self).__init__() + + self.upscale = kwargs.get('upscale', 1) + self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True) + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork( + num_kp=num_kp, + feature_channel=reshape_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params + ) + else: + self.dense_motion_network = None + + self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + + def deform_input(self, inp, deformation): + return F.grid_sample(inp, deformation, align_corners=False) + + def forward(self, feature_3d, kp_driving, kp_source): + if self.dense_motion_network is not None: + # Feature warper, Transforming feature representation according to deformation and occlusion + dense_motion = self.dense_motion_network( + feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source + ) + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64 + else: + occlusion_map = None + + deformation = dense_motion['deformation'] # Bx16x64x64x3 + out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64 + + bs, c, d, h, w = out.shape # Bx32x16x64x64 + out = out.view(bs, c * d, h, w) # -> Bx512x64x64 + out = self.third(out) # -> Bx256x64x64 + out = self.fourth(out) # -> Bx256x64x64 + + if self.flag_use_occlusion_map and (occlusion_map is not None): + out = out * occlusion_map + + ret_dct = { + 'occlusion_map': occlusion_map, + 'deformation': deformation, + 'out': out, + } + + return ret_dct diff --git a/subprocess/LivePortrait/src/utils/__init__.py b/subprocess/LivePortrait/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/subprocess/LivePortrait/src/utils/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee1b01adc1301977b13390a8eb96e974a4db2fc6 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ef2e63a298f508062a9cb4f024075962e1ef335 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/camera.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/camera.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7d49cf3d36d2e08cc6c7465b4c519734390d70e Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/camera.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/camera.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/camera.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecda11528c7ab793603e4cf657d4d3cac2cea8c1 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/camera.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/crop.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/crop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01fa177b0f9b398a2474ddb1571210bed56ce68c Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/crop.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/crop.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/crop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce2d75f950bdb1db08e91a65a0e005375b6f74be Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/crop.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/cropper.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/cropper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b853cd4e4ef41571b5d8250cfc4b8698a6da6a32 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/cropper.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/cropper.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/cropper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3a0ffbc0f013fa466c59729f9becc3446b58b5a Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/cropper.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/face_analysis_diy.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/face_analysis_diy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76812d9f695c0487641d6904c0b817d50b169ad4 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/face_analysis_diy.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/face_analysis_diy.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/face_analysis_diy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b1f1dd2538d73b49a49cf400617f48366e607e6 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/face_analysis_diy.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/helper.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c31fe754bb833000e2dfa6434e396f0fcfa2736 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/helper.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/helper.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..461d2acb1e1f3bb0635655fd3e15a9cabe52c04b Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/helper.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/io.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d84ef682801ad4aeeff54a49eb08c4bd800a5b0d Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/io.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/io.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/io.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db0e5ddbb944eb44b92b4fc5022aba16e6eddef4 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/io.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/landmark_runner.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/landmark_runner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62165531a6125bbc11da885cde30b874b2274ffd Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/landmark_runner.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/landmark_runner.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/landmark_runner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4731971977b04a80821b7f4cd5b36d7e993fbec4 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/landmark_runner.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/retargeting_utils.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/retargeting_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897354aea89c2a7ec5c140d4651f8ca416d27985 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/retargeting_utils.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/retargeting_utils.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/retargeting_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..135a2c881e60f185c121c64cf374f5c8883d526d Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/retargeting_utils.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/rprint.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/rprint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edc64d81b11fe94116ab90cc42de6a9ff307ea10 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/rprint.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/rprint.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/rprint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f93e79edf8ed20c2a783c3edb9eeec846e51a175 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/rprint.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/timer.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/timer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47ffdeb755cc0f5bd8b1f7fb4067f720caad936a Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/timer.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/timer.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/timer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae03ab9fec2329866fa89f3f7f47ed6048cf7c8c Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/timer.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/video.cpython-310.pyc b/subprocess/LivePortrait/src/utils/__pycache__/video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..998449fe0a48928a7bf1176e00767476b56f8207 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/video.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/__pycache__/video.cpython-39.pyc b/subprocess/LivePortrait/src/utils/__pycache__/video.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc2bafdb9a766783fddd7302ee19e28f3af2b239 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/__pycache__/video.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/camera.py b/subprocess/LivePortrait/src/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..a3dd942697e1f00a96dc3efc75b883d98b52e525 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/camera.py @@ -0,0 +1,73 @@ +# coding: utf-8 + +""" +functions for processing and transforming 3D facial keypoints +""" + +import numpy as np +import torch +import torch.nn.functional as F + +PI = np.pi + + +def headpose_pred_to_degree(pred): + """ + pred: (bs, 66) or (bs, 1) or others + """ + if pred.ndim > 1 and pred.shape[1] == 66: + # NOTE: note that the average is modified to 97.5 + device = pred.device + idx_tensor = [idx for idx in range(0, 66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred, dim=1) + degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 + + return degree + + return pred + + +def get_rotation_matrix(pitch_, yaw_, roll_): + """ the input is in degree + """ + # transform to radian + pitch = pitch_ / 180 * PI + yaw = yaw_ / 180 * PI + roll = roll_ / 180 * PI + + device = pitch.device + + if pitch.ndim == 1: + pitch = pitch.unsqueeze(1) + if yaw.ndim == 1: + yaw = yaw.unsqueeze(1) + if roll.ndim == 1: + roll = roll.unsqueeze(1) + + # calculate the euler matrix + bs = pitch.shape[0] + ones = torch.ones([bs, 1]).to(device) + zeros = torch.zeros([bs, 1]).to(device) + x, y, z = pitch, yaw, roll + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([bs, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([bs, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([bs, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) # transpose diff --git a/subprocess/LivePortrait/src/utils/crop.py b/subprocess/LivePortrait/src/utils/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..065b9f0f9f25be8444b7c9bfca45652f80f5685b --- /dev/null +++ b/subprocess/LivePortrait/src/utils/crop.py @@ -0,0 +1,398 @@ +# coding: utf-8 + +""" +cropping function and the related preprocess functions for cropping +""" + +import numpy as np +import os.path as osp +from math import sin, cos, acos, degrees +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread +from .rprint import rprint as print + +DTYPE = np.float32 +CV2_INTERP = cv2.INTER_LINEAR + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + +def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): + """ conduct similarity or affine transformation to the image, do not do border operation! + img: + M: 2x3 matrix or 3x3 matrix + dsize: target shape (width, height) + """ + if isinstance(dsize, tuple) or isinstance(dsize, list): + _dsize = tuple(dsize) + else: + _dsize = (dsize, dsize) + + if borderMode is not None: + return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0)) + else: + return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags) + + +def _transform_pts(pts, M): + """ conduct similarity or affine transformation to the pts + pts: Nx2 ndarray + M: 2x3 matrix or 3x3 matrix + return: Nx2 + """ + return pts @ M[:2, :2].T + M[:2, 2] + + +def parse_pt2_from_pt101(pt101, use_lip=True): + """ + parsing the 2 points according to the 101 points, which cancels the roll + """ + # the former version use the eye center, but it is not robust, now use interpolation + pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center + pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center + + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt101[75] + pt101[81]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt106(pt106, use_lip=True): + """ + parsing the 2 points according to the 106 points, which cancels the roll + """ + pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center + pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center + + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt106[52] + pt106[61]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt203(pt203, use_lip=True): + """ + parsing the 2 points according to the 203 points, which cancels the roll + """ + pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center + pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt203[48] + pt203[66]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt68(pt68, use_lip=True): + """ + parsing the 2 points according to the 68 points, which cancels the roll + """ + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1 + if use_lip: + pt5 = np.stack([ + np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye + np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye + pt68[lm_idx[0], :], # nose + pt68[lm_idx[5], :], # lip + pt68[lm_idx[6], :] # lip + ], axis=0) + + pt2 = np.stack([ + (pt5[0] + pt5[1]) / 2, + (pt5[3] + pt5[4]) / 2 + ], axis=0) + else: + pt2 = np.stack([ + np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye + np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye + ], axis=0) + + return pt2 + + +def parse_pt2_from_pt5(pt5, use_lip=True): + """ + parsing the 2 points according to the 5 points, which cancels the roll + """ + if use_lip: + pt2 = np.stack([ + (pt5[0] + pt5[1]) / 2, + (pt5[3] + pt5[4]) / 2 + ], axis=0) + else: + pt2 = np.stack([ + pt5[0], + pt5[1] + ], axis=0) + return pt2 + + +def parse_pt2_from_pt_x(pts, use_lip=True): + if pts.shape[0] == 101: + pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip) + elif pts.shape[0] == 106: + pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip) + elif pts.shape[0] == 68: + pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip) + elif pts.shape[0] == 5: + pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip) + elif pts.shape[0] == 203: + pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip) + elif pts.shape[0] > 101: + # take the first 101 points + pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip) + else: + raise Exception(f'Unknow shape: {pts.shape}') + + if not use_lip: + # NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually + v = pt2[1] - pt2[0] + pt2[1, 0] = pt2[0, 0] - v[1] + pt2[1, 1] = pt2[0, 1] + v[0] + + return pt2 + + +def parse_rect_from_landmark( + pts, + scale=1.5, + need_square=True, + vx_ratio=0, + vy_ratio=0, + use_deg_flag=False, + **kwargs +): + """parsing center, size, angle from 101/68/5/x landmarks + vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size + vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area + + judge with pts.shape + """ + pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get('use_lip', True)) + + uy = pt2[1] - pt2[0] + l = np.linalg.norm(uy) + if l <= 1e-3: + uy = np.array([0, 1], dtype=DTYPE) + else: + uy /= l + ux = np.array((uy[1], -uy[0]), dtype=DTYPE) + + # the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system) + # print(uy) + # print(ux) + angle = acos(ux[0]) + if ux[1] < 0: + angle = -angle + + # rotation matrix + M = np.array([ux, uy]) + + # calculate the size which contains the angle degree of the bbox, and the center + center0 = np.mean(pts, axis=0) + rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T + lt_pt = np.min(rpts, axis=0) + rb_pt = np.max(rpts, axis=0) + center1 = (lt_pt + rb_pt) / 2 + + size = rb_pt - lt_pt + if need_square: + m = max(size[0], size[1]) + size[0] = m + size[1] = m + + size *= scale # scale size + center = center0 + ux * center1[0] + uy * center1[1] # counterclockwise rotation, equivalent to M.T @ center1.T + center = center + ux * (vx_ratio * size) + uy * \ + (vy_ratio * size) # considering the offset in vx and vy direction + + if use_deg_flag: + angle = degrees(angle) + + return center, size, angle + + +def parse_bbox_from_landmark(pts, **kwargs): + center, size, angle = parse_rect_from_landmark(pts, **kwargs) + cx, cy = center + w, h = size + + # calculate the vertex positions before rotation + bbox = np.array([ + [cx-w/2, cy-h/2], # left, top + [cx+w/2, cy-h/2], + [cx+w/2, cy+h/2], # right, bottom + [cx-w/2, cy+h/2] + ], dtype=DTYPE) + + # construct rotation matrix + bbox_rot = bbox.copy() + R = np.array([ + [np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)] + ], dtype=DTYPE) + + # calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center + bbox_rot = (bbox_rot - center) @ R.T + center + + return { + 'center': center, # 2x1 + 'size': size, # scalar + 'angle': angle, # rad, counterclockwise + 'bbox': bbox, # 4x2 + 'bbox_rot': bbox_rot, # 4x2 + } + + +def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs): + left, top, right, bot = bbox + if int(right - left) != int(bot - top): + print(f'right-left {right-left} != bot-top {bot-top}') + size = right - left + + src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE) + tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) + + s = dsize / size # scale + if flag_rot and angle is not None: + costheta, sintheta = cos(angle), sin(angle) + cx, cy = src_center[0], src_center[1] # ori center + tcx, tcy = tgt_center[0], tgt_center[1] # target center + # need to infer + M_o2c = np.array( + [[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], + [-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]], + dtype=DTYPE + ) + else: + M_o2c = np.array( + [[s, 0, tgt_center[0] - s * src_center[0]], + [0, s, tgt_center[1] - s * src_center[1]]], + dtype=DTYPE + ) + + # if flag_rot and angle is None: + # print('angle is None, but flag_rotate is True', style="bold yellow") + + img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None)) + lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None + + M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)]) + M_c2o = np.linalg.inv(M_o2c) + + # cv2.imwrite('crop.jpg', img_crop) + + return { + 'img_crop': img_crop, + 'lmk_crop': lmk_crop, + 'M_o2c': M_o2c, + 'M_c2o': M_c2o, + } + + +def _estimate_similar_transform_from_pts( + pts, + dsize, + scale=1.5, + vx_ratio=0, + vy_ratio=-0.1, + flag_do_rot=True, + **kwargs +): + """ calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image + pts: landmark, 101 or 68 points or other points, Nx2 + scale: the larger scale factor, the smaller face ratio + vx_ratio: x shift + vy_ratio: y shift, the smaller the y shift, the lower the face region + rot_flag: if it is true, conduct correction + """ + center, size, angle = parse_rect_from_landmark( + pts, scale=scale, vx_ratio=vx_ratio, vy_ratio=vy_ratio, + use_lip=kwargs.get('use_lip', True) + ) + + s = dsize / size[0] # scale + tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize + + if flag_do_rot: + costheta, sintheta = cos(angle), sin(angle) + cx, cy = center[0], center[1] # ori center + tcx, tcy = tgt_center[0], tgt_center[1] # target center + # need to infer + M_INV = np.array( + [[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], + [-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]], + dtype=DTYPE + ) + else: + M_INV = np.array( + [[s, 0, tgt_center[0] - s * center[0]], + [0, s, tgt_center[1] - s * center[1]]], + dtype=DTYPE + ) + + M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])]) + M = np.linalg.inv(M_INV_H) + + # M_INV is from the original image to the cropped image, M is from the cropped image to the original image + return M_INV, M[:2, ...] + + +def crop_image(img, pts: np.ndarray, **kwargs): + dsize = kwargs.get('dsize', 224) + scale = kwargs.get('scale', 1.5) # 1.5 | 1.6 + vy_ratio = kwargs.get('vy_ratio', -0.1) # -0.0625 | -0.1 + + M_INV, _ = _estimate_similar_transform_from_pts( + pts, + dsize=dsize, + scale=scale, + vy_ratio=vy_ratio, + flag_do_rot=kwargs.get('flag_do_rot', True), + ) + + img_crop = _transform_img(img, M_INV, dsize) # origin to crop + pt_crop = _transform_pts(pts, M_INV) + + M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)]) + M_c2o = np.linalg.inv(M_o2c) + + ret_dct = { + 'M_o2c': M_o2c, # from the original image to the cropped image 3x3 + 'M_c2o': M_c2o, # from the cropped image to the original image 3x3 + 'img_crop': img_crop, # the cropped image + 'pt_crop': pt_crop, # the landmarks of the cropped image + } + + return ret_dct + +def average_bbox_lst(bbox_lst): + if len(bbox_lst) == 0: + return None + bbox_arr = np.array(bbox_lst) + return np.mean(bbox_arr, axis=0).tolist() + +def prepare_paste_back(mask_crop, crop_M_c2o, dsize): + """prepare mask for later image paste back + """ + mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) + mask_ori = mask_ori.astype(np.float32) / 255. + return mask_ori + +def paste_back(img_crop, M_c2o, img_ori, mask_ori): + """paste back the image + """ + dsize = (img_ori.shape[1], img_ori.shape[0]) + result = _transform_img(img_crop, M_c2o, dsize=dsize) + result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8) + return result diff --git a/subprocess/LivePortrait/src/utils/cropper.py b/subprocess/LivePortrait/src/utils/cropper.py new file mode 100644 index 0000000000000000000000000000000000000000..81fe74d4cfe822b8a008e0dab88ec7fb75a52558 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/cropper.py @@ -0,0 +1,183 @@ +# coding: utf-8 + +import numpy as np +import os.path as osp +from typing import List, Union, Tuple +from dataclasses import dataclass, field +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) + +from ..config.crop_config import CropConfig +from .landmark_runner import LandmarkRunner +from .face_analysis_diy import FaceAnalysisDIY +from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst +from .rprint import rlog as log +from .io import contiguous + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +@dataclass +class Trajectory: + start: int = -1 # start frame + end: int = -1 # end frame + lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list + bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list + + frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list + + lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list + frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list + + +class Cropper(object): + def __init__(self, **kwargs) -> None: + device_id = kwargs.get('device_id', 0) + flag_force_cpu = kwargs.get('flag_force_cpu', False) + if flag_force_cpu: + device = 'cpu' + face_analysis_wrapper_provicer = ['CPUExecutionProvider'] + else: + device = 'cuda' + face_analysis_wrapper_provicer = ["CUDAExecutionProvider"] + self.landmark_runner = LandmarkRunner( + ckpt_path=make_abs_path('../../pretrained_weights/liveportrait/landmark.onnx'), + onnx_provider=device, + device_id=device_id + ) + self.landmark_runner.warmup() + + + self.face_analysis_wrapper = FaceAnalysisDIY( + name='buffalo_l', + root=make_abs_path('../../pretrained_weights/insightface'), + providers=face_analysis_wrapper_provicer + ) + self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512)) + self.face_analysis_wrapper.warmup() + + self.crop_cfg: CropConfig = kwargs.get('crop_cfg', None) + + def update_config(self, user_args): + for k, v in user_args.items(): + if hasattr(self.crop_cfg, k): + setattr(self.crop_cfg, k, v) + + def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig): + # crop a source image and get neccessary information + img_rgb = img_rgb_.copy() # copy it + + img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + src_face = self.face_analysis_wrapper.get( + img_bgr, + flag_do_landmark_2d_106=True, + direction=crop_cfg.direction, + max_face_num=crop_cfg.max_face_num, + ) + + if len(src_face) == 0: + log('No face detected in the source image.') + return None + elif len(src_face) > 1: + log(f'More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.') + + # NOTE: temporarily only pick the first face, to support multiple face in the future + src_face = src_face[0] + lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface + + # crop the face + ret_dct = crop_image( + img_rgb, # ndarray + lmk, # 106x2 or Nx2 + dsize=crop_cfg.dsize, + scale=crop_cfg.scale, + vx_ratio=crop_cfg.vx_ratio, + vy_ratio=crop_cfg.vy_ratio, + ) + + lmk = self.landmark_runner.run(img_rgb, lmk) + ret_dct['lmk_crop'] = lmk + + # update a 256x256 version for network input + ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA) + ret_dct['lmk_crop_256x256'] = ret_dct['lmk_crop'] * 256 / crop_cfg.dsize + + return ret_dct + + def crop_driving_video(self, driving_rgb_lst, **kwargs): + """Tracking based landmarks/alignment and cropping""" + trajectory = Trajectory() + direction = kwargs.get('direction', 'large-small') + for idx, frame_rgb in enumerate(driving_rgb_lst): + if idx == 0 or trajectory.start == -1: + src_face = self.face_analysis_wrapper.get( + contiguous(frame_rgb[..., ::-1]), + flag_do_landmark_2d_106=True, + direction=direction + ) + if len(src_face) == 0: + log(f'No face detected in the frame #{idx}') + continue + elif len(src_face) > 1: + log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.') + src_face = src_face[0] + lmk = src_face.landmark_2d_106 + lmk = self.landmark_runner.run(frame_rgb, lmk) + trajectory.start, trajectory.end = idx, idx + else: + lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1]) + trajectory.end = idx + + trajectory.lmk_lst.append(lmk) + ret_bbox = parse_bbox_from_landmark(lmk, scale=self.crop_cfg.scale_crop_video, vx_ratio_crop_video=self.crop_cfg.vx_ratio_crop_video, vy_ratio=self.crop_cfg.vy_ratio_crop_video)['bbox'] + bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4, + trajectory.bbox_lst.append(bbox) # bbox + trajectory.frame_rgb_lst.append(frame_rgb) + + global_bbox = average_bbox_lst(trajectory.bbox_lst) + + for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)): + ret_dct = crop_image_by_bbox( + frame_rgb, + global_bbox, + lmk=lmk, + dsize=kwargs.get('dsize', 512), + flag_rot=False, + borderValue=(0, 0, 0), + ) + trajectory.frame_rgb_crop_lst.append(ret_dct['img_crop']) + trajectory.lmk_crop_lst.append(ret_dct['lmk_crop']) + + return { + 'frame_crop_lst': trajectory.frame_rgb_crop_lst, + 'lmk_crop_lst': trajectory.lmk_crop_lst, + } + + def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs): + """Tracking based landmarks/alignment""" + trajectory = Trajectory() + direction = kwargs.get('direction', 'large-small') + + for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst): + if idx == 0 or trajectory.start == -1: + src_face = self.face_analysis_wrapper.get( + contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR + flag_do_landmark_2d_106=True, + direction=direction + ) + if len(src_face) == 0: + log(f'No face detected in the frame #{idx}') + raise Exception(f'No face detected in the frame #{idx}') + elif len(src_face) > 1: + log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.') + src_face = src_face[0] + lmk = src_face.landmark_2d_106 + lmk = self.landmark_runner.run(frame_rgb_crop, lmk) + trajectory.start, trajectory.end = idx, idx + else: + lmk = self.landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1]) + trajectory.end = idx + + trajectory.lmk_lst.append(lmk) + return trajectory.lmk_lst diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/__init__.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1680083da47850b31da10803c7d255e67dda619a --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/__init__.py @@ -0,0 +1,20 @@ +# coding: utf-8 +# pylint: disable=wrong-import-position +"""InsightFace: A Face Analysis Toolkit.""" +from __future__ import absolute_import + +try: + #import mxnet as mx + import onnxruntime +except ImportError: + raise ImportError( + "Unable to import dependency onnxruntime. " + ) + +__version__ = '0.7.3' + +from . import model_zoo +from . import utils +from . import app +from . import data + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96ddc76c2a1c2c28d227153505b3d68689adfd12 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91c8c652d2e12d21dfaeb5258071ac6f9d816109 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__init__.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc574616885290489798bac5c682e7aaa65a5dad --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__init__.py @@ -0,0 +1 @@ +from .face_analysis import * diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e0b6b8af4332244a67a738907a1cf2fd76de817 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a172f74cdb808a6e78124d2a4c8b092ee50576c5 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/common.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4496283ca1ea4d0d6d785520378d7f5682ecf827 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/common.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/common.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d6c828aa6386f1cb02133c6050110b60258566c Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/common.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/face_analysis.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/face_analysis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f9368d2e9c43c28e8d750678aacb70b814e5822 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/face_analysis.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/face_analysis.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/face_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73923994f78aaedf18d55454e47da4aedbdaa0ab Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/__pycache__/face_analysis.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/common.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/common.py new file mode 100644 index 0000000000000000000000000000000000000000..82ca987aeede35510b3aef72b4edf2390ad84e65 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/common.py @@ -0,0 +1,49 @@ +import numpy as np +from numpy.linalg import norm as l2norm +#from easydict import EasyDict + +class Face(dict): + + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + #for k in self.__class__.__dict__.keys(): + # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + # setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(Face, self).__setattr__(name, value) + super(Face, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def __getattr__(self, name): + return None + + @property + def embedding_norm(self): + if self.embedding is None: + return None + return l2norm(self.embedding) + + @property + def normed_embedding(self): + if self.embedding is None: + return None + return self.embedding / self.embedding_norm + + @property + def sex(self): + if self.gender is None: + return None + return 'M' if self.gender==1 else 'F' diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/app/face_analysis.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/face_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5128b3f5e02c2c19e7df195cc1c1e7fcf36c4d --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/app/face_analysis.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + + +from __future__ import division + +import glob +import os.path as osp + +import numpy as np +import onnxruntime +from numpy.linalg import norm + +from ..model_zoo import model_zoo +from ..utils import ensure_available +from .common import Face + + +DEFAULT_MP_NAME = 'buffalo_l' +__all__ = ['FaceAnalysis'] + +class FaceAnalysis: + def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): + onnxruntime.set_default_logger_severity(3) + self.models = {} + self.model_dir = ensure_available('models', name, root=root) + onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) + onnx_files = sorted(onnx_files) + for onnx_file in onnx_files: + model = model_zoo.get_model(onnx_file, **kwargs) + if model is None: + print('model not recognized:', onnx_file) + elif allowed_modules is not None and model.taskname not in allowed_modules: + print('model ignore:', onnx_file, model.taskname) + del model + elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): + # print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std) + self.models[model.taskname] = model + else: + print('duplicated model task type, ignore:', onnx_file, model.taskname) + del model + assert 'detection' in self.models + self.det_model = self.models['detection'] + + + def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): + self.det_thresh = det_thresh + assert det_size is not None + # print('set det-size:', det_size) + self.det_size = det_size + for taskname, model in self.models.items(): + if taskname=='detection': + model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) + else: + model.prepare(ctx_id) + + def get(self, img, max_num=0): + bboxes, kpss = self.det_model.detect(img, + max_num=max_num, + metric='default') + if bboxes.shape[0] == 0: + return [] + ret = [] + for i in range(bboxes.shape[0]): + bbox = bboxes[i, 0:4] + det_score = bboxes[i, 4] + kps = None + if kpss is not None: + kps = kpss[i] + face = Face(bbox=bbox, kps=kps, det_score=det_score) + for taskname, model in self.models.items(): + if taskname=='detection': + continue + model.get(img, face) + ret.append(face) + return ret + + def draw_on(self, img, faces): + import cv2 + dimg = img.copy() + for i in range(len(faces)): + face = faces[i] + box = face.bbox.astype(np.int) + color = (0, 0, 255) + cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2) + if face.kps is not None: + kps = face.kps.astype(np.int) + #print(landmark.shape) + for l in range(kps.shape[0]): + color = (0, 0, 255) + if l == 0 or l == 3: + color = (0, 255, 0) + cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color, + 2) + if face.gender is not None and face.age is not None: + cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1) + + #for key, value in face.items(): + # if key.startswith('landmark_3d'): + # print(key, value.shape) + # print(value[0:10,:]) + # lmk = np.round(value).astype(np.int) + # for l in range(lmk.shape[0]): + # color = (255, 0, 0) + # cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color, + # 2) + return dimg diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__init__.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..665c59ec99b6ebf12822015e0350969c7903e243 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__init__.py @@ -0,0 +1,2 @@ +from .image import get_image +from .pickle_object import get_object diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0458745f1071c91320c4855a74331f554eabdf6a Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4c1e17e7bf3261fa3a3f81a5482f2d07d8a4f8 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/image.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef2dae902d902c7a118522cfe43ac2665dedb175 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/image.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/image.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/image.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ea67ece2b1d00fecbd29172fc37087e600da86c Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/image.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/pickle_object.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/pickle_object.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25cc94712dcd59ba2e37557545864d9b242650cd Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/pickle_object.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/pickle_object.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/pickle_object.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cbc22d0119eba6fd12c4b9aa0275bfc6c5d0688 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/__pycache__/pickle_object.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/image.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6d32c4bcb1b13d33bcb0d840cf7b8c08d183b3ea --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/image.py @@ -0,0 +1,27 @@ +import cv2 +import os +import os.path as osp +from pathlib import Path + +class ImageCache: + data = {} + +def get_image(name, to_rgb=False): + key = (name, to_rgb) + if key in ImageCache.data: + return ImageCache.data[key] + images_dir = osp.join(Path(__file__).parent.absolute(), 'images') + ext_names = ['.jpg', '.png', '.jpeg'] + image_file = None + for ext_name in ext_names: + _image_file = osp.join(images_dir, "%s%s"%(name, ext_name)) + if osp.exists(_image_file): + image_file = _image_file + break + assert image_file is not None, '%s not found'%name + img = cv2.imread(image_file) + if to_rgb: + img = img[:,:,::-1] + ImageCache.data[key] = img + return img + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png new file mode 100644 index 0000000000000000000000000000000000000000..906315d13fa29bb3a5ded3e162592f2c7f041b23 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_black.jpg b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_black.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0eab0df555c23f1e033537fe39f3c0c8303dd369 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_black.jpg differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_blue.jpg b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_blue.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f71336b9a0d3038ebd84e6995ebfbe54946fcbb4 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_blue.jpg differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_green.jpg b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_green.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac2ad55f4fc580c915dfa4c157ca3bfc84e453f4 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_green.jpg differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_white.jpg b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_white.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2148ab2d09fdee6e3f59315470e98ecfc54339e4 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/mask_white.jpg differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/t1.jpg b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/t1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d1d64a59675c9590fd12429db647eb169cecff8 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/images/t1.jpg differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/objects/meanshape_68.pkl b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/objects/meanshape_68.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d5297e9e8ea5574298ddd287b058252e03aa18c1 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/objects/meanshape_68.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39ffecf84ba73f0d0d7e49380833ba88713c9fcdec51df4f7ac45a48b8f4cc51 +size 974 diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/pickle_object.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/pickle_object.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd87030ea15e1d01af1cd4cff1be2bc54cc82dd --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/pickle_object.py @@ -0,0 +1,17 @@ +import cv2 +import os +import os.path as osp +from pathlib import Path +import pickle + +def get_object(name): + objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects') + if not name.endswith('.pkl'): + name = name+".pkl" + filepath = osp.join(objects_dir, name) + if not osp.exists(filepath): + return None + with open(filepath, 'rb') as f: + obj = pickle.load(f) + return obj + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/data/rec_builder.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/rec_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e02abc969da2f882639326f5bad3c7e8d08c1fde --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/data/rec_builder.py @@ -0,0 +1,71 @@ +import pickle +import numpy as np +import os +import os.path as osp +import sys +import mxnet as mx + + +class RecBuilder(): + def __init__(self, path, image_size=(112, 112)): + self.path = path + self.image_size = image_size + self.widx = 0 + self.wlabel = 0 + self.max_label = -1 + assert not osp.exists(path), '%s exists' % path + os.makedirs(path) + self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'), + os.path.join(path, 'train.rec'), + 'w') + self.meta = [] + + def add(self, imgs): + #!!! img should be BGR!!!! + #assert label >= 0 + #assert label > self.last_label + assert len(imgs) > 0 + label = self.wlabel + for img in imgs: + idx = self.widx + image_meta = {'image_index': idx, 'image_classes': [label]} + header = mx.recordio.IRHeader(0, label, idx, 0) + if isinstance(img, np.ndarray): + s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') + else: + s = mx.recordio.pack(header, img) + self.writer.write_idx(idx, s) + self.meta.append(image_meta) + self.widx += 1 + self.max_label = label + self.wlabel += 1 + + + def add_image(self, img, label): + #!!! img should be BGR!!!! + #assert label >= 0 + #assert label > self.last_label + idx = self.widx + header = mx.recordio.IRHeader(0, label, idx, 0) + if isinstance(label, list): + idlabel = label[0] + else: + idlabel = label + image_meta = {'image_index': idx, 'image_classes': [idlabel]} + if isinstance(img, np.ndarray): + s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') + else: + s = mx.recordio.pack(header, img) + self.writer.write_idx(idx, s) + self.meta.append(image_meta) + self.widx += 1 + self.max_label = max(self.max_label, idlabel) + + def close(self): + with open(osp.join(self.path, 'train.meta'), 'wb') as pfile: + pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL) + print('stat:', self.widx, self.wlabel) + with open(os.path.join(self.path, 'property'), 'w') as f: + f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1])) + f.write("%d\n" % (self.widx)) + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__init__.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..225623d6142c968b4040f391039bfab88bdd1b2a --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__init__.py @@ -0,0 +1,6 @@ +from .model_zoo import get_model +from .arcface_onnx import ArcFaceONNX +from .retinaface import RetinaFace +from .scrfd import SCRFD +from .landmark import Landmark +from .attribute import Attribute diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84439bd0027534962db5138e5e42e614939dadca Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2617284bfb240b557f2b2e0a4d24b0e10b36faf0 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/arcface_onnx.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/arcface_onnx.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cebee26f3f5c4877a0ab23ee61387e3f37a7e94 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/arcface_onnx.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/arcface_onnx.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/arcface_onnx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c7c56f8a154a5f74fa01375aa62dbc7c2256a32 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/arcface_onnx.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/attribute.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/attribute.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d7bf7b98786a0a2e577e3ef758d25657514ed7 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/attribute.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/attribute.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/attribute.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95e2f8ae781b977e96a6f3768e64eaef111669d0 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/attribute.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/inswapper.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/inswapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f05ba7999457cdc9ef70f626a276f9778f9f17b Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/inswapper.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/inswapper.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/inswapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6694ae931894c1d67953c92dfbc9a2c8a6c7982 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/inswapper.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/landmark.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/landmark.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25d27f4ad8d657d4a805e473f7176a563331d2cc Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/landmark.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/landmark.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/landmark.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3a28c3cef8a60ab8279b70588f04bd2331e946e Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/landmark.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/model_zoo.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/model_zoo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e52cf2788a260a75b3881f16500b20325aa9d18 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/model_zoo.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/model_zoo.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/model_zoo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a9ac67ef0bbb3290f54ca05dff8f04efbf01a14 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/model_zoo.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/retinaface.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/retinaface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f72df6676855d2770e1bc2634b3910a3925d0e02 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/retinaface.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/retinaface.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/retinaface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b49c05d2a8261864a66cb47f98e37cdcf51fceb Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/retinaface.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/scrfd.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/scrfd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2e5ff0293a82b16eaaaab1f34aa41411cb6ad44 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/scrfd.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/scrfd.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/scrfd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1d0c7ce2ef48607db7c212d97ca474d8738b643 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/__pycache__/scrfd.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/arcface_onnx.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/arcface_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b537ce2ee15d4a1834d54e185f34e336aab30a77 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/arcface_onnx.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +from __future__ import division +import numpy as np +import cv2 +import onnx +import onnxruntime +from ..utils import face_align + +__all__ = [ + 'ArcFaceONNX', +] + + +class ArcFaceONNX: + def __init__(self, model_file=None, session=None): + assert model_file is not None + self.model_file = model_file + self.session = session + self.taskname = 'recognition' + find_sub = False + find_mul = False + model = onnx.load(self.model_file) + graph = model.graph + for nid, node in enumerate(graph.node[:8]): + #print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + if find_sub and find_mul: + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + #print('input mean and std:', self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + input_name = input_cfg.name + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + assert len(self.output_names)==1 + self.output_shape = outputs[0].shape + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + + def get(self, img, face): + aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) + face.embedding = self.get_feat(aimg).flatten() + return face.embedding + + def compute_sim(self, feat1, feat2): + from numpy.linalg import norm + feat1 = feat1.ravel() + feat2 = feat2.ravel() + sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) + return sim + + def get_feat(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.input_size + + blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + def forward(self, batch_data): + blob = (batch_data - self.input_mean) / self.input_std + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/attribute.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..40c34de3f0995499448cf5779004cc1e5f3564fb --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/attribute.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-06-19 +# @Function : + +from __future__ import division +import numpy as np +import cv2 +import onnx +import onnxruntime +from ..utils import face_align + +__all__ = [ + 'Attribute', +] + + +class Attribute: + def __init__(self, model_file=None, session=None): + assert model_file is not None + self.model_file = model_file + self.session = session + find_sub = False + find_mul = False + model = onnx.load(self.model_file) + graph = model.graph + for nid, node in enumerate(graph.node[:8]): + #print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + if nid<3 and node.name=='bn_data': + find_sub = True + find_mul = True + if find_sub and find_mul: + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 128.0 + self.input_mean = input_mean + self.input_std = input_std + #print('input mean and std:', model_file, self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + input_name = input_cfg.name + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + assert len(self.output_names)==1 + output_shape = outputs[0].shape + #print('init output_shape:', output_shape) + if output_shape[1]==3: + self.taskname = 'genderage' + else: + self.taskname = 'attribute_%d'%output_shape[1] + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + + def get(self, img, face): + bbox = face.bbox + w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) + center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 + rotate = 0 + _scale = self.input_size[0] / (max(w, h)*1.5) + #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) + aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) + input_size = tuple(aimg.shape[0:2][::-1]) + #assert input_size==self.input_size + blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] + if self.taskname=='genderage': + assert len(pred)==3 + gender = np.argmax(pred[:2]) + age = int(np.round(pred[2]*100)) + face['gender'] = gender + face['age'] = age + return gender, age + else: + return pred + + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/inswapper.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/inswapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f321c627ee66cceddcab98b561b997441dd4f768 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/inswapper.py @@ -0,0 +1,114 @@ +import time +import numpy as np +import onnxruntime +import cv2 +import onnx +from onnx import numpy_helper +from ..utils import face_align + + + + +class INSwapper(): + def __init__(self, model_file=None, session=None): + self.model_file = model_file + self.session = session + model = onnx.load(self.model_file) + graph = model.graph + self.emap = numpy_helper.to_array(graph.initializer[-1]) + self.input_mean = 0.0 + self.input_std = 255.0 + #print('input mean and std:', model_file, self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + inputs = self.session.get_inputs() + self.input_names = [] + for inp in inputs: + self.input_names.append(inp.name) + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.output_names = output_names + assert len(self.output_names)==1 + output_shape = outputs[0].shape + input_cfg = inputs[0] + input_shape = input_cfg.shape + self.input_shape = input_shape + # print('inswapper-shape:', self.input_shape) + self.input_size = tuple(input_shape[2:4][::-1]) + + def forward(self, img, latent): + img = (img - self.input_mean) / self.input_std + pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] + return pred + + def get(self, img, target_face, source_face, paste_back=True): + face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8) + cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1) + aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) + blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + latent = source_face.normed_embedding.reshape((1,-1)) + latent = np.dot(latent, self.emap) + latent /= np.linalg.norm(latent) + pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0] + #print(latent.shape, latent.dtype, pred.shape) + img_fake = pred.transpose((0,2,3,1))[0] + bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] + if not paste_back: + return bgr_fake, M + else: + target_img = img + fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) + fake_diff = np.abs(fake_diff).mean(axis=2) + fake_diff[:2,:] = 0 + fake_diff[-2:,:] = 0 + fake_diff[:,:2] = 0 + fake_diff[:,-2:] = 0 + IM = cv2.invertAffineTransform(M) + img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) + bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white[img_white>20] = 255 + fthresh = 10 + fake_diff[fake_diff=fthresh] = 255 + img_mask = img_white + mask_h_inds, mask_w_inds = np.where(img_mask==255) + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h*mask_w)) + k = max(mask_size//10, 10) + #k = max(mask_size//20, 6) + #k = 6 + kernel = np.ones((k,k),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel = np.ones((2,2),np.uint8) + fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) + + face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1) + fake_diff[face_mask==1] = 255 + + k = max(mask_size//20, 5) + #k = 3 + #k = 3 + kernel_size = (k, k) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + k = 5 + kernel_size = (k, k) + blur_size = tuple(2*i+1 for i in kernel_size) + fake_diff = cv2.blur(fake_diff, (11,11), 0) + ##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) + # print('blur_size: ', blur_size) + # fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size + img_mask /= 255 + fake_diff /= 255 + # img_mask = fake_diff + img_mask = img_mask*fake_diff + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) + fake_merged = fake_merged.astype(np.uint8) + return fake_merged diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/landmark.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/landmark.py new file mode 100644 index 0000000000000000000000000000000000000000..598b4b29a2d0674d8bb25b681f921c61460d101c --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/landmark.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +from __future__ import division +import numpy as np +import cv2 +import onnx +import onnxruntime +from ..utils import face_align +from ..utils import transform +from ..data import get_object + +__all__ = [ + 'Landmark', +] + + +class Landmark: + def __init__(self, model_file=None, session=None): + assert model_file is not None + self.model_file = model_file + self.session = session + find_sub = False + find_mul = False + model = onnx.load(self.model_file) + graph = model.graph + for nid, node in enumerate(graph.node[:8]): + #print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + if nid<3 and node.name=='bn_data': + find_sub = True + find_mul = True + if find_sub and find_mul: + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 128.0 + self.input_mean = input_mean + self.input_std = input_std + #print('input mean and std:', model_file, self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + input_name = input_cfg.name + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + assert len(self.output_names)==1 + output_shape = outputs[0].shape + self.require_pose = False + #print('init output_shape:', output_shape) + if output_shape[1]==3309: + self.lmk_dim = 3 + self.lmk_num = 68 + self.mean_lmk = get_object('meanshape_68.pkl') + self.require_pose = True + else: + self.lmk_dim = 2 + self.lmk_num = output_shape[1]//self.lmk_dim + self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num) + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + + def get(self, img, face): + bbox = face.bbox + w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) + center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 + rotate = 0 + _scale = self.input_size[0] / (max(w, h)*1.5) + #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) + aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) + input_size = tuple(aimg.shape[0:2][::-1]) + #assert input_size==self.input_size + blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] + if pred.shape[0] >= 3000: + pred = pred.reshape((-1, 3)) + else: + pred = pred.reshape((-1, 2)) + if self.lmk_num < pred.shape[0]: + pred = pred[self.lmk_num*-1:,:] + pred[:, 0:2] += 1 + pred[:, 0:2] *= (self.input_size[0] // 2) + if pred.shape[1] == 3: + pred[:, 2] *= (self.input_size[0] // 2) + + IM = cv2.invertAffineTransform(M) + pred = face_align.trans_points(pred, IM) + face[self.taskname] = pred + if self.require_pose: + P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred) + s, R, t = transform.P2sRt(P) + rx, ry, rz = transform.matrix2angle(R) + pose = np.array( [rx, ry, rz], dtype=np.float32 ) + face['pose'] = pose #pitch, yaw, roll + return pred + + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/model_store.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/model_store.py new file mode 100644 index 0000000000000000000000000000000000000000..50bb85d314f5b7a0ea8211d2cd21186e32791592 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/model_store.py @@ -0,0 +1,103 @@ +""" +This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py +""" +from __future__ import print_function + +__all__ = ['get_model_file'] +import os +import zipfile +import glob + +from ..utils import download, check_sha1 + +_model_sha1 = { + name: checksum + for checksum, name in [ + ('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'), + ('', 'arcface_mfn_v1'), + ('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'), + ('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'), + ('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'), + ('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'), + ] +} + +base_repo_url = 'https://insightface.ai/files/' +_url_format = '{repo_url}models/{file_name}.zip' + + +def short_hash(name): + if name not in _model_sha1: + raise ValueError( + 'Pretrained model for {name} is not available.'.format(name=name)) + return _model_sha1[name][:8] + + +def find_params_file(dir_path): + if not os.path.exists(dir_path): + return None + paths = glob.glob("%s/*.params" % dir_path) + if len(paths) == 0: + return None + paths = sorted(paths) + return paths[-1] + + +def get_model_file(name, root=os.path.join('~', '.insightface', 'models')): + r"""Return location for the pretrained on local file system. + + This function will download from online model zoo when model cannot be found or has mismatch. + The root directory will be created if it doesn't exist. + + Parameters + ---------- + name : str + Name of the model. + root : str, default '~/.mxnet/models' + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + + file_name = name + root = os.path.expanduser(root) + dir_path = os.path.join(root, name) + file_path = find_params_file(dir_path) + #file_path = os.path.join(root, file_name + '.params') + sha1_hash = _model_sha1[name] + if file_path is not None: + if check_sha1(file_path, sha1_hash): + return file_path + else: + print( + 'Mismatch in the content of model file detected. Downloading again.' + ) + else: + print('Model file is not found. Downloading.') + + if not os.path.exists(root): + os.makedirs(root) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + zip_file_path = os.path.join(root, file_name + '.zip') + repo_url = base_repo_url + if repo_url[-1] != '/': + repo_url = repo_url + '/' + download(_url_format.format(repo_url=repo_url, file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(dir_path) + os.remove(zip_file_path) + file_path = find_params_file(dir_path) + + if check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError( + 'Downloaded file has different hash. Please try again.') + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/model_zoo.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..d8366e2a5461d5d6688f23e102a40944330084a4 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/model_zoo.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +import os +import os.path as osp +import glob +import onnxruntime +from .arcface_onnx import * +from .retinaface import * +#from .scrfd import * +from .landmark import * +from .attribute import Attribute +from .inswapper import INSwapper +from ..utils import download_onnx + +__all__ = ['get_model'] + + +class PickableInferenceSession(onnxruntime.InferenceSession): + # This is a wrapper to make the current InferenceSession class pickable. + def __init__(self, model_path, **kwargs): + super().__init__(model_path, **kwargs) + self.model_path = model_path + + def __getstate__(self): + return {'model_path': self.model_path} + + def __setstate__(self, values): + model_path = values['model_path'] + self.__init__(model_path) + +class ModelRouter: + def __init__(self, onnx_file): + self.onnx_file = onnx_file + + def get_model(self, **kwargs): + session = PickableInferenceSession(self.onnx_file, **kwargs) + # print(f'Applied providers: {session._providers}, with options: {session._provider_options}') + inputs = session.get_inputs() + input_cfg = inputs[0] + input_shape = input_cfg.shape + outputs = session.get_outputs() + + if len(outputs)>=5: + return RetinaFace(model_file=self.onnx_file, session=session) + elif input_shape[2]==192 and input_shape[3]==192: + return Landmark(model_file=self.onnx_file, session=session) + elif input_shape[2]==96 and input_shape[3]==96: + return Attribute(model_file=self.onnx_file, session=session) + elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128: + return INSwapper(model_file=self.onnx_file, session=session) + elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0: + return ArcFaceONNX(model_file=self.onnx_file, session=session) + else: + #raise RuntimeError('error on model routing') + return None + +def find_onnx_file(dir_path): + if not os.path.exists(dir_path): + return None + paths = glob.glob("%s/*.onnx" % dir_path) + if len(paths) == 0: + return None + paths = sorted(paths) + return paths[-1] + +def get_default_providers(): + return ['CUDAExecutionProvider', 'CPUExecutionProvider'] + +def get_default_provider_options(): + return None + +def get_model(name, **kwargs): + root = kwargs.get('root', '~/.insightface') + root = os.path.expanduser(root) + model_root = osp.join(root, 'models') + allow_download = kwargs.get('download', False) + download_zip = kwargs.get('download_zip', False) + if not name.endswith('.onnx'): + model_dir = os.path.join(model_root, name) + model_file = find_onnx_file(model_dir) + if model_file is None: + return None + else: + model_file = name + if not osp.exists(model_file) and allow_download: + model_file = download_onnx('models', model_file, root=root, download_zip=download_zip) + assert osp.exists(model_file), 'model_file %s should exist'%model_file + assert osp.isfile(model_file), 'model_file %s should be a file'%model_file + router = ModelRouter(model_file) + providers = kwargs.get('providers', get_default_providers()) + provider_options = kwargs.get('provider_options', get_default_provider_options()) + model = router.get_model(providers=providers, provider_options=provider_options) + return model diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/retinaface.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4ad91ed70688b38503127137e928dc7e5433e1 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/retinaface.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-09-18 +# @Function : + +from __future__ import division +import datetime +import numpy as np +import onnx +import onnxruntime +import os +import os.path as osp +import cv2 +import sys + +def softmax(z): + assert len(z.shape) == 2 + s = np.max(z, axis=1) + s = s[:, np.newaxis] # necessary step to do broadcasting + e_x = np.exp(z - s) + div = np.sum(e_x, axis=1) + div = div[:, np.newaxis] # dito + return e_x / div + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + +class RetinaFace: + def __init__(self, model_file=None, session=None): + import onnxruntime + self.model_file = model_file + self.session = session + self.taskname = 'detection' + if self.session is None: + assert self.model_file is not None + assert osp.exists(self.model_file) + self.session = onnxruntime.InferenceSession(self.model_file, None) + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + #print(input_shape) + if isinstance(input_shape[2], str): + self.input_size = None + else: + self.input_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + #print(self.output_names) + #assert len(outputs)==10 or len(outputs)==15 + self.use_kps = False + self._anchor_ratio = 1.0 + self._num_anchors = 1 + if len(outputs)==6: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + elif len(outputs)==9: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + elif len(outputs)==10: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + elif len(outputs)==15: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + self.use_kps = True + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None: + if self.input_size is not None: + print('warning: det_size is already set in detection model, ignore') + else: + self.input_size = input_size + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_outs = self.session.run(self.output_names, {self.input_name : blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx+fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx+fmc*2] * stride + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-1, c style: + #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + #for i in range(height): + # anchor_centers[i, :, 1] = i + #for i in range(width): + # anchor_centers[:, i, 0] = i + + #solution-2: + #ax = np.arange(width, dtype=np.float32) + #ay = np.arange(height, dtype=np.float32) + #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + #print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + #kpss = kps_preds + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size = None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + +def get_retinaface(name, download=False, root='~/.insightface/models', **kwargs): + if not download: + assert os.path.exists(name) + return RetinaFace(name) + else: + from .model_store import get_model_file + _file = get_model_file("retinaface_%s" % name, root=root) + return retinaface(_file) + + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/scrfd.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/scrfd.py new file mode 100644 index 0000000000000000000000000000000000000000..674db4bba761157592dfb95c5d1638da1099f89c --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/model_zoo/scrfd.py @@ -0,0 +1,348 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +from __future__ import division +import datetime +import numpy as np +import onnx +import onnxruntime +import os +import os.path as osp +import cv2 +import sys + +def softmax(z): + assert len(z.shape) == 2 + s = np.max(z, axis=1) + s = s[:, np.newaxis] # necessary step to do broadcasting + e_x = np.exp(z - s) + div = np.sum(e_x, axis=1) + div = div[:, np.newaxis] # dito + return e_x / div + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + +class SCRFD: + def __init__(self, model_file=None, session=None): + import onnxruntime + self.model_file = model_file + self.session = session + self.taskname = 'detection' + self.batched = False + if self.session is None: + assert self.model_file is not None + assert osp.exists(self.model_file) + self.session = onnxruntime.InferenceSession(self.model_file, None) + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + #print(input_shape) + if isinstance(input_shape[2], str): + self.input_size = None + else: + self.input_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + self.input_shape = input_shape + outputs = self.session.get_outputs() + if len(outputs[0].shape) == 3: + self.batched = True + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + #print(self.output_names) + #assert len(outputs)==10 or len(outputs)==15 + self.use_kps = False + self._anchor_ratio = 1.0 + self._num_anchors = 1 + if len(outputs)==6: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + elif len(outputs)==9: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + elif len(outputs)==10: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + elif len(outputs)==15: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + self.use_kps = True + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None: + if self.input_size is not None: + print('warning: det_size is already set in scrfd model, ignore') + else: + self.input_size = input_size + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_outs = self.session.run(self.output_names, {self.input_name : blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + # If model support batch dim, take first output + if self.batched: + scores = net_outs[idx][0] + bbox_preds = net_outs[idx + fmc][0] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2][0] * stride + # If model doesn't support batching take output as is + else: + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-1, c style: + #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + #for i in range(height): + # anchor_centers[i, :, 1] = i + #for i in range(width): + # anchor_centers[:, i, 0] = i + + #solution-2: + #ax = np.arange(width, dtype=np.float32) + #ay = np.arange(height, dtype=np.float32) + #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + #print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + #kpss = kps_preds + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size = None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + +def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs): + if not download: + assert os.path.exists(name) + return SCRFD(name) + else: + from .model_store import get_model_file + _file = get_model_file("scrfd_%s" % name, root=root) + return SCRFD(_file) + + +def scrfd_2p5gkps(**kwargs): + return get_scrfd("2p5gkps", download=True, **kwargs) + + +if __name__ == '__main__': + import glob + detector = SCRFD(model_file='./det.onnx') + detector.prepare(-1) + img_paths = ['tests/data/t1.jpg'] + for img_path in img_paths: + img = cv2.imread(img_path) + + for _ in range(1): + ta = datetime.datetime.now() + #bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640)) + bboxes, kpss = detector.detect(img, 0.5) + tb = datetime.datetime.now() + print('all cost:', (tb-ta).total_seconds()*1000) + print(img_path, bboxes.shape) + if kpss is not None: + print(kpss.shape) + for i in range(bboxes.shape[0]): + bbox = bboxes[i] + x1,y1,x2,y2,score = bbox.astype(np.int) + cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2) + if kpss is not None: + kps = kpss[i] + for kp in kps: + kp = kp.astype(np.int) + cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2) + filename = img_path.split('/')[-1] + print('output:', filename) + cv2.imwrite('./outputs/%s'%filename, img) + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__init__.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6960431b1bd6db38890e391c4c94dd2182f2e1fd --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import + +from .storage import download, ensure_available, download_onnx +from .filesystem import get_model_dir +from .filesystem import makedirs, try_import_dali +from .constant import * diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/__init__.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9a1d8421a7b2c978abd367b0c52706ca7f4e5d3 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/__init__.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..185cb017258b0287e0a95382a0164d7a3250bbeb Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/constant.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/constant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0300a80d3e6daea94e2c9538747c6ea9378bf844 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/constant.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/constant.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/constant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3e3b1220b48dfa53bb3671ad110fdcfc91b15fd Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/constant.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/download.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab9c27366a94ab842c44304bb216519ac968d0cb Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/download.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/download.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/download.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f866af2607732caf05d2fbb52a429fcd77731a9b Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/download.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/face_align.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/face_align.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a25852709b126462ea4c28e2f9fdc4bb62ec2ff Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/face_align.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/face_align.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/face_align.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb81e06d9ebaa89f1d026b25a1ddf92d30043d96 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/face_align.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/filesystem.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/filesystem.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b13be309628204bc50e87bd600b97634a8251b0f Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/filesystem.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/filesystem.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/filesystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec57b384c26f75d02e5386006f129c2bd91be0d2 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/filesystem.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/storage.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/storage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c22244b0bbc68ddbd04d0a20b15da638ce0644d3 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/storage.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/storage.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/storage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dccb21e83c0ae7af60442a86c6d4336ca0ed8954 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/storage.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/transform.cpython-310.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44a6ba41e7b44f55bfbea5bb0a72c9c554edb1f8 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/transform.cpython-310.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/transform.cpython-39.pyc b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/transform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78ea73f1c3e7d8641de4e095c628b9c003c5a095 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/__pycache__/transform.cpython-39.pyc differ diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/constant.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..8860ff077ae7227235591edfc84c0cdc227a6432 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/constant.py @@ -0,0 +1,3 @@ + +DEFAULT_MP_NAME = 'buffalo_l' + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/download.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..5cda84dede45b81dcd99161d87792b6c409fa279 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/download.py @@ -0,0 +1,95 @@ +""" +This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py +""" +import os +import hashlib +import requests +from tqdm import tqdm + + +def check_sha1(filename, sha1_hash): + """Check whether the sha1 hash of the file content matches the expected hash. + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, 'rb') as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + sha1_file = sha1.hexdigest() + l = min(len(sha1_file), len(sha1_hash)) + return sha1.hexdigest()[0:l] == sha1_hash[0:l] + + +def download_file(url, path=None, overwrite=False, sha1_hash=None): + """Download an given URL + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified + but doesn't match. + Returns + ------- + str + The file path of the downloaded file. + """ + if path is None: + fname = url.split('/')[-1] + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split('/')[-1]) + else: + fname = path + + if overwrite or not os.path.exists(fname) or ( + sha1_hash and not check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + + print('Downloading %s from %s...' % (fname, url)) + r = requests.get(url, stream=True) + if r.status_code != 200: + raise RuntimeError("Failed downloading url %s" % url) + total_length = r.headers.get('content-length') + with open(fname, 'wb') as f: + if total_length is None: # no content length header + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + else: + total_length = int(total_length) + for chunk in tqdm(r.iter_content(chunk_size=1024), + total=int(total_length / 1024. + 0.5), + unit='KB', + unit_scale=False, + dynamic_ncols=True): + f.write(chunk) + + if sha1_hash and not check_sha1(fname, sha1_hash): + raise UserWarning('File {} is downloaded but the content hash does not match. ' \ + 'The repo may be outdated or download may be incomplete. ' \ + 'If the "repo_url" is overridden, consider switching to ' \ + 'the default repo.'.format(fname)) + + return fname diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/face_align.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/face_align.py new file mode 100644 index 0000000000000000000000000000000000000000..226628b39cf743947df230feffbb97bf5c585e1d --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/face_align.py @@ -0,0 +1,103 @@ +import cv2 +import numpy as np +from skimage import transform as trans + + +arcface_dst = np.array( + [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], + [41.5493, 92.3655], [70.7299, 92.2041]], + dtype=np.float32) + +def estimate_norm(lmk, image_size=112,mode='arcface'): + assert lmk.shape == (5, 2) + assert image_size%112==0 or image_size%128==0 + if image_size%112==0: + ratio = float(image_size)/112.0 + diff_x = 0 + else: + ratio = float(image_size)/128.0 + diff_x = 8.0*ratio + dst = arcface_dst * ratio + dst[:,0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + +def norm_crop(img, landmark, image_size=112, mode='arcface'): + M = estimate_norm(landmark, image_size, mode) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped + +def norm_crop2(img, landmark, image_size=112, mode='arcface'): + M = estimate_norm(landmark, image_size, mode) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped, M + +def square_crop(im, S): + if im.shape[0] > im.shape[1]: + height = S + width = int(float(im.shape[1]) / im.shape[0] * S) + scale = float(S) / im.shape[0] + else: + width = S + height = int(float(im.shape[0]) / im.shape[1] * S) + scale = float(S) / im.shape[1] + resized_im = cv2.resize(im, (width, height)) + det_im = np.zeros((S, S, 3), dtype=np.uint8) + det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im + return det_im, scale + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, + output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, + M, (output_size, output_size), + borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + #print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/filesystem.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..01e3851975bdcbbf7f5eeb7e68e70a36dc040535 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/filesystem.py @@ -0,0 +1,157 @@ +""" +This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py +""" +import os +import os.path as osp +import errno + + +def get_model_dir(name, root='~/.insightface'): + root = os.path.expanduser(root) + model_dir = osp.join(root, 'models', name) + return model_dir + +def makedirs(path): + """Create directory recursively if not exists. + Similar to `makedir -p`, you can skip checking existence before this function. + + Parameters + ---------- + path : str + Path of the desired dir + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + + +def try_import(package, message=None): + """Try import specified package, with custom message support. + + Parameters + ---------- + package : str + The name of the targeting package. + message : str, default is None + If not None, this function will raise customized error message when import error is found. + + + Returns + ------- + module if found, raise ImportError otherwise + + """ + try: + return __import__(package) + except ImportError as e: + if not message: + raise e + raise ImportError(message) + + +def try_import_cv2(): + """Try import cv2 at runtime. + + Returns + ------- + cv2 module if found. Raise ImportError otherwise + + """ + msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \ + or `pip install opencv-python --user` (note that this is unofficial PYPI package)." + + return try_import('cv2', msg) + + +def try_import_mmcv(): + """Try import mmcv at runtime. + + Returns + ------- + mmcv module if found. Raise ImportError otherwise + + """ + msg = "mmcv is required, you can install by first `pip install Cython --user` \ + and then `pip install mmcv --user` (note that this is unofficial PYPI package)." + + return try_import('mmcv', msg) + + +def try_import_rarfile(): + """Try import rarfile at runtime. + + Returns + ------- + rarfile module if found. Raise ImportError otherwise + + """ + msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \ + and then `pip install rarfile --user` (note that this is unofficial PYPI package)." + + return try_import('rarfile', msg) + + +def import_try_install(package, extern_url=None): + """Try import the specified package. + If the package not installed, try use pip to install and import if success. + + Parameters + ---------- + package : str + The name of the package trying to import. + extern_url : str or None, optional + The external url if package is not hosted on PyPI. + For example, you can install a package using: + "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx". + In this case, you can pass the url to the extern_url. + + Returns + ------- + + The imported python module. + + """ + try: + return __import__(package) + except ImportError: + try: + from pip import main as pipmain + except ImportError: + from pip._internal import main as pipmain + + # trying to install package + url = package if extern_url is None else extern_url + pipmain(['install', '--user', + url]) # will raise SystemExit Error if fails + + # trying to load again + try: + return __import__(package) + except ImportError: + import sys + import site + user_site = site.getusersitepackages() + if user_site not in sys.path: + sys.path.append(user_site) + return __import__(package) + return __import__(package) + + +def try_import_dali(): + """Try import NVIDIA DALI at runtime. + """ + try: + dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types']) + dali.Pipeline = dali.pipeline.Pipeline + except ImportError: + + class dali: + class Pipeline: + def __init__(self): + raise NotImplementedError( + "DALI not found, please check if you installed it correctly." + ) + + return dali diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/storage.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf37e2d17b28dee2a8839484778815f87fc4a9c --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/storage.py @@ -0,0 +1,52 @@ + +import os +import os.path as osp +import zipfile +from .download import download_file + +BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7' + +def download(sub_dir, name, force=False, root='~/.insightface'): + _root = os.path.expanduser(root) + dir_path = os.path.join(_root, sub_dir, name) + if osp.exists(dir_path) and not force: + return dir_path + print('download_path:', dir_path) + zip_file_path = os.path.join(_root, sub_dir, name + '.zip') + model_url = "%s/%s.zip"%(BASE_REPO_URL, name) + download_file(model_url, + path=zip_file_path, + overwrite=True) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(dir_path) + #os.remove(zip_file_path) + return dir_path + +def ensure_available(sub_dir, name, root='~/.insightface'): + return download(sub_dir, name, force=False, root=root) + +def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False): + _root = os.path.expanduser(root) + model_root = osp.join(_root, sub_dir) + new_model_file = osp.join(model_root, model_file) + if osp.exists(new_model_file) and not force: + return new_model_file + if not osp.exists(model_root): + os.makedirs(model_root) + print('download_path:', new_model_file) + if not download_zip: + model_url = "%s/%s"%(BASE_REPO_URL, model_file) + download_file(model_url, + path=new_model_file, + overwrite=True) + else: + model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file) + zip_file_path = new_model_file+".zip" + download_file(model_url, + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(model_root) + return new_model_file diff --git a/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/transform.py b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..06531d257b694211a0b9a09c9d741b9b2ff53bfe --- /dev/null +++ b/subprocess/LivePortrait/src/utils/dependencies/insightface/utils/transform.py @@ -0,0 +1,116 @@ +import cv2 +import math +import numpy as np +from skimage import transform as trans + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, + output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, + M, (output_size, output_size), + borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + #print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + +def estimate_affine_matrix_3d23d(X, Y): + ''' Using least-squares solution + Args: + X: [n, 3]. 3d points(fixed) + Y: [n, 3]. corresponding 3d points(moving). Y = PX + Returns: + P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]). + ''' + X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4 + P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4 + return P + +def P2sRt(P): + ''' decompositing camera matrix P + Args: + P: (3, 4). Affine Camera Matrix. + Returns: + s: scale factor. + R: (3, 3). rotation matrix. + t: (3,). translation. + ''' + t = P[:, 3] + R1 = P[0:1, :3] + R2 = P[1:2, :3] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0 + r1 = R1/np.linalg.norm(R1) + r2 = R2/np.linalg.norm(R2) + r3 = np.cross(r1, r2) + + R = np.concatenate((r1, r2, r3), 0) + return s, R, t + +def matrix2angle(R): + ''' get three Euler angles from Rotation Matrix + Args: + R: (3,3). rotation matrix + Returns: + x: pitch + y: yaw + z: roll + ''' + sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) + + singular = sy < 1e-6 + + if not singular : + x = math.atan2(R[2,1] , R[2,2]) + y = math.atan2(-R[2,0], sy) + z = math.atan2(R[1,0], R[0,0]) + else : + x = math.atan2(-R[1,2], R[1,1]) + y = math.atan2(-R[2,0], sy) + z = 0 + + # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z) + rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi + return rx, ry, rz + diff --git a/subprocess/LivePortrait/src/utils/face_analysis_diy.py b/subprocess/LivePortrait/src/utils/face_analysis_diy.py new file mode 100644 index 0000000000000000000000000000000000000000..f13a659134216958da3c7273aabf3b0f96fb320d --- /dev/null +++ b/subprocess/LivePortrait/src/utils/face_analysis_diy.py @@ -0,0 +1,79 @@ +# coding: utf-8 + +""" +face detectoin and alignment using InsightFace +""" + +import numpy as np +from .rprint import rlog as log +from .dependencies.insightface.app import FaceAnalysis +from .dependencies.insightface.app.common import Face +from .timer import Timer + + +def sort_by_direction(faces, direction: str = 'large-small', face_center=None): + if len(faces) <= 0: + return faces + + if direction == 'left-right': + return sorted(faces, key=lambda face: face['bbox'][0]) + if direction == 'right-left': + return sorted(faces, key=lambda face: face['bbox'][0], reverse=True) + if direction == 'top-bottom': + return sorted(faces, key=lambda face: face['bbox'][1]) + if direction == 'bottom-top': + return sorted(faces, key=lambda face: face['bbox'][1], reverse=True) + if direction == 'small-large': + return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1])) + if direction == 'large-small': + return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1]), reverse=True) + if direction == 'distance-from-retarget-face': + return sorted(faces, key=lambda face: (((face['bbox'][2]+face['bbox'][0])/2-face_center[0])**2+((face['bbox'][3]+face['bbox'][1])/2-face_center[1])**2)**0.5) + return faces + + +class FaceAnalysisDIY(FaceAnalysis): + def __init__(self, name='buffalo_l', root='~/.insightface', allowed_modules=None, **kwargs): + super().__init__(name=name, root=root, allowed_modules=allowed_modules, **kwargs) + + self.timer = Timer() + + def get(self, img_bgr, **kwargs): + max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit + flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection + direction = kwargs.get('direction', 'large-small') # sorting direction + face_center = None + + bboxes, kpss = self.det_model.detect(img_bgr, max_num=max_num, metric='default') + if bboxes.shape[0] == 0: + return [] + ret = [] + for i in range(bboxes.shape[0]): + bbox = bboxes[i, 0:4] + det_score = bboxes[i, 4] + kps = None + if kpss is not None: + kps = kpss[i] + face = Face(bbox=bbox, kps=kps, det_score=det_score) + for taskname, model in self.models.items(): + if taskname == 'detection': + continue + + if (not flag_do_landmark_2d_106) and taskname == 'landmark_2d_106': + continue + + # print(f'taskname: {taskname}') + model.get(img_bgr, face) + ret.append(face) + + ret = sort_by_direction(ret, direction, face_center) + return ret + + def warmup(self): + self.timer.tic() + + img_bgr = np.zeros((512, 512, 3), dtype=np.uint8) + self.get(img_bgr) + + elapse = self.timer.toc() + log(f'FaceAnalysisDIY warmup time: {elapse:.3f}s') diff --git a/subprocess/LivePortrait/src/utils/helper.py b/subprocess/LivePortrait/src/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..0e2af94e137b6447c88ec4df3c7c2c1b1bd94b8a --- /dev/null +++ b/subprocess/LivePortrait/src/utils/helper.py @@ -0,0 +1,145 @@ +# coding: utf-8 + +""" +utility functions and classes to handle feature extraction and model loading +""" + +import os +import os.path as osp +import torch +from collections import OrderedDict + +from ..modules.spade_generator import SPADEDecoder +from ..modules.warping_network import WarpingNetwork +from ..modules.motion_extractor import MotionExtractor +from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor +from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork + + +def suffix(filename): + """a.jpg -> jpg""" + pos = filename.rfind(".") + if pos == -1: + return "" + return filename[pos + 1:] + + +def prefix(filename): + """a.jpg -> a""" + pos = filename.rfind(".") + if pos == -1: + return filename + return filename[:pos] + + +def basename(filename): + """a/b/c.jpg -> c""" + return prefix(osp.basename(filename)) + + +def remove_suffix(filepath): + """a/b/c.jpg -> a/b/c""" + return osp.join(osp.dirname(filepath), basename(filepath)) + + +def is_video(file_path): + if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): + return True + return False + + +def is_template(file_path): + if file_path.endswith(".pkl"): + return True + return False + + +def mkdir(d, log=False): + # return self-assined `d`, for one line code + if not osp.exists(d): + os.makedirs(d, exist_ok=True) + if log: + print(f"Make dir: {d}") + return d + + +def squeeze_tensor_to_numpy(tensor): + out = tensor.data.squeeze(0).cpu().numpy() + return out + + +def dct2device(dct: dict, device): + for key in dct: + dct[key] = torch.tensor(dct[key]).to(device) + return dct + + +def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ + kp_source: (bs, k, 3) + kp_driving: (bs, k, 3) + Return: (bs, 2k*3) + """ + bs_src = kp_source.shape[0] + bs_dri = kp_driving.shape[0] + assert bs_src == bs_dri, 'batch size must be equal' + + feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) + return feat + + +def remove_ddp_dumplicate_key(state_dict): + state_dict_new = OrderedDict() + for key in state_dict.keys(): + state_dict_new[key.replace('module.', '')] = state_dict[key] + return state_dict_new + + +def load_model(ckpt_path, model_config, device, model_type): + model_params = model_config['model_params'][f'{model_type}_params'] + + if model_type == 'appearance_feature_extractor': + model = AppearanceFeatureExtractor(**model_params).to(device) + elif model_type == 'motion_extractor': + model = MotionExtractor(**model_params).to(device) + elif model_type == 'warping_module': + model = WarpingNetwork(**model_params).to(device) + elif model_type == 'spade_generator': + model = SPADEDecoder(**model_params).to(device) + elif model_type == 'stitching_retargeting_module': + # Special handling for stitching and retargeting module + config = model_config['model_params']['stitching_retargeting_module_params'] + checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) + + stitcher = StitchingRetargetingNetwork(**config.get('stitching')) + stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) + stitcher = stitcher.to(device) + stitcher.eval() + + retargetor_lip = StitchingRetargetingNetwork(**config.get('lip')) + retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth'])) + retargetor_lip = retargetor_lip.to(device) + retargetor_lip.eval() + + retargetor_eye = StitchingRetargetingNetwork(**config.get('eye')) + retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye'])) + retargetor_eye = retargetor_eye.to(device) + retargetor_eye.eval() + + return { + 'stitching': stitcher, + 'lip': retargetor_lip, + 'eye': retargetor_eye + } + else: + raise ValueError(f"Unknown model type: {model_type}") + + model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage)) + model.eval() + return model + + +def load_description(fp): + with open(fp, 'r', encoding='utf-8') as f: + content = f.read() + return content diff --git a/subprocess/LivePortrait/src/utils/io.py b/subprocess/LivePortrait/src/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..28c2d99f09421fc9eb1f6475419cb1c6e6dcd028 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/io.py @@ -0,0 +1,125 @@ +# coding: utf-8 + +import os +from glob import glob +import os.path as osp +import imageio +import numpy as np +import pickle +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) + +from .helper import mkdir, suffix + + +def load_image_rgb(image_path: str): + if not osp.exists(image_path): + raise FileNotFoundError(f"Image not found: {image_path}") + img = cv2.imread(image_path, cv2.IMREAD_COLOR) + return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + +def load_driving_info(driving_info): + driving_video_ori = [] + + def load_images_from_directory(directory): + image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg'))) + return [load_image_rgb(im_path) for im_path in image_paths] + + def load_images_from_video(file_path): + reader = imageio.get_reader(file_path, "ffmpeg") + return [image for _, image in enumerate(reader)] + + if osp.isdir(driving_info): + driving_video_ori = load_images_from_directory(driving_info) + elif osp.isfile(driving_info): + driving_video_ori = load_images_from_video(driving_info) + + return driving_video_ori + + +def contiguous(obj): + if not obj.flags.c_contiguous: + obj = obj.copy(order="C") + return obj + + +def resize_to_limit(img: np.ndarray, max_dim=1920, division=2): + """ + ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. + :param img: the image to be processed. + :param max_dim: the maximum dimension constraint. + :param n: the number that needs to be multiples of. + :return: the adjusted image. + """ + h, w = img.shape[:2] + + # ajust the size of the image according to the maximum dimension + if max_dim > 0 and max(h, w) > max_dim: + if h > w: + new_h = max_dim + new_w = int(w * (max_dim / h)) + else: + new_w = max_dim + new_h = int(h * (max_dim / w)) + img = cv2.resize(img, (new_w, new_h)) + + # ensure that the image dimensions are multiples of n + division = max(division, 1) + new_h = img.shape[0] - (img.shape[0] % division) + new_w = img.shape[1] - (img.shape[1] % division) + + if new_h == 0 or new_w == 0: + # when the width or height is less than n, no need to process + return img + + if new_h != img.shape[0] or new_w != img.shape[1]: + img = img[:new_h, :new_w] + + return img + + +def load_img_online(obj, mode="bgr", **kwargs): + max_dim = kwargs.get("max_dim", 1920) + n = kwargs.get("n", 2) + if isinstance(obj, str): + if mode.lower() == "gray": + img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE) + else: + img = cv2.imread(obj, cv2.IMREAD_COLOR) + else: + img = obj + + # Resize image to satisfy constraints + img = resize_to_limit(img, max_dim=max_dim, division=n) + + if mode.lower() == "bgr": + return contiguous(img) + elif mode.lower() == "rgb": + return contiguous(img[..., ::-1]) + else: + raise Exception(f"Unknown mode {mode}") + + +def load(fp): + suffix_ = suffix(fp) + + if suffix_ == "npy": + return np.load(fp) + elif suffix_ == "pkl": + return pickle.load(open(fp, "rb")) + else: + raise Exception(f"Unknown type: {suffix}") + + +def dump(wfp, obj): + wd = osp.split(wfp)[0] + if wd != "" and not osp.exists(wd): + mkdir(wd) + + _suffix = suffix(wfp) + if _suffix == "npy": + np.save(wfp, obj) + elif _suffix == "pkl": + pickle.dump(obj, open(wfp, "wb")) + else: + raise Exception("Unknown type: {}".format(_suffix)) diff --git a/subprocess/LivePortrait/src/utils/landmark_runner.py b/subprocess/LivePortrait/src/utils/landmark_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..7680a2c4a65ebe7f4dadbafc4a35603ab9f90be6 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/landmark_runner.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +import os.path as osp +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) +import torch +import numpy as np +import onnxruntime +from .timer import Timer +from .rprint import rlog +from .crop import crop_image, _transform_pts + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +def to_ndarray(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu().numpy() + elif isinstance(obj, np.ndarray): + return obj + else: + return np.array(obj) + + +class LandmarkRunner(object): + """landmark runner""" + + def __init__(self, **kwargs): + ckpt_path = kwargs.get('ckpt_path') + onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda + device_id = kwargs.get('device_id', 0) + self.dsize = kwargs.get('dsize', 224) + self.timer = Timer() + + if onnx_provider.lower() == 'cuda': + self.session = onnxruntime.InferenceSession( + ckpt_path, providers=[ + ('CUDAExecutionProvider', {'device_id': device_id}) + ] + ) + else: + opts = onnxruntime.SessionOptions() + opts.intra_op_num_threads = 4 # 默认线程数为 4 + self.session = onnxruntime.InferenceSession( + ckpt_path, providers=['CPUExecutionProvider'], + sess_options=opts + ) + + def _run(self, inp): + out = self.session.run(None, {'input': inp}) + return out + + def run(self, img_rgb: np.ndarray, lmk=None): + if lmk is not None: + crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1) + img_crop_rgb = crop_dct['img_crop'] + else: + # NOTE: force resize to 224x224, NOT RECOMMEND! + img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize)) + scale = max(img_rgb.shape[:2]) / self.dsize + crop_dct = { + 'M_c2o': np.array([ + [scale, 0., 0.], + [0., scale, 0.], + [0., 0., 1.], + ], dtype=np.float32), + } + + inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) + + out_lst = self._run(inp) + out_pts = out_lst[2] + + # 2d landmarks 203 points + lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224 + lmk = _transform_pts(lmk, M=crop_dct['M_c2o']) + + return lmk + + def warmup(self): + self.timer.tic() + + dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32) + + _ = self._run(dummy_image) + + elapse = self.timer.toc() + rlog(f'LandmarkRunner warmup time: {elapse:.3f}s') diff --git a/subprocess/LivePortrait/src/utils/resources/mask_template.png b/subprocess/LivePortrait/src/utils/resources/mask_template.png new file mode 100644 index 0000000000000000000000000000000000000000..bca6ca5977ba820d0d2c05b3793c6231cc82e715 Binary files /dev/null and b/subprocess/LivePortrait/src/utils/resources/mask_template.png differ diff --git a/subprocess/LivePortrait/src/utils/retargeting_utils.py b/subprocess/LivePortrait/src/utils/retargeting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2e5f52effe8107503586c9f5a24f39dfdbbbcf --- /dev/null +++ b/subprocess/LivePortrait/src/utils/retargeting_utils.py @@ -0,0 +1,24 @@ + +""" +Functions to compute distance ratios between specific pairs of facial landmarks +""" + +import numpy as np + + +def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: + return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) / + (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps)) + + +def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray: + lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12) + righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36) + if target_eye_ratio is not None: + return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1) + else: + return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1) + + +def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: + return calculate_distance_ratio(lmk, 90, 102, 48, 66) diff --git a/subprocess/LivePortrait/src/utils/rprint.py b/subprocess/LivePortrait/src/utils/rprint.py new file mode 100644 index 0000000000000000000000000000000000000000..c43a42f9855bbb019725e6c2b6c6c50e6fa4d0c5 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/rprint.py @@ -0,0 +1,16 @@ +# coding: utf-8 + +""" +custom print and log functions +""" + +__all__ = ['rprint', 'rlog'] + +try: + from rich.console import Console + console = Console() + rprint = console.print + rlog = console.log +except: + rprint = print + rlog = print diff --git a/subprocess/LivePortrait/src/utils/timer.py b/subprocess/LivePortrait/src/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..3570fa45d3ff36376471b82a5b3c02efe46eed98 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/timer.py @@ -0,0 +1,29 @@ +# coding: utf-8 + +""" +tools to measure elapsed time +""" + +import time + +class Timer(object): + """A simple timer.""" + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + return self.diff + + def clear(self): + self.start_time = 0. + self.diff = 0. diff --git a/subprocess/LivePortrait/src/utils/video.py b/subprocess/LivePortrait/src/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..a69841238f27f0259a67f0290eb5807dd2087efa --- /dev/null +++ b/subprocess/LivePortrait/src/utils/video.py @@ -0,0 +1,212 @@ +# coding: utf-8 + +""" +functions for processing video +""" + +import os.path as osp +import numpy as np +import subprocess +import imageio +import cv2 + +from .rprint import rlog as log + +# try: +# import ffmpeg +# except ImportError as e: +# log(f'Try to install ffmpeg by: pip install ffmpeg-python==0.2.0', style='bold red') +# raise(e) + +from rich.progress import track +from .helper import prefix +from .rprint import rprint as print + + + +def exec_cmd(cmd): + subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + +def images2video(images, wfp, **kwargs): + fps = kwargs.get('fps', 30) + video_format = kwargs.get('format', 'mp4') # default is mp4 format + codec = kwargs.get('codec', 'libx264') # default is libx264 encoding + quality = kwargs.get('quality') # video quality + pixelformat = kwargs.get('pixelformat', 'yuv420p') # video pixel format + image_mode = kwargs.get('image_mode', 'rgb') + macro_block_size = kwargs.get('macro_block_size', 2) + ffmpeg_params = ['-crf', str(kwargs.get('crf', 18))] + + writer = imageio.get_writer( + wfp, fps=fps, format=video_format, + codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size + ) + + n = len(images) + for i in track(range(n), description='Writing', transient=True): + if image_mode.lower() == 'bgr': + writer.append_data(images[i][..., ::-1]) + else: + writer.append_data(images[i]) + + writer.close() + + +def video2gif(video_fp, fps=30, size=256): + if osp.exists(video_fp): + d = osp.split(video_fp)[0] + fn = prefix(osp.basename(video_fp)) + palette_wfp = osp.join(d, 'palette.png') + gif_wfp = osp.join(d, f'{fn}.gif') + # generate the palette + cmd = f'ffmpeg -i {video_fp} -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" {palette_wfp} -y' + exec_cmd(cmd) + # use the palette to generate the gif + cmd = f'ffmpeg -i {video_fp} -i {palette_wfp} -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" {gif_wfp} -y' + exec_cmd(cmd) + else: + print(f'video_fp: {video_fp} not exists!') + + +def merge_audio_video(video_fp, audio_fp, wfp): + if osp.exists(video_fp) and osp.exists(audio_fp): + cmd = f'ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y' + exec_cmd(cmd) + print(f'merge {video_fp} and {audio_fp} to {wfp}') + else: + print(f'video_fp: {video_fp} or audio_fp: {audio_fp} not exists!') + + +def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)): + mask_float = mask.astype(np.float32) / 255. + background_color = np.array(background_color).reshape([1, 1, 3]) + bg = np.ones_like(img) * background_color + img = np.clip(mask_float * img + (1 - mask_float) * bg, 0, 255).astype(np.uint8) + return img + + +def concat_frames(driving_image_lst, source_image, I_p_lst): + # TODO: add more concat style, e.g., left-down corner driving + out_lst = [] + h, w, _ = I_p_lst[0].shape + + for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'): + I_p = I_p_lst[idx] + source_image_resized = cv2.resize(source_image, (w, h)) + + if driving_image_lst is None: + out = np.hstack((source_image_resized, I_p)) + else: + driving_image = driving_image_lst[idx] + driving_image_resized = cv2.resize(driving_image, (w, h)) + out = np.hstack((driving_image_resized, source_image_resized, I_p)) + + out_lst.append(out) + return out_lst + + +class VideoWriter: + def __init__(self, **kwargs): + self.fps = kwargs.get('fps', 30) + self.wfp = kwargs.get('wfp', 'video.mp4') + self.video_format = kwargs.get('format', 'mp4') + self.codec = kwargs.get('codec', 'libx264') + self.quality = kwargs.get('quality') + self.pixelformat = kwargs.get('pixelformat', 'yuv420p') + self.image_mode = kwargs.get('image_mode', 'rgb') + self.ffmpeg_params = kwargs.get('ffmpeg_params') + + self.writer = imageio.get_writer( + self.wfp, fps=self.fps, format=self.video_format, + codec=self.codec, quality=self.quality, + ffmpeg_params=self.ffmpeg_params, pixelformat=self.pixelformat + ) + + def write(self, image): + if self.image_mode.lower() == 'bgr': + self.writer.append_data(image[..., ::-1]) + else: + self.writer.append_data(image) + + def close(self): + if self.writer is not None: + self.writer.close() + + +def change_video_fps(input_file, output_file, fps=20, codec='libx264', crf=5): + cmd = f"ffmpeg -i {input_file} -c:v {codec} -crf {crf} -r {fps} {output_file} -y" + exec_cmd(cmd) + + +def get_fps(filepath, default_fps=25): + try: + fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS) + + if fps in (0, None): + fps = default_fps + except Exception as e: + print(e) + fps = default_fps + + return fps + + +def has_audio_stream(video_path: str) -> bool: + """ + Check if the video file contains an audio stream. + + :param video_path: Path to the video file + :return: True if the video contains an audio stream, False otherwise + """ + if osp.isdir(video_path): + return False + + cmd = [ + 'ffprobe', + '-v', 'error', + '-select_streams', 'a', + '-show_entries', 'stream=codec_type', + '-of', 'default=noprint_wrappers=1:nokey=1', + video_path + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + log(f"Error occurred while probing video: {result.stderr}") + return False + + # Check if there is any output from ffprobe command + return bool(result.stdout.strip()) + + +def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str): + cmd = [ + 'ffmpeg', + '-y', + '-i', silent_video_path, + '-i', audio_video_path, + '-map', '0:v', + '-map', '1:a', + '-c:v', 'copy', + '-shortest', + output_video_path + ] + + try: + exec_cmd(' '.join(cmd)) + log(f"Video with audio generated successfully: {output_video_path}") + except subprocess.CalledProcessError as e: + log(f"Error occurred: {e}") + + +def bb_intersection_over_union(boxA, boxB): + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) + boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) + boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) + iou = interArea / float(boxAArea + boxBArea - interArea) + return iou diff --git a/subprocess/LivePortrait/src/utils/viz.py b/subprocess/LivePortrait/src/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..59443cbf207f3395bee241f63c7acb95b9402530 --- /dev/null +++ b/subprocess/LivePortrait/src/utils/viz.py @@ -0,0 +1,19 @@ +# coding: utf-8 + +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) + + +def viz_lmk(img_, vps, **kwargs): + """可视化点""" + lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA + img_for_viz = img_.copy() + for pt in vps: + cv2.circle( + img_for_viz, + (int(pt[0]), int(pt[1])), + radius=kwargs.get("radius", 1), + color=(0, 255, 0), + thickness=kwargs.get("thickness", 1), + lineType=lineType, + ) + return img_for_viz diff --git a/subprocess/LivePortrait/uploads/d6.mp4 b/subprocess/LivePortrait/uploads/d6.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..44f351385cef843b21b03fab8c3b10e0c005ec5e --- /dev/null +++ b/subprocess/LivePortrait/uploads/d6.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00e3ea79bbf28cbdc4fbb67ec655d9a0fe876e880ec45af55ae481348d0c0fff +size 1967790 diff --git a/subprocess/LivePortrait/uploads/image(3).png b/subprocess/LivePortrait/uploads/image(3).png new file mode 100644 index 0000000000000000000000000000000000000000..bb0ee94c02976ada9d7c4857109d338bfa9b52d7 Binary files /dev/null and b/subprocess/LivePortrait/uploads/image(3).png differ diff --git a/subprocess/LivePortrait/uploads/intro.mp4 b/subprocess/LivePortrait/uploads/intro.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2d220c12b25a494a9d25ceb1d7f5e6f1413f8b6d --- /dev/null +++ b/subprocess/LivePortrait/uploads/intro.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a984facdcb2c83b29cb6bee2888fd65f1d624d2a234125c83368f381ab4a0ceb +size 4368949 diff --git a/subprocess/port.db b/subprocess/port.db new file mode 100644 index 0000000000000000000000000000000000000000..ae3ccc2bd090134988f23e7aa346ab7d5ff937fe Binary files /dev/null and b/subprocess/port.db differ