diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b8ee441f1ff658fa2bc9037d198fb8f6c80b7bcd --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +__pycache__/ +build/ +*.egg-info/ +*.so +*.mp4 +*.pth + +data_utils/face_tracking/3DMM/* +data_utils/face_parsing/79999_iter.pth + +*.pyc +.vscode +output* +build +gridencoder/gridencoder.egg-info +diff_rasterization/diff_rast.egg-info +diff_rasterization/dist +tensorboard_3d +screenshots + +data/* +!*.gitkeep \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..1bd74a0608dd0e96a8bba37973875869a27c65bb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "submodules/simple-knn"] + path = submodules/simple-knn + url = https://gitlab.inria.fr/bkerbl/simple-knn.git +[submodule "submodules/diff-gaussian-rasterization"] + path = submodules/diff-gaussian-rasterization + url = https://github.com/ashawkey/diff-gaussian-rasterization.git diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a6d2215bba7dcfe7bd36a313d0a41ffd4e47bbd1 --- /dev/null +++ b/README.md @@ -0,0 +1,153 @@ +# TalkingGaussian: Structure-Persistent 3D Talking Head Synthesis via Gaussian Splatting + +This is the official repository for our ECCV 2024 paper **TalkingGaussian: Structure-Persistent 3D Talking Head Synthesis via Gaussian Splatting**. + +[Paper](https://arxiv.org/abs/2404.15264) | [Project](https://fictionarry.github.io/TalkingGaussian/) | [Video](https://youtu.be/c5VG7HkDs8I) + +![image](./assets/main.png) + + +## Installation + +Tested on Ubuntu 18.04, CUDA 11.3, PyTorch 1.12.1 + +``` +git clone git@github.com:Fictionarry/TalkingGaussian.git --recursive + +conda env create --file environment.yml +conda activate talking_gaussian +pip install "git+https://github.com/facebookresearch/pytorch3d.git" +pip install tensorflow-gpu==2.8.0 +``` + +If encounter installation problem from the `diff-gaussian-rasterization` or `gridencoder`, please refer to [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [torch-ngp](https://github.com/ashawkey/torch-ngp). + +### Preparation + +- Prepare face-parsing model and the 3DMM model for head pose estimation. + + ```bash + bash scripts/prepare.sh + ``` + +- Download 3DMM model from [Basel Face Model 2009](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-1-0&id=details): + + ```bash + # 1. copy 01_MorphableModel.mat to data_util/face_tracking/3DMM/ + # 2. run following + cd data_utils/face_tracking + python convert_BFM.py + ``` + +- Prepare the environment for [EasyPortrait](https://github.com/hukenovs/easyportrait): + + ```bash + # prepare mmcv + conda activate talking_gaussian + pip install -U openmim + mim install mmcv-full==1.7.1 + + # download model weight + cd data_utils/easyportrait + wget "https://n-ws-620xz-pd11.s3pd11.sbercloud.ru/b-ws-620xz-pd11-jux/easyportrait/experiments/models/fpn-fp-512.pth" + ``` + +## Usage + +### Important Notice + +- This code is provided for research purposes only. The author makes no warranties, express or implied, as to the accuracy, completeness, or fitness for a particular purpose of the code. Use this code at your own risk. + +- The author explicitly prohibits the use of this code for any malicious or illegal activities. By using this code, you agree to comply with all applicable laws and regulations, and you agree not to use it to harm others or to perform any actions that would be considered unethical or illegal. + +- The author will not be responsible for any damages, losses, or issues that arise from the use of this code. + +- Users are encouraged to use this code responsibly and ethically. + +### Video Dataset +[Here](https://drive.google.com/drive/folders/1E_8W805lioIznqbkvTQHWWi5IFXUG7Er?usp=drive_link) we provide two video clips used in our experiments, which are captured from YouTube. Please respect the original content creators' rights and comply with YouTube’s copyright policies in the usage. + +Other used videos can be found from [GeneFace](https://github.com/yerfor/GeneFace) and [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). + + +### Pre-processing Training Video + +* Put training video under `data//.mp4`. + + The video **must be 25FPS, with all frames containing the talking person**. + The resolution should be about 512x512, and duration about 1-5 min. + +* Run script to process the video. + + ```bash + python data_utils/process.py data//.mp4 + ``` + +* Obtain Action Units + + Run `FeatureExtraction` in [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace), rename and move the output CSV file to `data//au.csv`. + +* Generate tooth masks + + ```bash + export PYTHONPATH=./data_utils/easyportrait + python ./data_utils/easyportrait/create_teeth_mask.py ./data/ + ``` + +### Audio Pre-process + +In our paper, we use DeepSpeech features for evaluation. + +* DeepSpeech + + ```bash + python data_utils/deepspeech_features/extract_ds_features.py --input data/.wav # saved to data/.npy + ``` + +- HuBERT + + Similar to ER-NeRF, HuBERT is also available. Recommended for situations if the audio is not in English. + + Specify `--audio_extractor hubert` when training and testing. + + ``` + python data_utils/hubert.py --wav data/.wav # save to data/_hu.npy + ``` + +### Train + +```bash +# If resources are sufficient, partially parallel is available to speed up the training. See the script. +bash scripts/train_xx.sh data/ output/ +``` + +### Test + +```bash +# saved to output//test/ours_None/renders +python synthesize_fuse.py -S data/ -M output/ --eval +``` + +### Inference with target audio + +```bash +python synthesize_fuse.py -S data/ -M output/ --use_train --audio .npy +``` + +## Citation + +Consider citing as below if you find this repository helpful to your project: + +``` +@article{li2024talkinggaussian, + title={TalkingGaussian: Structure-Persistent 3D Talking Head Synthesis via Gaussian Splatting}, + author={Jiahe Li and Jiawei Zhang and Xiao Bai and Jin Zheng and Xin Ning and Jun Zhou and Lin Gu}, + journal={arXiv preprint arXiv:2404.15264}, + year={2024} +} +``` + + +## Acknowledgement + +This code is developed on [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) with [simple-knn](https://gitlab.inria.fr/bkerbl/simple-knn), and a modified [diff-gaussian-rasterization](https://github.com/ashawkey/diff-gaussian-rasterization). Partial codes are from [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF), [DFRF](https://github.com/sstzal/DFRF), [GeneFace](https://github.com/yerfor/GeneFace), and [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). Teeth mask is from [EasyPortrait](https://github.com/hukenovs/easyportrait). Thanks for these great projects! diff --git a/arguments/__init__.py b/arguments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f44f8d6d2123def3341eddd2fadc7fa7dc9a73bd --- /dev/null +++ b/arguments/__init__.py @@ -0,0 +1,118 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from argparse import ArgumentParser, Namespace +import sys +import os + +class GroupParams: + pass + +class ParamGroup: + def __init__(self, parser: ArgumentParser, name : str, fill_none = False): + group = parser.add_argument_group(name) + for key, value in vars(self).items(): + shorthand = False + if key.startswith("_"): + shorthand = True + key = key[1:] + t = type(value) + value = value if not fill_none else None + if shorthand: + if t == bool: + group.add_argument("--" + key, ("-" + key[0:1]), ("-" + key[0:1].upper()), default=value, action="store_true") + else: + group.add_argument("--" + key, ("-" + key[0:1]), ("-" + key[0:1].upper()), default=value, type=t) + else: + if t == bool: + group.add_argument("--" + key, default=value, action="store_true") + else: + group.add_argument("--" + key, default=value, type=t) + + def extract(self, args): + group = GroupParams() + for arg in vars(args).items(): + if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): + setattr(group, arg[0], arg[1]) + return group + +class ModelParams(ParamGroup): + def __init__(self, parser, sentinel=False): + self.sh_degree = 2 + self._source_path = "" + self._model_path = "" + self._images = "images" + self._resolution = -1 + self._white_background = False + self.data_device = "cpu" + self.eval = False + self.audio = "" + self.init_num = 10_000 + self.audio_extractor = "deepspeech" + super().__init__(parser, "Loading Parameters", sentinel) + + def extract(self, args): + g = super().extract(args) + g.source_path = os.path.abspath(g.source_path) + + return g + +class PipelineParams(ParamGroup): + def __init__(self, parser): + self.convert_SHs_python = False + self.compute_cov3D_python = False + self.debug = False + super().__init__(parser, "Pipeline Parameters") + +class OptimizationParams(ParamGroup): + def __init__(self, parser): + self.iterations = 50_000 + self.position_lr_init = 0.00016 + self.position_lr_final = 0.0000016 + self.position_lr_delay_mult = 0.01 + self.position_lr_max_steps = 45_000 + self.feature_lr = 0.0025 + self.opacity_lr = 0.05 + self.scaling_lr = 0.003 + self.rotation_lr = 0.001 + self.percent_dense = 0.005 + self.lambda_dssim = 0.2 + self.densification_interval = 100 + self.opacity_reset_interval = 3000 + self.densify_from_iter = 500 + + + self.densify_until_iter = 45_000 + self.densify_grad_threshold = 0.0002 + self.random_background = False + super().__init__(parser, "Optimization Parameters") + +def get_combined_args(parser : ArgumentParser): + cmdlne_string = sys.argv[1:] + cfgfile_string = "Namespace()" + args_cmdline = parser.parse_args(cmdlne_string) + + try: + cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") + print("Looking for config file in", cfgfilepath) + with open(cfgfilepath) as cfg_file: + print("Config file found: {}".format(cfgfilepath)) + cfgfile_string = cfg_file.read() + except TypeError: + print("Config file not found at") + pass + args_cfgfile = eval(cfgfile_string) + + merged_dict = vars(args_cfgfile).copy() + for k,v in vars(args_cmdline).items(): + if v != None: + merged_dict[k] = v + return Namespace(**merged_dict) diff --git a/assets/main.png b/assets/main.png new file mode 100644 index 0000000000000000000000000000000000000000..e24b412fe32a4944684c45b89eceefb8c622494f Binary files /dev/null and b/assets/main.png differ diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/deepspeech_features/README.md b/data_utils/deepspeech_features/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9f6c6be242e85f36f9bfb8b673b610506032277 --- /dev/null +++ b/data_utils/deepspeech_features/README.md @@ -0,0 +1,20 @@ +# Routines for DeepSpeech features processing +Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model. + +## Installation + +``` +pip3 install -r requirements.txt +``` + +## Usage + +Generate wav files: +``` +python3 extract_wav.py --in-video= +``` + +Generate files with DeepSpeech features: +``` +python3 extract_ds_features.py --input= +``` diff --git a/data_utils/deepspeech_features/deepspeech_features.py b/data_utils/deepspeech_features/deepspeech_features.py new file mode 100644 index 0000000000000000000000000000000000000000..787a6ea806369e0547055816b6cd4c52e0e20487 --- /dev/null +++ b/data_utils/deepspeech_features/deepspeech_features.py @@ -0,0 +1,274 @@ +""" + DeepSpeech features processing routines. + NB: Based on VOCA code. See the corresponding license restrictions. +""" + +__all__ = ['conv_audios_to_deepspeech'] + +import numpy as np +import warnings +import resampy +from scipy.io import wavfile +from python_speech_features import mfcc +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() + +def conv_audios_to_deepspeech(audios, + out_files, + num_frames_info, + deepspeech_pb_path, + audio_window_size=1, + audio_window_stride=1): + """ + Convert list of audio files into files with DeepSpeech features. + + Parameters + ---------- + audios : list of str or list of None + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + num_frames_info : list of int + List of numbers of frames. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + audio_window_size : int, default 16 + Audio window size. + audio_window_stride : int, default 1 + Audio window stride. + """ + graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net( + deepspeech_pb_path) + + with tf.compat.v1.Session(graph=graph) as sess: + for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info): + print(audio_file_path) + print(out_file_path) + audio_sample_rate, audio = wavfile.read(audio_file_path) + if audio.ndim != 1: + warnings.warn( + "Audio has multiple channels, the first channel is used") + audio = audio[:, 0] + ds_features = pure_conv_audio_to_deepspeech( + audio=audio, + audio_sample_rate=audio_sample_rate, + audio_window_size=audio_window_size, + audio_window_stride=audio_window_stride, + num_frames=num_frames, + net_fn=lambda x: sess.run( + logits_ph, + feed_dict={ + input_node_ph: x[np.newaxis, ...], + input_lengths_ph: [x.shape[0]]})) + + net_output = ds_features.reshape(-1, 29) + win_size = 16 + zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) + net_output = np.concatenate( + (zero_pad, net_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append( + net_output[window_index:window_index + win_size]) + print(np.array(windows).shape) + np.save(out_file_path, np.array(windows)) + + +def prepare_deepspeech_net(deepspeech_pb_path): + """ + Load and prepare DeepSpeech network. + + Parameters + ---------- + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + + Returns + ------- + graph : obj + ThensorFlow graph. + logits_ph : obj + ThensorFlow placeholder for `logits`. + input_node_ph : obj + ThensorFlow placeholder for `input_node`. + input_lengths_ph : obj + ThensorFlow placeholder for `input_lengths`. + """ + # Load graph and place_holders: + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + + graph = tf.compat.v1.get_default_graph() + tf.import_graph_def(graph_def, name="deepspeech") + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") + + return graph, logits_ph, input_node_ph, input_lengths_ph + + +def pure_conv_audio_to_deepspeech(audio, + audio_sample_rate, + audio_window_size, + audio_window_stride, + num_frames, + net_fn): + """ + Core routine for converting audion into DeepSpeech features. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + audio_window_size : int + Audio window size. + audio_window_stride : int + Audio window stride. + num_frames : int or None + Numbers of frames. + net_fn : func + Function for DeepSpeech model call. + + Returns + ------- + np.array + DeepSpeech features. + """ + target_sample_rate = 16000 + if audio_sample_rate != target_sample_rate: + resampled_audio = resampy.resample( + x=audio.astype(np.float), + sr_orig=audio_sample_rate, + sr_new=target_sample_rate) + else: + resampled_audio = audio.astype(np.float32) + input_vector = conv_audio_to_deepspeech_input_vector( + audio=resampled_audio.astype(np.int16), + sample_rate=target_sample_rate, + num_cepstrum=26, + num_context=9) + + network_output = net_fn(input_vector) + # print(network_output.shape) + + deepspeech_fps = 50 + video_fps = 50 # Change this option if video fps is different + audio_len_s = float(audio.shape[0]) / audio_sample_rate + if num_frames is None: + num_frames = int(round(audio_len_s * video_fps)) + else: + video_fps = num_frames / audio_len_s + network_output = interpolate_features( + features=network_output[:, 0], + input_rate=deepspeech_fps, + output_rate=video_fps, + output_len=num_frames) + + # Make windows: + zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1])) + network_output = np.concatenate( + (zero_pad, network_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride): + windows.append( + network_output[window_index:window_index + audio_window_size]) + + return np.array(windows) + + +def conv_audio_to_deepspeech_input_vector(audio, + sample_rate, + num_cepstrum, + num_context): + """ + Convert audio raw data into DeepSpeech input vector. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + num_cepstrum : int + Number of cepstrum. + num_context : int + Number of context. + + Returns + ------- + np.array + DeepSpeech input vector. + """ + # Get mfcc coefficients: + features = mfcc( + signal=audio, + samplerate=sample_rate, + numcep=num_cepstrum) + + # We only keep every second feature (BiRNN stride = 2): + features = features[::2] + + # One stride per time step in the input: + num_strides = len(features) + + # Add empty initial and final contexts: + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) + features = np.concatenate((empty_context, features, empty_context)) + + # Create a view into the array with overlapping strides of size + # numcontext (past) + 1 (present) + numcontext (future): + window_size = 2 * num_context + 1 + train_inputs = np.lib.stride_tricks.as_strided( + features, + shape=(num_strides, window_size, num_cepstrum), + strides=(features.strides[0], + features.strides[0], features.strides[1]), + writeable=False) + + # Flatten the second and third dimensions: + train_inputs = np.reshape(train_inputs, [num_strides, -1]) + + train_inputs = np.copy(train_inputs) + train_inputs = (train_inputs - np.mean(train_inputs)) / \ + np.std(train_inputs) + + return train_inputs + + +def interpolate_features(features, + input_rate, + output_rate, + output_len): + """ + Interpolate DeepSpeech features. + + Parameters + ---------- + features : np.array + DeepSpeech features. + input_rate : int + input rate (FPS). + output_rate : int + Output rate (FPS). + output_len : int + Output data length. + + Returns + ------- + np.array + Interpolated data. + """ + input_len = features.shape[0] + num_features = features.shape[1] + input_timestamps = np.arange(input_len) / float(input_rate) + output_timestamps = np.arange(output_len) / float(output_rate) + output_features = np.zeros((output_len, num_features)) + for feature_idx in range(num_features): + output_features[:, feature_idx] = np.interp( + x=output_timestamps, + xp=input_timestamps, + fp=features[:, feature_idx]) + return output_features diff --git a/data_utils/deepspeech_features/deepspeech_store.py b/data_utils/deepspeech_features/deepspeech_store.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2f603e88fe644eb1025e3088de3b09341ad9fd --- /dev/null +++ b/data_utils/deepspeech_features/deepspeech_store.py @@ -0,0 +1,172 @@ +""" + Routines for loading DeepSpeech model. +""" + +__all__ = ['get_deepspeech_model_file'] + +import os +import zipfile +import logging +import hashlib + + +deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' + + +def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): + """ + Return location for the pretrained on local file system. This function will download from online model zoo when + model cannot be found or has mismatch. The root directory will be created if it doesn't exist. + + Parameters + ---------- + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" + file_name = "deepspeech-0_1_0-b90017e8.pb" + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) + file_path = os.path.join(local_model_store_dir_path, file_name) + if os.path.exists(file_path): + if _check_sha1(file_path, sha1_hash): + return file_path + else: + logging.warning("Mismatch in the content of model file detected. Downloading again.") + else: + logging.info("Model file not found. Downloading to {}.".format(file_path)) + + if not os.path.exists(local_model_store_dir_path): + os.makedirs(local_model_store_dir_path) + + zip_file_path = file_path + ".zip" + _download( + url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( + repo_url=deepspeech_features_repo_url, + repo_release_tag="v0.0.1", + file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(local_model_store_dir_path) + os.remove(zip_file_path) + + if _check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError("Downloaded file has different hash. Please try again.") + + +def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): + """ + Download an given URL + + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified + but doesn't match. + retries : integer, default 5 + The number of times to attempt the download in case of failure or non 200 return codes + verify_ssl : bool, default True + Verify SSL certificates. + + Returns + ------- + str + The file path of the downloaded file. + """ + import warnings + try: + import requests + except ImportError: + class requests_failed_to_import(object): + pass + requests = requests_failed_to_import + + if path is None: + fname = url.split("/")[-1] + # Empty filenames are invalid + assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split("/")[-1]) + else: + fname = path + assert retries >= 0, "Number of retries should be at least 0" + + if not verify_ssl: + warnings.warn( + "Unverified HTTPS request is being made (verify_ssl=False). " + "Adding certificate verification is strongly advised.") + + if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + while retries + 1 > 0: + # Disable pyling too broad Exception + # pylint: disable=W0703 + try: + print("Downloading {} from {}...".format(fname, url)) + r = requests.get(url, stream=True, verify=verify_ssl) + if r.status_code != 200: + raise RuntimeError("Failed downloading url {}".format(url)) + with open(fname, "wb") as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if sha1_hash and not _check_sha1(fname, sha1_hash): + raise UserWarning("File {} is downloaded but the content hash does not match." + " The repo may be outdated or download may be incomplete. " + "If the `repo_url` is overridden, consider switching to " + "the default repo.".format(fname)) + break + except Exception as e: + retries -= 1 + if retries <= 0: + raise e + else: + print("download failed, retrying, {} attempt{} left" + .format(retries, "s" if retries > 1 else "")) + + return fname + + +def _check_sha1(filename, sha1_hash): + """ + Check whether the sha1 hash of the file content matches the expected hash. + + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, "rb") as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + return sha1.hexdigest() == sha1_hash diff --git a/data_utils/deepspeech_features/extract_ds_features.py b/data_utils/deepspeech_features/extract_ds_features.py new file mode 100644 index 0000000000000000000000000000000000000000..4063017f9d899a152d0aaf55346f71787e2d72e4 --- /dev/null +++ b/data_utils/deepspeech_features/extract_ds_features.py @@ -0,0 +1,130 @@ +""" + Script for extracting DeepSpeech features from audio file. +""" + +import os +import argparse +import numpy as np +import pandas as pd +from deepspeech_store import get_deepspeech_model_file +from deepspeech_features import conv_audios_to_deepspeech + + +def parse_args(): + """ + Create python script parameters. + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract DeepSpeech features from audio file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--input", + type=str, + required=True, + help="path to input audio file or directory") + parser.add_argument( + "--output", + type=str, + help="path to output file with DeepSpeech features") + parser.add_argument( + "--deepspeech", + type=str, + help="path to DeepSpeech 0.1.0 frozen model") + parser.add_argument( + "--metainfo", + type=str, + help="path to file with meta-information") + + args = parser.parse_args() + return args + + +def extract_features(in_audios, + out_files, + deepspeech_pb_path, + metainfo_file_path=None): + """ + Real extract audio from video file. + Parameters + ---------- + in_audios : list of str + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + metainfo_file_path : str, default None + Path to file with meta-information. + """ + if metainfo_file_path is None: + num_frames_info = [None] * len(in_audios) + else: + train_df = pd.read_csv( + metainfo_file_path, + sep="\t", + index_col=False, + dtype={"Id": np.int, "File": np.unicode, "Count": np.int}) + num_frames_info = train_df["Count"].values + assert (len(num_frames_info) == len(in_audios)) + + for i, in_audio in enumerate(in_audios): + if not out_files[i]: + file_stem, _ = os.path.splitext(in_audio) + out_files[i] = file_stem + ".npy" + #print(out_files[i]) + conv_audios_to_deepspeech( + audios=in_audios, + out_files=out_files, + num_frames_info=num_frames_info, + deepspeech_pb_path=deepspeech_pb_path) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_audio = os.path.expanduser(args.input) + if not os.path.exists(in_audio): + raise Exception("Input file/directory doesn't exist: {}".format(in_audio)) + deepspeech_pb_path = args.deepspeech + #add + deepspeech_pb_path = True + args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb' + if deepspeech_pb_path is None: + deepspeech_pb_path = "" + if deepspeech_pb_path: + deepspeech_pb_path = os.path.expanduser(args.deepspeech) + if not os.path.exists(deepspeech_pb_path): + deepspeech_pb_path = get_deepspeech_model_file() + if os.path.isfile(in_audio): + extract_features( + in_audios=[in_audio], + out_files=[args.output], + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + else: + audio_file_paths = [] + for file_name in os.listdir(in_audio): + if not os.path.isfile(os.path.join(in_audio, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() == ".wav": + audio_file_path = os.path.join(in_audio, file_name) + audio_file_paths.append(audio_file_path) + audio_file_paths = sorted(audio_file_paths) + out_file_paths = [""] * len(audio_file_paths) + extract_features( + in_audios=audio_file_paths, + out_files=out_file_paths, + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + + +if __name__ == "__main__": + main() + diff --git a/data_utils/deepspeech_features/extract_wav.py b/data_utils/deepspeech_features/extract_wav.py new file mode 100644 index 0000000000000000000000000000000000000000..5f39e8b0e231762518c20cd88dee532624059584 --- /dev/null +++ b/data_utils/deepspeech_features/extract_wav.py @@ -0,0 +1,87 @@ +""" + Script for extracting audio (16-bit, mono, 22000 Hz) from video file. +""" + +import os +import argparse +import subprocess + + +def parse_args(): + """ + Create python script parameters. + + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract audio from video file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--in-video", + type=str, + required=True, + help="path to input video file or directory") + parser.add_argument( + "--out-audio", + type=str, + help="path to output audio file") + + args = parser.parse_args() + return args + + +def extract_audio(in_video, + out_audio): + """ + Real extract audio from video file. + + Parameters + ---------- + in_video : str + Path to input video file. + out_audio : str + Path to output audio file. + """ + if not out_audio: + file_stem, _ = os.path.splitext(in_video) + out_audio = file_stem + ".wav" + # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" + # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" + subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_video = os.path.expanduser(args.in_video) + if not os.path.exists(in_video): + raise Exception("Input file/directory doesn't exist: {}".format(in_video)) + if os.path.isfile(in_video): + extract_audio( + in_video=in_video, + out_audio=args.out_audio) + else: + video_file_paths = [] + for file_name in os.listdir(in_video): + if not os.path.isfile(os.path.join(in_video, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() in (".mp4", ".mkv", ".avi"): + video_file_path = os.path.join(in_video, file_name) + video_file_paths.append(video_file_path) + video_file_paths = sorted(video_file_paths) + for video_file_path in video_file_paths: + extract_audio( + in_video=video_file_path, + out_audio="") + + +if __name__ == "__main__": + main() diff --git a/data_utils/deepspeech_features/fea_win.py b/data_utils/deepspeech_features/fea_win.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9c666309c8eb2f029d1ac99095d0c15a270549 --- /dev/null +++ b/data_utils/deepspeech_features/fea_win.py @@ -0,0 +1,11 @@ +import numpy as np + +net_output = np.load('french.ds.npy').reshape(-1, 29) +win_size = 16 +zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) +net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) +windows = [] +for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append(net_output[window_index:window_index + win_size]) +print(np.array(windows).shape) +np.save('aud_french.npy', np.array(windows)) diff --git a/data_utils/easyportrait/create_teeth_mask.py b/data_utils/easyportrait/create_teeth_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..4707bb4ea042164f4564d5cd3bb01f4e3e7b30f3 --- /dev/null +++ b/data_utils/easyportrait/create_teeth_mask.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot + +import os +import glob +from tqdm import tqdm +import numpy as np + +def main(): + parser = ArgumentParser() + parser.add_argument('datset', help='Image file') + parser.add_argument('--config', default="./data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-fp/fpn-fp.py", help='Config file') + parser.add_argument('--checkpoint', default="./data_utils/easyportrait/fpn-fp-512.pth", help='Checkpoint file') + + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + model = init_segmentor(args.config, args.checkpoint, device='cuda:0') + + # test a single image + dataset_path = os.path.join(args.datset, 'ori_imgs') + out_path = os.path.join(args.datset, 'teeth_mask') + os.makedirs(out_path, exist_ok=True) + + for file in tqdm(glob.glob(os.path.join(dataset_path, '*.jpg'))): + result = inference_segmentor(model, file) + result[0][result[0]!=7] = 0 + np.save(file.replace('jpg', 'npy').replace('ori_imgs', 'teeth_mask'), result[0].astype(np.bool_)) + + +if __name__ == '__main__': + main() diff --git a/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_1024x1024.py b/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_1024x1024.py new file mode 100644 index 0000000000000000000000000000000000000000..eca3c193390ade09a3127602f1bce2507bd90eec --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_1024x1024.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'EasyPortraitDataset' +data_root = 'path/to/data/EasyPortrait' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Pad', size=(1920, 1920), pad_val=0, seg_pad_val=255), + dict(type='Resize', img_scale=(1024, 1024)), + + # We don't use RandomFlip, but need it in the code to fix error: https://github.com/open-mmlab/mmsegmentation/issues/231 + dict(type='RandomFlip', prob=0.0), + dict(type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=9), + dict(type='Normalize', **img_norm_cfg), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1024, 1024), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/train', + ann_dir='annotations/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/val', + ann_dir='annotations/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/test', + ann_dir='annotations/test', + pipeline=test_pipeline)) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_384x384.py b/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_384x384.py new file mode 100644 index 0000000000000000000000000000000000000000..f1aef5a64b61ff0def1df8995b7ae2ba035651a4 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_384x384.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'EasyPortraitDataset' +data_root = 'path/to/data/EasyPortrait' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Pad', size=(1920, 1920), pad_val=0, seg_pad_val=255), + dict(type='Resize', img_scale=(384, 384)), + + # We don't use RandomFlip, but need it in the code to fix error: https://github.com/open-mmlab/mmsegmentation/issues/231 + dict(type='RandomFlip', prob=0.0), + dict(type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=9), + dict(type='Normalize', **img_norm_cfg), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/train', + ann_dir='annotations/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/val', + ann_dir='annotations/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/test', + ann_dir='annotations/test', + pipeline=test_pipeline)) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_512x512.py b/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..09cca6d85c68eb4e354955a49392b8282245f808 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_512x512.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'EasyPortraitDataset' +data_root = 'path/to/data/EasyPortrait' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Pad', size=(1920, 1920), pad_val=0, seg_pad_val=255), + dict(type='Resize', img_scale=(512, 512)), + + # We don't use RandomFlip, but need it in the code to fix error: https://github.com/open-mmlab/mmsegmentation/issues/231 + dict(type='RandomFlip', prob=0.0), + dict(type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=9), + dict(type='Normalize', **img_norm_cfg), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(512, 512), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/train', + ann_dir='annotations/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/val', + ann_dir='annotations/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/test', + ann_dir='annotations/test', + pipeline=test_pipeline)) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/default_runtime.py b/data_utils/easyportrait/local_configs/__base__/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..5f8d0f50d06280f5729eda62fcde14258b16edae --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/default_runtime.py @@ -0,0 +1,14 @@ +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/models/bisenetv2.py b/data_utils/easyportrait/local_configs/__base__/models/bisenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..ea96fbf1ef7bad0b9eae8cfc3af98c1560b45b39 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/models/bisenetv2.py @@ -0,0 +1,80 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='BiSeNetV2', + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + init_cfg=None, + align_corners=False), + decode_head=dict( + type='FCNHead', + in_channels=128, + in_index=0, + channels=1024, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=16, + channels=16, + num_convs=2, + num_classes=19, + in_index=1, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=32, + channels=64, + num_convs=2, + num_classes=19, + in_index=2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=64, + channels=256, + num_convs=2, + num_classes=19, + in_index=3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=1024, + num_convs=2, + num_classes=19, + in_index=4, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/models/fcn_resnet50.py b/data_utils/easyportrait/local_configs/__base__/models/fcn_resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..9b3a8de65c45be9fa2708976cd6c971c87202577 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/models/fcn_resnet50.py @@ -0,0 +1,45 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='FCNHead', + in_channels=2048, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/models/fpn_resnet50.py b/data_utils/easyportrait/local_configs/__base__/models/fpn_resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..30f7c1c5ecde0b3152e177625262c7afa26def4b --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/models/fpn_resnet50.py @@ -0,0 +1,36 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/models/lraspp.py b/data_utils/easyportrait/local_configs/__base__/models/lraspp.py new file mode 100644 index 0000000000000000000000000000000000000000..350362759dfc6b0ed73854d3cf8d49209989de1d --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/models/lraspp.py @@ -0,0 +1,25 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='MobileNetV3', + arch='large', + out_indices=(1, 3, 16), + norm_cfg=norm_cfg), + decode_head=dict( + type='LRASPPHead', + in_channels=(16, 24, 960), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/models/segformer.py b/data_utils/easyportrait/local_configs/__base__/models/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..838b7c89054ba5ce416027b7a692b643b803948d --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/models/segformer.py @@ -0,0 +1,34 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='MixVisionTransformer', + in_channels=3, + embed_dims=32, + num_stages=4, + num_layers=[2, 2, 2, 2], + num_heads=[1, 2, 5, 8], + patch_sizes=[7, 3, 3, 3], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1), + decode_head=dict( + type='SegformerHead', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + channels=256, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/schedules/schedule_10k_adamw.py b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_10k_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..73ba72104950b1c6dbe070433efef6782a2d881a --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_10k_adamw.py @@ -0,0 +1,11 @@ +# optimizer +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() + +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) + +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=10000) +checkpoint_config = dict(by_epoch=False, interval=2000) +evaluation = dict(interval=2000, metric='mIoU') \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/schedules/schedule_160k_adamw.py b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_160k_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..729269a7f4caaa46aa14e1d1abee3c5045160e73 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_160k_adamw.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=160000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU') \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/schedules/schedule_20k_adamw.py b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_20k_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..479bf486beb47fdd9265e4c8eb3121e00badd316 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_20k_adamw.py @@ -0,0 +1,11 @@ +# optimizer +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() + +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) + +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=20000) +checkpoint_config = dict(by_epoch=False, interval=2000) +evaluation = dict(interval=2000, metric='mIoU') \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/schedules/schedule_40k_adamw.py b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_40k_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..f577039bb3ed728c526dad58fac1ce8a57fdb363 --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_40k_adamw.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=40000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU') \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/__base__/schedules/schedule_80k_adamw.py b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_80k_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e5f40e0f0e0e8ee67dc19b664b6fbb81be39cf --- /dev/null +++ b/data_utils/easyportrait/local_configs/__base__/schedules/schedule_80k_adamw.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=80000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU') \ No newline at end of file diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/bisenet-fp/bisenetv2-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/bisenet-fp/bisenetv2-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab497640a1b8b70e2bcd40671583473bab55879 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/bisenet-fp/bisenetv2-fp.py @@ -0,0 +1,221 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='BiSeNetV2', + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + init_cfg=None, + align_corners=False), + decode_head=dict( + type='FCNHead', + in_channels=128, + in_index=0, + channels=1024, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=16, + channels=16, + num_convs=2, + num_classes=8, + in_index=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=32, + channels=64, + num_convs=2, + num_classes=8, + in_index=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=64, + channels=256, + num_convs=2, + num_classes=8, + in_index=3, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=1024, + num_convs=2, + num_classes=8, + in_index=4, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='AdamW', lr=0.05, weight_decay=0.0001) +optimizer_config = dict() +lr_config = dict( + policy='poly', + power=0.9, + min_lr=0.0, + by_epoch=True, + warmup='linear', + warmup_iters=1000) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/bisenet-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/bisenet-ps/bisenetv2-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/bisenet-ps/bisenetv2-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..683fe976db02d0fc86afedc1325f51a761b8e738 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/bisenet-ps/bisenetv2-ps.py @@ -0,0 +1,218 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='BiSeNetV2', + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + init_cfg=None, + align_corners=False), + decode_head=dict( + type='FCNHead', + in_channels=128, + in_index=0, + channels=1024, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=16, + channels=16, + num_convs=2, + num_classes=2, + in_index=1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=32, + channels=64, + num_convs=2, + num_classes=2, + in_index=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=64, + channels=256, + num_convs=2, + num_classes=2, + in_index=3, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=1024, + num_convs=2, + num_classes=2, + in_index=4, + norm_cfg=dict(type='SyncBN', requires_grad=True), + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='AdamW', lr=0.05, weight_decay=0.0001) +optimizer_config = dict() +lr_config = dict( + policy='poly', + power=0.9, + min_lr=0.0, + by_epoch=True, + warmup='linear', + warmup_iters=1000) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/bisenet-ps/' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/danet-fp/danet-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/danet-fp/danet-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..9af597abee02d8f8da0d4300fe8c4261008f68d7 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/danet-fp/danet-fp.py @@ -0,0 +1,174 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DAHead', + in_channels=2048, + in_index=3, + channels=512, + pam_channels=64, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=True) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/danet-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/danet-ps/danet-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/danet-ps/danet-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..7485d440e310c6e850a1b2ce656a85ff8cbd4db8 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/danet-ps/danet-ps.py @@ -0,0 +1,171 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DAHead', + in_channels=2048, + in_index=3, + channels=512, + pam_channels=64, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=True) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/danet-ps' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/deeplab-fp/deeplabv3-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/deeplab-fp/deeplabv3-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..555beecf1c7e806ef6d21681f5bb75015a2322bf --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/deeplab-fp/deeplabv3-fp.py @@ -0,0 +1,174 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=True) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/deeplabv3-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/deeplab-ps/deeplabv3-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/deeplab-ps/deeplabv3-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..a3304b4c5c51def54f58278dba4985614502c4ca --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/deeplab-ps/deeplabv3-ps.py @@ -0,0 +1,171 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) +optimizer_config = dict() +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=True) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/deeplabv3-ps' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fastscnn-fp/fastscnn-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fastscnn-fp/fastscnn-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe1b2f19d8cb2b601a73dbfcac05618d7a75012 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fastscnn-fp/fastscnn-fp.py @@ -0,0 +1,165 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='FastSCNN', + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01), + align_corners=False), + decode_head=dict( + type='DepthwiseSeparableFCNHead', + in_channels=128, + channels=128, + concat_input=False, + num_classes=8, + in_index=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)), + auxiliary_head=[ + dict(type='FCNHead', in_channels=128, channels=32, num_classes=8), + dict(type='FCNHead', in_channels=128, channels=32, num_classes=8) + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='SGD', lr=0.12, weight_decay=4e-05, momentum=0.9) +optimizer_config = dict() +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=True) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/fast_scnn-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fastscnn-ps/fastscnn-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fastscnn-ps/fastscnn-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..664ee867cdf078fc2efa727085b20d419e89f4a9 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fastscnn-ps/fastscnn-ps.py @@ -0,0 +1,162 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='FastSCNN', + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01), + align_corners=False), + decode_head=dict( + type='DepthwiseSeparableFCNHead', + in_channels=128, + channels=128, + concat_input=False, + num_classes=2, + in_index=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)), + auxiliary_head=[ + dict(type='FCNHead', in_channels=128, channels=32, num_classes=2), + dict(type='FCNHead', in_channels=128, channels=32, num_classes=2) + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict(type='SGD', lr=0.12, weight_decay=4e-05, momentum=0.9) +optimizer_config = dict() +lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=True) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/fast_scnn-ps' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fcn-fp/fcn-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fcn-fp/fcn-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0ca180de039bb25e4aeba0cbefe0841bdfe900 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fcn-fp/fcn-fp.py @@ -0,0 +1,187 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='mmcls://mobilenet_v2', + backbone=dict( + type='MobileNetV2', + widen_factor=1.0, + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6), + norm_cfg=dict(type='SyncBN', requires_grad=True)), + decode_head=dict( + type='FCNHead', + in_channels=320, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=96, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict( + type='AdamW', + lr=6e-05, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys=dict( + pos_block=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + head=dict(lr_mult=10.0)))) +optimizer_config = dict() +lr_config = dict( + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-06, + power=1.0, + min_lr=0.0, + by_epoch=False) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/fcn-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fcn-ps/fcn-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fcn-ps/fcn-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..a9222d26578e96fb2f5ca2c2cdadc402f5eebab2 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fcn-ps/fcn-ps.py @@ -0,0 +1,184 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='mmcls://mobilenet_v2', + backbone=dict( + type='MobileNetV2', + widen_factor=1.0, + strides=(1, 2, 2, 1, 1, 1, 1), + dilations=(1, 1, 1, 2, 2, 4, 4), + out_indices=(1, 2, 4, 6), + norm_cfg=dict(type='SyncBN', requires_grad=True)), + decode_head=dict( + type='FCNHead', + in_channels=320, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=96, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict( + type='AdamW', + lr=6e-05, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys=dict( + pos_block=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + head=dict(lr_mult=10.0)))) +optimizer_config = dict() +lr_config = dict( + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-06, + power=1.0, + min_lr=0.0, + by_epoch=False) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/fcn-ps' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-fp/fpn-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-fp/fpn-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4a08e2d8ca3eecc0887d64124dc54fd215ad55 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-fp/fpn-fp.py @@ -0,0 +1,182 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict( + type='AdamW', + lr=6e-05, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys=dict( + pos_block=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + head=dict(lr_mult=10.0)))) +optimizer_config = dict() +lr_config = dict( + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-06, + power=1.0, + min_lr=0.0, + by_epoch=False) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/fpn-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-ps/fpn-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-ps/fpn-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..a01a5f5bfb5ea0667d3ec1e23805d7c84459e831 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-ps/fpn-ps.py @@ -0,0 +1,179 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict( + type='AdamW', + lr=6e-05, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys=dict( + pos_block=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + head=dict(lr_mult=10.0)))) +optimizer_config = dict() +lr_config = dict( + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-06, + power=1.0, + min_lr=0.0, + by_epoch=False) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +work_dir = 'work_dirs/petrova/fpn-ps' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/segformer-fp/segformer-fp.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/segformer-fp/segformer-fp.py new file mode 100644 index 0000000000000000000000000000000000000000..f20545a13364c3a1a1989187db176abe1ae14a98 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/segformer-fp/segformer-fp.py @@ -0,0 +1,182 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained= + 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth', + backbone=dict( + type='MixVisionTransformer', + in_channels=3, + embed_dims=32, + num_stages=4, + num_layers=[2, 2, 2, 2], + num_heads=[1, 2, 5, 8], + patch_sizes=[7, 3, 3, 3], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1), + decode_head=dict( + type='SegformerHead', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + channels=256, + dropout_ratio=0.1, + num_classes=8, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitFPDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_fp/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_fp/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitFPDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'skin', 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_fp/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict( + type='AdamW', + lr=6e-05, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys=dict( + pos_block=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + head=dict(lr_mult=10.0)))) +optimizer_config = dict() +lr_config = dict( + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-06, + power=1.0, + min_lr=0.0, + by_epoch=False) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth' +work_dir = 'work_dirs/petrova/segformer-fp' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/segformer-ps/segformer-ps.py b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/segformer-ps/segformer-ps.py new file mode 100644 index 0000000000000000000000000000000000000000..172b63a845b9a5d2e964732fc606d8a45809f482 --- /dev/null +++ b/data_utils/easyportrait/local_configs/easyportrait_experiments_v2/segformer-ps/segformer-ps.py @@ -0,0 +1,179 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained= + 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth', + backbone=dict( + type='MixVisionTransformer', + in_channels=3, + embed_dims=32, + num_stages=4, + num_layers=[2, 2, 2, 2], + num_heads=[1, 2, 5, 8], + patch_sizes=[7, 3, 3, 3], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1), + decode_head=dict( + type='SegformerHead', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + channels=256, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + train_cfg=dict(), + test_cfg=dict(mode='whole')) +dataset_type = 'EasyPortraitPSDataset' +data_root = '/home/jovyan/datasets/wacv_24/' +img_norm_cfg = dict( + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + train=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/train', + ann_dir='easyportrait_384/annotations_ps/train', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomFlip', prob=0.0), + dict( + type='PhotoMetricDistortion', + brightness_delta=16, + contrast_range=(0.5, 1.0), + saturation_range=(0.5, 1.0), + hue_delta=5), + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) + ]), + val=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/val', + ann_dir='easyportrait_384/annotations_ps/val', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + test=dict( + type='EasyPortraitPSDataset', + data_root='/home/jovyan/datasets/wacv_24/', + classes=('background', 'person'), + img_dir='easyportrait_384/images/test', + ann_dir='easyportrait_384/annotations_ps/test', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(384, 384), + flip=False, + transforms=[ + dict( + type='Normalize', + mean=[143.55267075, 132.96705975, 126.94924335], + std=[60.2625333, 60.32740275, 59.30988645], + to_rgb=True), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ]), + samples_per_gpu=32, + workers_per_gpu=8) +log_config = dict( + interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True +optimizer = dict( + type='AdamW', + lr=6e-05, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys=dict( + pos_block=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + head=dict(lr_mult=10.0)))) +optimizer_config = dict() +lr_config = dict( + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-06, + power=1.0, + min_lr=0.0, + by_epoch=False) +default_hooks = dict(stop=dict(type='EarlyStoppingHook', monitor='mIoU')) +runner = dict(type='EpochBasedRunner', max_epochs=100) +checkpoint_config = dict(by_epoch=True, interval=100) +evaluation = dict(interval=1, metric='mIoU', save_best='mIoU') +checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth' +work_dir = 'work_dirs/petrova/segformer-ps' +gpu_ids = [0] +auto_resume = False diff --git a/data_utils/easyportrait/mmseg/.mim/configs b/data_utils/easyportrait/mmseg/.mim/configs new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/easyportrait/mmseg/.mim/tools b/data_utils/easyportrait/mmseg/.mim/tools new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/easyportrait/mmseg/__init__.py b/data_utils/easyportrait/mmseg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c28bf4e5e0648190cf586de3d987a8ba64cd26ab --- /dev/null +++ b/data_utils/easyportrait/mmseg/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv +from packaging.version import parse + +from .version import __version__, version_info + +MMCV_MIN = '1.3.13' +MMCV_MAX = '1.8.0' + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +mmcv_min_version = digit_version(MMCV_MIN) +mmcv_max_version = digit_version(MMCV_MAX) +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/data_utils/easyportrait/mmseg/apis/__init__.py b/data_utils/easyportrait/mmseg/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6881805330b61eecc632ac7e93d94cf83dab6cc --- /dev/null +++ b/data_utils/easyportrait/mmseg/apis/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_segmentor, init_segmentor, show_result_pyplot +from .test import multi_gpu_test, single_gpu_test +from .train import (get_root_logger, init_random_seed, set_random_seed, + train_segmentor) + +__all__ = [ + 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', + 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', + 'show_result_pyplot', 'init_random_seed' +] diff --git a/data_utils/easyportrait/mmseg/apis/inference.py b/data_utils/easyportrait/mmseg/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5bbe66634ebc8d7678f1f7d56c05393beaa8d7b7 --- /dev/null +++ b/data_utils/easyportrait/mmseg/apis/inference.py @@ -0,0 +1,145 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import matplotlib.pyplot as plt +import mmcv +import torch +from mmcv.parallel import collate, scatter +from mmcv.runner import load_checkpoint + +from mmseg.datasets.pipelines import Compose +from mmseg.models import build_segmentor + + +def init_segmentor(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + config.model.pretrained = None + config.model.train_cfg = None + model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.CLASSES = checkpoint['meta']['CLASSES'] + model.PALETTE = checkpoint['meta']['PALETTE'] + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +class LoadImage: + """A simple pipeline to load image.""" + + def __call__(self, results): + """Call function to load images into results. + + Args: + results (dict): A result dict contains the file name + of the image to be read. + + Returns: + dict: ``results`` will be returned containing loaded image. + """ + + if isinstance(results['img'], str): + results['filename'] = results['img'] + results['ori_filename'] = results['img'] + else: + results['filename'] = None + results['ori_filename'] = None + img = mmcv.imread(results['img']) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + return results + + +def inference_segmentor(model, imgs): + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + (list[Tensor]): The segmentation result. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = [] + imgs = imgs if isinstance(imgs, list) else [imgs] + for img in imgs: + img_data = dict(img=img) + img_data = test_pipeline(img_data) + data.append(img_data) + data = collate(data, samples_per_gpu=len(imgs)) + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + data['img_metas'] = [i.data[0] for i in data['img_metas']] + + # forward the model + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def show_result_pyplot(model, + img, + result, + palette=None, + fig_size=(15, 10), + opacity=0.5, + title='', + block=True, + out_file=None): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (list): The segmentation result. + palette (list[list[int]]] | None): The palette of segmentation + map. If None is given, random palette will be generated. + Default: None + fig_size (tuple): Figure size of the pyplot figure. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + block (bool): Whether to block the pyplot figure. + Default is True. + out_file (str or None): The path to write the image. + Default: None. + """ + if hasattr(model, 'module'): + model = model.module + img = model.show_result( + img, result, palette=palette, show=False, opacity=opacity) + plt.figure(figsize=fig_size) + plt.imshow(mmcv.bgr2rgb(img)) + plt.title(title) + plt.tight_layout() + plt.show(block=block) + if out_file is not None: + mmcv.imwrite(img, out_file) diff --git a/data_utils/easyportrait/mmseg/apis/test.py b/data_utils/easyportrait/mmseg/apis/test.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4fcc97975741697b7c99c32f66c47b6206f1a6 --- /dev/null +++ b/data_utils/easyportrait/mmseg/apis/test.py @@ -0,0 +1,233 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +import warnings + +import mmcv +import numpy as np +import torch +from mmcv.engine import collect_results_cpu, collect_results_gpu +from mmcv.image import tensor2imgs +from mmcv.runner import get_dist_info + + +def np2tmp(array, temp_file_name=None, tmpdir=None): + """Save ndarray to local numpy file. + + Args: + array (ndarray): Ndarray to save. + temp_file_name (str): Numpy file name. If 'temp_file_name=None', this + function will generate a file name with tempfile.NamedTemporaryFile + to save ndarray. Default: None. + tmpdir (str): Temporary directory to save Ndarray files. Default: None. + Returns: + str: The numpy file name. + """ + + if temp_file_name is None: + temp_file_name = tempfile.NamedTemporaryFile( + suffix='.npy', delete=False, dir=tmpdir).name + np.save(temp_file_name, array) + return temp_file_name + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + efficient_test=False, + opacity=0.5, + pre_eval=False, + format_only=False, + format_args={}): + """Test with single GPU by progressive mode. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + show (bool): Whether show results during inference. Default: False. + out_dir (str, optional): If specified, the results will be dumped into + the directory to save output results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Mutually exclusive with + pre_eval and format_results. Default: False. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + pre_eval (bool): Use dataset.pre_eval() function to generate + pre_results for metric evaluation. Mutually exclusive with + efficient_test and format_results. Default: False. + format_only (bool): Only format result for results commit. + Mutually exclusive with pre_eval and efficient_test. + Default: False. + format_args (dict): The args for format_results. Default: {}. + Returns: + list: list of evaluation pre-results or list of save file names. + """ + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' + 'evaluation is CPU memory friendly with pre_eval=True') + mmcv.mkdir_or_exist('.efficient_test') + # when none of them is set true, return segmentation results as + # a list of np.array. + assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ + '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ + 'exclusive, only one of them could be true .' + + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + # The pipeline about how the data_loader retrieval samples from dataset: + # sampler -> batch_sampler -> indices + # The indices are passed to dataset_fetcher to get data from dataset. + # data_fetcher -> collate_fn(dataset[index]) -> data_sample + # we use batch_sampler to get correct data idx + loader_indices = data_loader.batch_sampler + + for batch_indices, data in zip(loader_indices, data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + + if show or out_dir: + img_tensor = data['img'][0] + img_metas = data['img_metas'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + + for img, img_meta in zip(imgs, img_metas): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + ori_h, ori_w = img_meta['ori_shape'][:-1] + img_show = mmcv.imresize(img_show, (ori_w, ori_h)) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result, + palette=dataset.PALETTE, + show=show, + out_file=out_file, + opacity=opacity) + + if efficient_test: + result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] + + if format_only: + result = dataset.format_results( + result, indices=batch_indices, **format_args) + if pre_eval: + # TODO: adapt samples_per_gpu > 1. + # only samples_per_gpu=1 valid now + result = dataset.pre_eval(result, indices=batch_indices) + results.extend(result) + else: + results.extend(result) + + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + + return results + + +def multi_gpu_test(model, + data_loader, + tmpdir=None, + gpu_collect=False, + efficient_test=False, + pre_eval=False, + format_only=False, + format_args={}): + """Test model with multiple gpus by progressive mode. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. The same path is used for efficient + test. Default: None. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Mutually exclusive with + pre_eval and format_results. Default: False. + pre_eval (bool): Use dataset.pre_eval() function to generate + pre_results for metric evaluation. Mutually exclusive with + efficient_test and format_results. Default: False. + format_only (bool): Only format result for results commit. + Mutually exclusive with pre_eval and efficient_test. + Default: False. + format_args (dict): The args for format_results. Default: {}. + + Returns: + list: list of evaluation pre-results or list of save file names. + """ + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' + 'evaluation is CPU memory friendly with pre_eval=True') + mmcv.mkdir_or_exist('.efficient_test') + # when none of them is set true, return segmentation results as + # a list of np.array. + assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ + '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ + 'exclusive, only one of them could be true .' + + model.eval() + results = [] + dataset = data_loader.dataset + # The pipeline about how the data_loader retrieval samples from dataset: + # sampler -> batch_sampler -> indices + # The indices are passed to dataset_fetcher to get data from dataset. + # data_fetcher -> collate_fn(dataset[index]) -> data_sample + # we use batch_sampler to get correct data idx + + # batch_sampler based on DistributedSampler, the indices only point to data + # samples of related machine. + loader_indices = data_loader.batch_sampler + + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + + for batch_indices, data in zip(loader_indices, data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + if efficient_test: + result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] + + if format_only: + result = dataset.format_results( + result, indices=batch_indices, **format_args) + if pre_eval: + # TODO: adapt samples_per_gpu > 1. + # only samples_per_gpu=1 valid now + result = dataset.pre_eval(result, indices=batch_indices) + + results.extend(result) + + if rank == 0: + batch_size = len(result) * world_size + for _ in range(batch_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results diff --git a/data_utils/easyportrait/mmseg/apis/train.py b/data_utils/easyportrait/mmseg/apis/train.py new file mode 100644 index 0000000000000000000000000000000000000000..be8e422b319216fc0ae2744658836c509a66ead0 --- /dev/null +++ b/data_utils/easyportrait/mmseg/apis/train.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import random +import warnings + +import mmcv +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, + build_runner, get_dist_info) +from mmcv.utils import build_from_cfg + +from mmseg import digit_version +from mmseg.core import DistEvalHook, EvalHook, build_optimizer +from mmseg.datasets import build_dataloader, build_dataset +from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint, + get_root_logger) + + +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_segmentor(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None): + """Launch segmentor training.""" + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + # The default loader config + loader_cfg = dict( + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + drop_last=True) + # The overall dataloader settings + loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) + + # The specific dataloader settings + train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})} + data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] + + # put model on devices + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # DDP wrapper + model = build_ddp( + model, + cfg.device, + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + if not torch.cuda.is_available(): + assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ + 'Please use MMCV >= 1.4.4 for CPU training!' + model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if cfg.get('runner') is None: + cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # register hooks + runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, + cfg.checkpoint_config, cfg.log_config, + cfg.get('momentum_config', None)) + if distributed: + # when distributed training by epoch, using`DistSamplerSeedHook` to set + # the different seed to distributed sampler for each epoch, it will + # shuffle dataset at each epoch and avoid overfitting. + if isinstance(runner, EpochBasedRunner): + runner.register_hook(DistSamplerSeedHook()) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + # register eval hooks + if validate: + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + # The specific dataloader settings + val_loader_cfg = { + **loader_cfg, + 'samples_per_gpu': 1, + 'shuffle': False, # Not shuffle by default + **cfg.data.get('val_dataloader', {}), + } + val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the + # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. + runner.register_hook( + eval_hook(val_dataloader, **eval_cfg), priority='LOW') + + # user-defined hooks + if cfg.get('custom_hooks', None): + custom_hooks = cfg.custom_hooks + assert isinstance(custom_hooks, list), \ + f'custom_hooks expect list type, but got {type(custom_hooks)}' + for hook_cfg in cfg.custom_hooks: + assert isinstance(hook_cfg, dict), \ + 'Each item in custom_hooks expects dict type, but got ' \ + f'{type(hook_cfg)}' + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) diff --git a/data_utils/easyportrait/mmseg/core/__init__.py b/data_utils/easyportrait/mmseg/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82f2422d543d36cd78eef833ca751fb88be52a92 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import (OPTIMIZER_BUILDERS, build_optimizer, + build_optimizer_constructor) +from .evaluation import * # noqa: F401, F403 +from .hook import * # noqa: F401, F403 +from .optimizers import * # noqa: F401, F403 +from .seg import * # noqa: F401, F403 +from .utils import * # noqa: F401, F403 + +__all__ = [ + 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' +] diff --git a/data_utils/easyportrait/mmseg/core/builder.py b/data_utils/easyportrait/mmseg/core/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..406dd9b4b7027e9c2254b0d18cf0c80a7161912b --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/builder.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS +from mmcv.utils import Registry, build_from_cfg + +OPTIMIZER_BUILDERS = Registry( + 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) + + +def build_optimizer_constructor(cfg): + constructor_type = cfg.get('type') + if constructor_type in OPTIMIZER_BUILDERS: + return build_from_cfg(cfg, OPTIMIZER_BUILDERS) + elif constructor_type in MMCV_OPTIMIZER_BUILDERS: + return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) + else: + raise KeyError(f'{constructor_type} is not registered ' + 'in the optimizer builder registry.') + + +def build_optimizer(model, cfg): + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/data_utils/easyportrait/mmseg/core/evaluation/__init__.py b/data_utils/easyportrait/mmseg/core/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d16d17e54222f006e32cd6b9e6ca323e3738f03 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/evaluation/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_names import get_classes, get_palette +from .eval_hooks import DistEvalHook, EvalHook +from .metrics import (eval_metrics, intersect_and_union, mean_dice, + mean_fscore, mean_iou, pre_eval_to_metrics) + +__all__ = [ + 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', + 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', + 'intersect_and_union' +] diff --git a/data_utils/easyportrait/mmseg/core/evaluation/class_names.py b/data_utils/easyportrait/mmseg/core/evaluation/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..8f85958191a95802f656b8274ab6bbfb630ba657 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/evaluation/class_names.py @@ -0,0 +1,342 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def cocostuff_classes(): + """CocoStuff class names for external use.""" + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper', + 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', + 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', + 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', + 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', + 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', + 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', + 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood' + ] + + +def loveda_classes(): + """LoveDA class names for external use.""" + return [ + 'background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural' + ] + + +def potsdam_classes(): + """Potsdam class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def vaihingen_classes(): + """Vaihingen class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def isaid_classes(): + """iSAID class names for external use.""" + return [ + 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court', + 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle', + 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout', + 'Soccer_ball_field', 'plane', 'Harbor' + ] + + +def stare_classes(): + """stare class names for external use.""" + return ['background', 'vessel'] + + +def occludedface_classes(): + """occludedface class names for external use.""" + return ['background', 'face'] + + +def easy_portrait_classes(): + """Easy Portrait class names for external use.""" + return ['background', 'person', 'skin', + 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth'] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +def cocostuff_palette(): + """CocoStuff palette for external use.""" + return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], + [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], + [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], + [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], + [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], + [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], + [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], + [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], + [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], + [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], + [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], + [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], + [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], + [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0], + [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96], + [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128], + [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64], + [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96], + [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0], + [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64], + [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96], + [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128], + [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0], + [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32], + [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64], + [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0], + [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32], + [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192], + [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64], + [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32], + [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64], + [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64], + [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32], + [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192], + [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0], + [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96], + [64, 160, 64], [64, 64, 0]] + + +def loveda_palette(): + """LoveDA palette for external use.""" + return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]] + + +def potsdam_palette(): + """Potsdam palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def vaihingen_palette(): + """Vaihingen palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def isaid_palette(): + """iSAID palette for external use.""" + return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, + 127], [0, 0, 127], + [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191], + [0, 127, 255], [0, 100, 155]] + + +def stare_palette(): + """STARE palette for external use.""" + return [[120, 120, 120], [6, 230, 230]] + + +def occludedface_palette(): + """occludedface palette for external use.""" + return [[0, 0, 0], [128, 0, 0]] + + +def easy_portrait_palette(): + """Easy Portrait palette for external use.""" + return [[0, 0, 0], [223, 87, 188], [160, 221, 255], + [130, 106, 237], [200, 121, 255], [255, 183, 255], + [0, 144, 193], [113, 137, 255], [230, 232, 230]] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], + 'loveda': ['loveda'], + 'potsdam': ['potsdam'], + 'vaihingen': ['vaihingen'], + 'cocostuff': [ + 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', + 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', + 'coco_stuff164k' + ], + 'isaid': ['isaid', 'iSAID'], + 'stare': ['stare', 'STARE'], + 'occludedface': ['occludedface'], + 'easy_portrait': ['easy_portrait'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/data_utils/easyportrait/mmseg/core/evaluation/eval_hooks.py b/data_utils/easyportrait/mmseg/core/evaluation/eval_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2be570ace0a7148a0aab136c83278896c30610 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/evaluation/eval_hooks.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings + +import torch.distributed as dist +from mmcv.runner import DistEvalHook as _DistEvalHook +from mmcv.runner import EvalHook as _EvalHook +from torch.nn.modules.batchnorm import _BatchNorm + + +class EvalHook(_EvalHook): + """Single GPU EvalHook, with efficient test support. + + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + pre_eval (bool): Whether to use progressive mode to evaluate model. + Default: False. + Returns: + list: The prediction results. + """ + + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, + *args, + by_epoch=False, + efficient_test=False, + pre_eval=False, + **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.pre_eval = pre_eval + self.latest_results = None + + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` for evaluation hook ' + 'is deprecated, the evaluation hook is CPU memory friendly ' + 'with ``pre_eval=True`` as argument for ``single_gpu_test()`` ' + 'function') + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + if not self._should_evaluate(runner): + return + + from mmseg.apis import single_gpu_test + results = single_gpu_test( + runner.model, self.dataloader, show=False, pre_eval=self.pre_eval) + self.latest_results = results + runner.log_buffer.clear() + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + if self.save_best: + self._save_ckpt(runner, key_score) + + +class DistEvalHook(_DistEvalHook): + """Distributed EvalHook, with efficient test support. + + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + pre_eval (bool): Whether to use progressive mode to evaluate model. + Default: False. + Returns: + list: The prediction results. + """ + + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, + *args, + by_epoch=False, + efficient_test=False, + pre_eval=False, + **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.pre_eval = pre_eval + self.latest_results = None + if efficient_test: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` for evaluation hook ' + 'is deprecated, the evaluation hook is CPU memory friendly ' + 'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` ' + 'function') + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + if not self._should_evaluate(runner): + return + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + from mmseg.apis import multi_gpu_test + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect, + pre_eval=self.pre_eval) + self.latest_results = results + runner.log_buffer.clear() + + if runner.rank == 0: + print('\n') + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + + if self.save_best: + self._save_ckpt(runner, key_score) diff --git a/data_utils/easyportrait/mmseg/core/evaluation/metrics.py b/data_utils/easyportrait/mmseg/core/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..81c2b4dcb83fc186a6b7f6805056728cc9279cc4 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/evaluation/metrics.py @@ -0,0 +1,396 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import mmcv +import numpy as np +import torch + + +def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined score. + Default: False. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + +def intersect_and_union(pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate intersection and Union. + + Args: + pred_label (ndarray | str): Prediction segmentation map + or predict result filename. + label (ndarray | str): Ground truth segmentation map + or label filename. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. The parameter will + work only when label is str. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. The parameter will + work only when label is str. Default: False. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + if isinstance(pred_label, str): + pred_label = torch.from_numpy(np.load(pred_label)) + else: + pred_label = torch.from_numpy((pred_label)) + + if isinstance(label, str): + label = torch.from_numpy( + mmcv.imread(label, flag='unchanged', backend='pillow')) + else: + label = torch.from_numpy(label) + + if reduce_zero_label: + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + if label_map is not None: + label_copy = label.clone() + for old_id, new_id in label_map.items(): + label[label_copy == old_id] = new_id + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, max=num_classes - 1) + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + +def total_intersect_and_union(results, + gt_seg_maps, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate Total Intersection and Union. + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground + truth segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. Default: False. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes. + ndarray: The union of prediction and ground truth histogram on all + classes. + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) + total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) + for result, gt_seg_map in zip(results, gt_seg_maps): + area_intersect, area_union, area_pred_label, area_label = \ + intersect_and_union( + result, gt_seg_map, num_classes, ignore_index, + label_map, reduce_zero_label) + total_area_intersect += area_intersect + total_area_union += area_union + total_area_pred_label += area_pred_label + total_area_label += area_label + return total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label + + +def mean_iou(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. Default: False. + + Returns: + dict[str, float | ndarray]: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category IoU, shape (num_classes, ). + """ + iou_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mIoU'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return iou_result + + +def mean_dice(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Dice (mDice) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. Default: False. + + Returns: + dict[str, float | ndarray]: Default metrics. + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category dice, shape (num_classes, ). + """ + + dice_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mDice'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return dice_result + + +def mean_fscore(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1): + """Calculate Mean F-Score (mFscore) + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. Default: False. + beta (int): Determines the weight of recall in the combined score. + Default: False. + + + Returns: + dict[str, float | ndarray]: Default metrics. + float: Overall accuracy on all images. + ndarray: Per category recall, shape (num_classes, ). + ndarray: Per category precision, shape (num_classes, ). + ndarray: Per category f-score, shape (num_classes, ). + """ + fscore_result = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mFscore'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label, + beta=beta) + return fscore_result + + +def eval_metrics(results, + gt_seg_maps, + num_classes, + ignore_index, + metrics=['mIoU'], + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False, + beta=1): + """Calculate evaluation metrics + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str] | Iterables): list of ground + truth segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Whether ignore zero label. Default: False. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + + total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label = total_intersect_and_union( + results, gt_seg_maps, num_classes, ignore_index, label_map, + reduce_zero_label) + ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, + total_area_pred_label, + total_area_label, metrics, nan_to_num, + beta) + + return ret_metrics + + +def pre_eval_to_metrics(pre_eval_results, + metrics=['mIoU'], + nan_to_num=None, + beta=1): + """Convert pre-eval results to metrics. + + Args: + pre_eval_results (list[tuple[torch.Tensor]]): per image eval results + for computing evaluation metric + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + pre_eval_results = tuple(zip(*pre_eval_results)) + assert len(pre_eval_results) == 4 + + total_area_intersect = sum(pre_eval_results[0]) + total_area_union = sum(pre_eval_results[1]) + total_area_pred_label = sum(pre_eval_results[2]) + total_area_label = sum(pre_eval_results[3]) + + ret_metrics = total_area_to_metrics(total_area_intersect, total_area_union, + total_area_pred_label, + total_area_label, metrics, nan_to_num, + beta) + + return ret_metrics + + +def total_area_to_metrics(total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + metrics=['mIoU'], + nan_to_num=None, + beta=1): + """Calculate evaluation metrics + Args: + total_area_intersect (ndarray): The intersection of prediction and + ground truth histogram on all classes. + total_area_union (ndarray): The union of prediction and ground truth + histogram on all classes. + total_area_pred_label (ndarray): The prediction histogram on all + classes. + total_area_label (ndarray): The ground truth histogram on all classes. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError('metrics {} is not supported'.format(metrics)) + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor( + [f_score(x[0], x[1], beta) for x in zip(precision, recall)]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/data_utils/easyportrait/mmseg/core/hook/__init__.py b/data_utils/easyportrait/mmseg/core/hook/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02fe93dd597713154bd10c77fb6ed12b89e83dac --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/hook/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .wandblogger_hook import MMSegWandbHook + +__all__ = ['MMSegWandbHook'] diff --git a/data_utils/easyportrait/mmseg/core/hook/wandblogger_hook.py b/data_utils/easyportrait/mmseg/core/hook/wandblogger_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..b35c526051d01f8838bb1227ba6ab2fddc372123 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/hook/wandblogger_hook.py @@ -0,0 +1,370 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np +from mmcv.runner import HOOKS +from mmcv.runner.dist_utils import master_only +from mmcv.runner.hooks.checkpoint import CheckpointHook +from mmcv.runner.hooks.logger.wandb import WandbLoggerHook + +from mmseg.core import DistEvalHook, EvalHook + + +@HOOKS.register_module() +class MMSegWandbHook(WandbLoggerHook): + """Enhanced Wandb logger hook for MMSegmentation. + + Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not + only automatically log all the metrics but also log the following extra + information - saves model checkpoints as W&B Artifact, and + logs model prediction as interactive W&B Tables. + + - Metrics: The MMSegWandbHook will automatically log training + and validation metrics along with system metrics (CPU/GPU). + + - Checkpointing: If `log_checkpoint` is True, the checkpoint saved at + every checkpoint interval will be saved as W&B Artifacts. + This depends on the : class:`mmcv.runner.CheckpointHook` whose priority + is higher than this hook. Please refer to + https://docs.wandb.ai/guides/artifacts/model-versioning + to learn more about model versioning with W&B Artifacts. + + - Checkpoint Metadata: If evaluation results are available for a given + checkpoint artifact, it will have a metadata associated with it. + The metadata contains the evaluation metrics computed on validation + data with that checkpoint along with the current epoch. It depends + on `EvalHook` whose priority is more than MMSegWandbHook. + + - Evaluation: At every evaluation interval, the `MMSegWandbHook` logs the + model prediction as interactive W&B Tables. The number of samples + logged is given by `num_eval_images`. Currently, the `MMSegWandbHook` + logs the predicted segmentation masks along with the ground truth at + every evaluation interval. This depends on the `EvalHook` whose + priority is more than `MMSegWandbHook`. Also note that the data is just + logged once and subsequent evaluation tables uses reference to the + logged data to save memory usage. Please refer to + https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables. + + ``` + Example: + log_config = dict( + ... + hooks=[ + ..., + dict(type='MMSegWandbHook', + init_kwargs={ + 'entity': "YOUR_ENTITY", + 'project': "YOUR_PROJECT_NAME" + }, + interval=50, + log_checkpoint=True, + log_checkpoint_metadata=True, + num_eval_images=100, + bbox_score_thr=0.3) + ]) + ``` + + Args: + init_kwargs (dict): A dict passed to wandb.init to initialize + a W&B run. Please refer to https://docs.wandb.ai/ref/python/init + for possible key-value pairs. + interval (int): Logging interval (every k iterations). + Default 10. + log_checkpoint (bool): Save the checkpoint at every checkpoint interval + as W&B Artifacts. Use this for model versioning where each version + is a checkpoint. + Default: False + log_checkpoint_metadata (bool): Log the evaluation metrics computed + on the validation data with the checkpoint, along with current + epoch as a metadata to that checkpoint. + Default: True + num_eval_images (int): Number of validation images to be logged. + Default: 100 + """ + + def __init__(self, + init_kwargs=None, + interval=50, + log_checkpoint=False, + log_checkpoint_metadata=False, + num_eval_images=100, + **kwargs): + super(MMSegWandbHook, self).__init__(init_kwargs, interval, **kwargs) + + self.log_checkpoint = log_checkpoint + self.log_checkpoint_metadata = ( + log_checkpoint and log_checkpoint_metadata) + self.num_eval_images = num_eval_images + self.log_evaluation = (num_eval_images > 0) + self.ckpt_hook: CheckpointHook = None + self.eval_hook: EvalHook = None + self.test_fn = None + + @master_only + def before_run(self, runner): + super(MMSegWandbHook, self).before_run(runner) + + # Check if EvalHook and CheckpointHook are available. + for hook in runner.hooks: + if isinstance(hook, CheckpointHook): + self.ckpt_hook = hook + if isinstance(hook, EvalHook): + from mmseg.apis import single_gpu_test + self.eval_hook = hook + self.test_fn = single_gpu_test + if isinstance(hook, DistEvalHook): + from mmseg.apis import multi_gpu_test + self.eval_hook = hook + self.test_fn = multi_gpu_test + + # Check conditions to log checkpoint + if self.log_checkpoint: + if self.ckpt_hook is None: + self.log_checkpoint = False + self.log_checkpoint_metadata = False + runner.logger.warning( + 'To log checkpoint in MMSegWandbHook, `CheckpointHook` is' + 'required, please check hooks in the runner.') + else: + self.ckpt_interval = self.ckpt_hook.interval + + # Check conditions to log evaluation + if self.log_evaluation or self.log_checkpoint_metadata: + if self.eval_hook is None: + self.log_evaluation = False + self.log_checkpoint_metadata = False + runner.logger.warning( + 'To log evaluation or checkpoint metadata in ' + 'MMSegWandbHook, `EvalHook` or `DistEvalHook` in mmseg ' + 'is required, please check whether the validation ' + 'is enabled.') + else: + self.eval_interval = self.eval_hook.interval + self.val_dataset = self.eval_hook.dataloader.dataset + # Determine the number of samples to be logged. + if self.num_eval_images > len(self.val_dataset): + self.num_eval_images = len(self.val_dataset) + runner.logger.warning( + f'The num_eval_images ({self.num_eval_images}) is ' + 'greater than the total number of validation samples ' + f'({len(self.val_dataset)}). The complete validation ' + 'dataset will be logged.') + + # Check conditions to log checkpoint metadata + if self.log_checkpoint_metadata: + assert self.ckpt_interval % self.eval_interval == 0, \ + 'To log checkpoint metadata in MMSegWandbHook, the interval ' \ + f'of checkpoint saving ({self.ckpt_interval}) should be ' \ + 'divisible by the interval of evaluation ' \ + f'({self.eval_interval}).' + + # Initialize evaluation table + if self.log_evaluation: + # Initialize data table + self._init_data_table() + # Add data to the data table + self._add_ground_truth(runner) + # Log ground truth data + self._log_data_table() + + # for the reason of this double-layered structure, refer to + # https://github.com/open-mmlab/mmdetection/issues/8145#issuecomment-1345343076 + def after_train_iter(self, runner): + if self.get_mode(runner) == 'train': + # An ugly patch. The iter-based eval hook will call the + # `after_train_iter` method of all logger hooks before evaluation. + # Use this trick to skip that call. + # Don't call super method at first, it will clear the log_buffer + return super(MMSegWandbHook, self).after_train_iter(runner) + else: + super(MMSegWandbHook, self).after_train_iter(runner) + self._after_train_iter(runner) + + @master_only + def _after_train_iter(self, runner): + if self.by_epoch: + return + + # Save checkpoint and metadata + if (self.log_checkpoint + and self.every_n_iters(runner, self.ckpt_interval) + or (self.ckpt_hook.save_last and self.is_last_iter(runner))): + if self.log_checkpoint_metadata and self.eval_hook: + metadata = { + 'iter': runner.iter + 1, + **self._get_eval_results() + } + else: + metadata = None + aliases = [f'iter_{runner.iter+1}', 'latest'] + model_path = osp.join(self.ckpt_hook.out_dir, + f'iter_{runner.iter+1}.pth') + self._log_ckpt_as_artifact(model_path, aliases, metadata) + + # Save prediction table + if self.log_evaluation and self.eval_hook._should_evaluate(runner): + # Currently the results of eval_hook is not reused by wandb, so + # wandb will run evaluation again internally. We will consider + # refactoring this function afterwards + results = self.test_fn(runner.model, self.eval_hook.dataloader) + # Initialize evaluation table + self._init_pred_table() + # Log predictions + self._log_predictions(results, runner) + # Log the table + self._log_eval_table(runner.iter + 1) + + @master_only + def after_run(self, runner): + self.wandb.finish() + + def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None): + """Log model checkpoint as W&B Artifact. + + Args: + model_path (str): Path of the checkpoint to log. + aliases (list): List of the aliases associated with this artifact. + metadata (dict, optional): Metadata associated with this artifact. + """ + model_artifact = self.wandb.Artifact( + f'run_{self.wandb.run.id}_model', type='model', metadata=metadata) + model_artifact.add_file(model_path) + self.wandb.log_artifact(model_artifact, aliases=aliases) + + def _get_eval_results(self): + """Get model evaluation results.""" + results = self.eval_hook.latest_results + eval_results = self.val_dataset.evaluate( + results, logger='silent', **self.eval_hook.eval_kwargs) + return eval_results + + def _init_data_table(self): + """Initialize the W&B Tables for validation data.""" + columns = ['image_name', 'image'] + self.data_table = self.wandb.Table(columns=columns) + + def _init_pred_table(self): + """Initialize the W&B Tables for model evaluation.""" + columns = ['image_name', 'ground_truth', 'prediction'] + self.eval_table = self.wandb.Table(columns=columns) + + def _add_ground_truth(self, runner): + # Get image loading pipeline + from mmseg.datasets.pipelines import LoadImageFromFile + img_loader = None + for t in self.val_dataset.pipeline.transforms: + if isinstance(t, LoadImageFromFile): + img_loader = t + + if img_loader is None: + self.log_evaluation = False + runner.logger.warning( + 'LoadImageFromFile is required to add images ' + 'to W&B Tables.') + return + + # Select the images to be logged. + self.eval_image_indexs = np.arange(len(self.val_dataset)) + # Set seed so that same validation set is logged each time. + np.random.seed(42) + np.random.shuffle(self.eval_image_indexs) + self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images] + + classes = self.val_dataset.CLASSES + self.class_id_to_label = {id: name for id, name in enumerate(classes)} + self.class_set = self.wandb.Classes([{ + 'id': id, + 'name': name + } for id, name in self.class_id_to_label.items()]) + + for idx in self.eval_image_indexs: + img_info = self.val_dataset.img_infos[idx] + image_name = img_info['filename'] + + # Get image and convert from BGR to RGB + img_meta = img_loader( + dict(img_info=img_info, img_prefix=self.val_dataset.img_dir)) + image = mmcv.bgr2rgb(img_meta['img']) + + # Get segmentation mask + seg_mask = self.val_dataset.get_gt_seg_map_by_idx(idx) + # Dict of masks to be logged. + wandb_masks = None + if seg_mask.ndim == 2: + wandb_masks = { + 'ground_truth': { + 'mask_data': seg_mask, + 'class_labels': self.class_id_to_label + } + } + + # Log a row to the data table. + self.data_table.add_data( + image_name, + self.wandb.Image( + image, masks=wandb_masks, classes=self.class_set)) + else: + runner.logger.warning( + f'The segmentation mask is {seg_mask.ndim}D which ' + 'is not supported by W&B.') + self.log_evaluation = False + return + + def _log_predictions(self, results, runner): + table_idxs = self.data_table_ref.get_index() + assert len(table_idxs) == len(self.eval_image_indexs) + assert len(results) == len(self.val_dataset) + + for ndx, eval_image_index in enumerate(self.eval_image_indexs): + # Get the result + pred_mask = results[eval_image_index] + + if pred_mask.ndim == 2: + wandb_masks = { + 'prediction': { + 'mask_data': pred_mask, + 'class_labels': self.class_id_to_label + } + } + + # Log a row to the data table. + self.eval_table.add_data( + self.data_table_ref.data[ndx][0], + self.data_table_ref.data[ndx][1], + self.wandb.Image( + self.data_table_ref.data[ndx][1], + masks=wandb_masks, + classes=self.class_set)) + else: + runner.logger.warning( + 'The predictio segmentation mask is ' + f'{pred_mask.ndim}D which is not supported by W&B.') + self.log_evaluation = False + return + + def _log_data_table(self): + """Log the W&B Tables for validation data as artifact and calls + `use_artifact` on it so that the evaluation table can use the reference + of already uploaded images. + + This allows the data to be uploaded just once. + """ + data_artifact = self.wandb.Artifact('val', type='dataset') + data_artifact.add(self.data_table, 'val_data') + + self.wandb.run.use_artifact(data_artifact) + data_artifact.wait() + + self.data_table_ref = data_artifact.get('val_data') + + def _log_eval_table(self, iter): + """Log the W&B Tables for model evaluation. + + The table will be logged multiple times creating new version. Use this + to compare models at different intervals interactively. + """ + pred_artifact = self.wandb.Artifact( + f'run_{self.wandb.run.id}_pred', type='evaluation') + pred_artifact.add(self.eval_table, 'eval_data') + self.wandb.run.log_artifact(pred_artifact) diff --git a/data_utils/easyportrait/mmseg/core/optimizers/__init__.py b/data_utils/easyportrait/mmseg/core/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbf4ecfcd4d1f0834322e2964b55d9637c844ba --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/optimizers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_decay_optimizer_constructor import ( + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' +] diff --git a/data_utils/easyportrait/mmseg/core/optimizers/layer_decay_optimizer_constructor.py b/data_utils/easyportrait/mmseg/core/optimizers/layer_decay_optimizer_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..31d145c17ccc4abe064ec202ea2ac2ec68d95db3 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/optimizers/layer_decay_optimizer_constructor.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import warnings + +from mmcv.runner import DefaultOptimizerConstructor, get_dist_info + +from mmseg.utils import get_root_logger +from ..builder import OPTIMIZER_BUILDERS + + +def get_layer_id_for_convnext(var_name, max_layer_id): + """Get the layer id to set the different learning rates in ``layer_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_layer_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + stage_id = int(var_name.split('.')[2]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + block_id = int(var_name.split('.')[3]) + if stage_id == 0: + layer_id = 1 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + else: + return max_layer_id + 1 + + +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_stage_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + return 0 + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + return stage_id + 1 + else: + return max_stage_id - 1 + + +def get_layer_id_for_vit(var_name, max_layer_id): + """Get the layer id to set the different learning rates. + + Args: + var_name (str): The key of the model. + num_max_layer (int): Maximum number of backbone layers. + + Returns: + int: Returns the layer id of the key. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('backbone.layers'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + else: + return max_layer_id - 1 + + +@OPTIMIZER_BUILDERS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for ConvNeXt, + BEiT and MAE. + """ + + def add_params(self, params, module, **kwargs): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + """ + logger = get_root_logger() + + parameter_groups = {} + logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}') + num_layers = self.paramwise_cfg.get('num_layers') + 2 + decay_rate = self.paramwise_cfg.get('decay_rate') + decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') + logger.info('Build LearningRateDecayOptimizerConstructor ' + f'{decay_type} {decay_rate} - {num_layers}') + weight_decay = self.base_wd + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') or name in ( + 'pos_embed', 'cls_token'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + logger.info(f'set param {name} as id {layer_id}') + elif 'BEiT' in module.backbone.__class__.__name__ or \ + 'MAE' in module.backbone.__class__.__name__ or \ + 'VisionTransformer' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_vit(name, num_layers) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + elif decay_type == 'stage_wise': + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + group_name = f'layer_{layer_id}_{group_name}' + + if group_name not in parameter_groups: + scale = decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') + params.extend(parameter_groups.values()) + + +@OPTIMIZER_BUILDERS.register_module() +class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for BEiT, + and it will be deprecated. + Please use ``LearningRateDecayOptimizerConstructor`` instead. + """ + + def __init__(self, optimizer_cfg, paramwise_cfg): + warnings.warn('DeprecationWarning: Original ' + 'LayerDecayOptimizerConstructor of BEiT ' + 'will be deprecated. Please use ' + 'LearningRateDecayOptimizerConstructor instead, ' + 'and set decay_type = layer_wise_vit in paramwise_cfg.') + paramwise_cfg.update({'decay_type': 'layer_wise_vit'}) + warnings.warn('DeprecationWarning: Layer_decay_rate will ' + 'be deleted, please use decay_rate instead.') + paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') + super(LayerDecayOptimizerConstructor, + self).__init__(optimizer_cfg, paramwise_cfg) diff --git a/data_utils/easyportrait/mmseg/core/seg/__init__.py b/data_utils/easyportrait/mmseg/core/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5206b96be6f87e99e8ae820bdd788444f4d255d9 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/seg/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import build_pixel_sampler +from .sampler import BasePixelSampler, OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/data_utils/easyportrait/mmseg/core/seg/builder.py b/data_utils/easyportrait/mmseg/core/seg/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1cecd347bffb6ab289f27e0f9bbab91c3a5d4bd8 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/seg/builder.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import Registry, build_from_cfg + +PIXEL_SAMPLERS = Registry('pixel sampler') + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) diff --git a/data_utils/easyportrait/mmseg/core/seg/sampler/__init__.py b/data_utils/easyportrait/mmseg/core/seg/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7648564adbfcdb3e66f640e4f9c61de6e215e1 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/seg/sampler/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_pixel_sampler import BasePixelSampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['BasePixelSampler', 'OHEMPixelSampler'] diff --git a/data_utils/easyportrait/mmseg/core/seg/sampler/base_pixel_sampler.py b/data_utils/easyportrait/mmseg/core/seg/sampler/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..03672cd478a2e464cc734ae92686c86f219da0a9 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/seg/sampler/base_pixel_sampler.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/data_utils/easyportrait/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/data_utils/easyportrait/mmseg/core/seg/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..833a28768cd0bfddfc7ab59d3ba3cbe892b2fbb5 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import PIXEL_SAMPLERS +from .base_pixel_sampler import BasePixelSampler + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super(OHEMPixelSampler, self).__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + if not isinstance(self.context.loss_decode, nn.ModuleList): + losses_decode = [self.context.loss_decode] + else: + losses_decode = self.context.loss_decode + losses = 0.0 + for loss_module in losses_decode: + losses += loss_module( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/data_utils/easyportrait/mmseg/core/utils/__init__.py b/data_utils/easyportrait/mmseg/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28882893a53a78dcb7063e51b07273d30dd1c19f --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dist_util import check_dist_init, sync_random_seed +from .misc import add_prefix + +__all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed'] diff --git a/data_utils/easyportrait/mmseg/core/utils/dist_util.py b/data_utils/easyportrait/mmseg/core/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b3288519d0a8785db12c00da9d48e51de5ce3ba1 --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/utils/dist_util.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info + + +def check_dist_init(): + return dist.is_available() and dist.is_initialized() + + +def sync_random_seed(seed=None, device='cuda'): + """Make sure different ranks share the same seed. All workers must call + this function, otherwise it will deadlock. This method is generally used in + `DistributedSampler`, because the seed should be identical across all + processes in the distributed group. + + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() diff --git a/data_utils/easyportrait/mmseg/core/utils/misc.py b/data_utils/easyportrait/mmseg/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..282bb8d9698ebd19876d849f1e3fc2ee23e2d40d --- /dev/null +++ b/data_utils/easyportrait/mmseg/core/utils/misc.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/data_utils/easyportrait/mmseg/datasets/__init__.py b/data_utils/easyportrait/mmseg/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a74ef4a1b4e38c9b6335aa9db6988036cbb0125c --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ade import ADE20KDataset +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +from .chase_db1 import ChaseDB1Dataset +from .cityscapes import CityscapesDataset +from .coco_stuff import COCOStuffDataset +from .custom import CustomDataset +from .dark_zurich import DarkZurichDataset +from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, + RepeatDataset) +from .drive import DRIVEDataset +from .face import FaceOccludedDataset +from .hrf import HRFDataset +from .imagenets import (ImageNetSDataset, LoadImageNetSAnnotations, + LoadImageNetSImageFromFile) +from .isaid import iSAIDDataset +from .isprs import ISPRSDataset +from .loveda import LoveDADataset +from .night_driving import NightDrivingDataset +from .pascal_context import PascalContextDataset, PascalContextDataset59 +from .potsdam import PotsdamDataset +from .stare import STAREDataset +from .voc import PascalVOCDataset +from .easy_portrait import EasyPortraitDataset +from .lapa import LaPaDataset +from .easy_portrait_face_parsing import EasyPortraitFPDataset, EasyPortraitFPDatasetCross +from .easy_portrait_portrait_segmentation import EasyPortraitPSDataset + +__all__ = [ + 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', + 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', + 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', + 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', + 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', + 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', + 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'FaceOccludedDataset', + 'ImageNetSDataset', 'LoadImageNetSAnnotations', + 'LoadImageNetSImageFromFile', 'EasyPortraitDataset', 'LaPaDataset', + 'EasyPortraitFPDataset', 'EasyPortraitPSDataset', 'EasyPortraitFPDatasetCross', +] diff --git a/data_utils/easyportrait/mmseg/datasets/ade.py b/data_utils/easyportrait/mmseg/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..db94cebd3bbaed1dfee0f9a80f5a164a862de84f --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/ade.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ADE20KDataset(CustomDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + CLASSES = ( + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + def __init__(self, **kwargs): + super(ADE20KDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) + + def results2img(self, results, imgfile_prefix, to_label_id, indices=None): + """Write the segmentation results to images. + + Args: + results (list[ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission. + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + if indices is None: + indices = list(range(len(self))) + + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + for result, idx in zip(results, indices): + + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + # The index range of official requirement is from 0 to 150. + # But the index range of output is from 0 to 149. + # That is because we set reduce_zero_label=True. + result = result + 1 + + output = Image.fromarray(result.astype(np.uint8)) + output.save(png_filename) + result_files.append(png_filename) + + return result_files + + def format_results(self, + results, + imgfile_prefix, + to_label_id=True, + indices=None): + """Format the results into dir (standard format for ade20k evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str | None): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". + to_label_id (bool): whether convert output to label_id for + submission. Default: False + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + + if indices is None: + indices = list(range(len(self))) + + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' + + result_files = self.results2img(results, imgfile_prefix, to_label_id, + indices) + return result_files diff --git a/data_utils/easyportrait/mmseg/datasets/builder.py b/data_utils/easyportrait/mmseg/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..49ee63373a2bf5445855550fdfc71b105a91a076 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/builder.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import platform +import random +from functools import partial + +import numpy as np +import torch +from mmcv.parallel import collate +from mmcv.runner import get_dist_info +from mmcv.utils import Registry, build_from_cfg, digit_version +from torch.utils.data import DataLoader, IterableDataset + +from .samplers import DistributedSampler + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + base_soft_limit = rlimit[0] + hard_limit = rlimit[1] + soft_limit = min(max(4096, base_soft_limit), hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def _concat_dataset(cfg, default_args=None): + """Build :obj:`ConcatDataset by.""" + from .dataset_wrappers import ConcatDataset + img_dir = cfg['img_dir'] + ann_dir = cfg.get('ann_dir', None) + split = cfg.get('split', None) + # pop 'separate_eval' since it is not a valid key for common datasets. + separate_eval = cfg.pop('separate_eval', True) + num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 + if ann_dir is not None: + num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 + else: + num_ann_dir = 0 + if split is not None: + num_split = len(split) if isinstance(split, (list, tuple)) else 1 + else: + num_split = 0 + if num_img_dir > 1: + assert num_img_dir == num_ann_dir or num_ann_dir == 0 + assert num_img_dir == num_split or num_split == 0 + else: + assert num_split == num_ann_dir or num_ann_dir <= 1 + num_dset = max(num_split, num_img_dir) + + datasets = [] + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + if isinstance(img_dir, (list, tuple)): + data_cfg['img_dir'] = img_dir[i] + if isinstance(ann_dir, (list, tuple)): + data_cfg['ann_dir'] = ann_dir[i] + if isinstance(split, (list, tuple)): + data_cfg['split'] = split[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets, separate_eval) + + +def build_dataset(cfg, default_args=None): + """Build datasets.""" + from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, + RepeatDataset) + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif cfg['type'] == 'MultiImageMixDataset': + cp_cfg = copy.deepcopy(cfg) + cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) + cp_cfg.pop('type') + dataset = MultiImageMixDataset(**cp_cfg) + elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( + cfg.get('split', None), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + drop_last=False, + pin_memory=True, + persistent_workers=True, + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int | None): Seed to be used. Default: None. + drop_last (bool): Whether to drop the last incomplete batch in epoch. + Default: False + pin_memory (bool): Whether to use pin_memory in DataLoader. + Default: True + persistent_workers (bool): If True, the data loader will not shutdown + the worker processes after a dataset has been consumed once. + This allows to maintain the workers Dataset instances alive. + The argument also has effect in PyTorch>=1.7.0. + Default: True + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist and not isinstance(dataset, IterableDataset): + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=shuffle, seed=seed) + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + elif dist: + sampler = None + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + if digit_version(torch.__version__) >= digit_version('1.8.0'): + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + persistent_workers=persistent_workers, + **kwargs) + else: + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) diff --git a/data_utils/easyportrait/mmseg/datasets/chase_db1.py b/data_utils/easyportrait/mmseg/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdc8d8d41c6869876fb13d83c1de0d51d4c6e7c --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/chase_db1.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ChaseDB1Dataset(CustomDataset): + """Chase_db1 dataset. + + In segmentation map annotation for Chase_db1, 0 stands for background, + which is included in 2 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_1stHO.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(ChaseDB1Dataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_1stHO.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) diff --git a/data_utils/easyportrait/mmseg/datasets/cityscapes.py b/data_utils/easyportrait/mmseg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..ed633d00db33789541284df0d2ec3187d4dd01a3 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/cityscapes.py @@ -0,0 +1,214 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np +from mmcv.utils import print_log +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class CityscapesDataset(CustomDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + + CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle') + + PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], + [0, 80, 100], [0, 0, 230], [119, 11, 32]] + + def __init__(self, + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtFine_labelTrainIds.png', + **kwargs): + super(CityscapesDataset, self).__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + import cityscapesscripts.helpers.labels as CSLabels + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy + + def results2img(self, results, imgfile_prefix, to_label_id, indices=None): + """Write the segmentation results to images. + + Args: + results (list[ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission. + indices (list[int], optional): Indices of input results, + if not set, all the indices of the dataset will be used. + Default: None. + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + if indices is None: + indices = list(range(len(self))) + + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + for result, idx in zip(results, indices): + if to_label_id: + result = self._convert_to_label_id(result) + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + output = Image.fromarray(result.astype(np.uint8)).convert('P') + import cityscapesscripts.helpers.labels as CSLabels + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) + for label_id, label in CSLabels.id2label.items(): + palette[label_id] = label.color + + output.putpalette(palette) + output.save(png_filename) + result_files.append(png_filename) + + return result_files + + def format_results(self, + results, + imgfile_prefix, + to_label_id=True, + indices=None): + """Format the results into dir (standard format for Cityscapes + evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". + to_label_id (bool): whether convert output to label_id for + submission. Default: False + indices (list[int], optional): Indices of input results, + if not set, all the indices of the dataset will be used. + Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + if indices is None: + indices = list(range(len(self))) + + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' + + result_files = self.results2img(results, imgfile_prefix, to_label_id, + indices) + + return result_files + + def evaluate(self, + results, + metric='mIoU', + logger=None, + imgfile_prefix=None): + """Evaluation in Cityscapes/default protocol. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file, + for cityscapes evaluation only. It includes the file path and + the prefix of filename, e.g., "a/b/prefix". + If results are evaluated with cityscapes protocol, it would be + the prefix of output png files. The output files would be + png images under folder "a/b/prefix/xxx.png", where "xxx" is + the image name of cityscapes. If not specified, a temp file + will be created for evaluation. + Default: None. + + Returns: + dict[str, float]: Cityscapes/default metrics. + """ + + eval_results = dict() + metrics = metric.copy() if isinstance(metric, list) else [metric] + if 'cityscapes' in metrics: + eval_results.update( + self._evaluate_cityscapes(results, logger, imgfile_prefix)) + metrics.remove('cityscapes') + if len(metrics) > 0: + eval_results.update( + super(CityscapesDataset, + self).evaluate(results, metrics, logger)) + + return eval_results + + def _evaluate_cityscapes(self, results, logger, imgfile_prefix): + """Evaluation in Cityscapes protocol. + + Args: + results (list): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + try: + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + except ImportError: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + result_dir = imgfile_prefix + + eval_results = dict() + print_log(f'Evaluating results under {result_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(result_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + seg_map_list = [] + pred_list = [] + + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + for seg_map in mmcv.scandir( + self.ann_dir, 'gtFine_labelIds.png', recursive=True): + seg_map_list.append(osp.join(self.ann_dir, seg_map)) + pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) + + eval_results.update( + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + + return eval_results diff --git a/data_utils/easyportrait/mmseg/datasets/coco_stuff.py b/data_utils/easyportrait/mmseg/datasets/coco_stuff.py new file mode 100644 index 0000000000000000000000000000000000000000..24d089556599a5696c50fd0115077fbc83413061 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/coco_stuff.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class COCOStuffDataset(CustomDataset): + """COCO-Stuff dataset. + + In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version + are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff + 164k is from 0 to 170, where 255 is the ignore index. So, they are all 171 + semantic categories. ``reduce_zero_label`` is set to True and False for the + 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', + and ``seg_map_suffix`` is fixed to '.png'. + """ + CLASSES = ( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', + 'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', + 'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', + 'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', + 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', + 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', + 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', + 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', + 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', + 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', + 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood') + + PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], + [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], + [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], + [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], + [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], + [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], + [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], + [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], + [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], + [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], + [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], + [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], + [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], + [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0], + [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128], + [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224], + [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128], + [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192], + [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224], + [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0], + [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192], + [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224], + [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128], + [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128], + [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160], + [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64], + [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128], + [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160], + [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192], + [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192], + [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160], + [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64], + [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192], + [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], + [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], + [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], + [64, 192, 96], [64, 160, 64], [64, 64, 0]] + + def __init__(self, **kwargs): + super(COCOStuffDataset, self).__init__( + img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs) diff --git a/data_utils/easyportrait/mmseg/datasets/custom.py b/data_utils/easyportrait/mmseg/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..57a20a265aa4ee960197115f8fbc0ccd78dc0717 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/custom.py @@ -0,0 +1,489 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from collections import OrderedDict + +import mmcv +import numpy as np +from mmcv.utils import print_log +from prettytable import PrettyTable +from torch.utils.data import Dataset + +from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics +from mmseg.utils import get_root_logger +from .builder import DATASETS +from .pipelines import Compose, LoadAnnotations + + +@DATASETS.register_module() +class CustomDataset(Dataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of CustomDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + pipeline (list[dict]): Processing pipeline + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. Default: '.jpg' + ann_dir (str, optional): Path to annotation directory. Default: None + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + split (str, optional): Split txt file. If split is specified, only + file with suffix in the splits will be loaded. Otherwise, all + images in img_dir/ann_dir will be loaded. Default: None + data_root (str, optional): Data root for img_dir/ann_dir. Default: + None. + test_mode (bool): If test_mode=True, gt wouldn't be loaded. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default: False + classes (str | Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, and + self.PALETTE is None, random palette will be generated. + Default: None + gt_seg_map_loader_cfg (dict): build LoadAnnotations to load gt for + evaluation, load from disk by default. Default: ``dict()``. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + CLASSES = None + + PALETTE = None + + def __init__(self, + pipeline, + img_dir, + img_suffix='.jpg', + ann_dir=None, + seg_map_suffix='.png', + split=None, + data_root=None, + test_mode=False, + ignore_index=255, + reduce_zero_label=False, + classes=None, + palette=None, + gt_seg_map_loader_cfg=dict(), + file_client_args=dict(backend='disk')): + self.pipeline = Compose(pipeline) + self.img_dir = img_dir + self.img_suffix = img_suffix + self.ann_dir = ann_dir + self.seg_map_suffix = seg_map_suffix + self.split = split + self.data_root = data_root + self.test_mode = test_mode + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.label_map = None + self.CLASSES, self.PALETTE = self.get_classes_and_palette( + classes, palette) + self.gt_seg_map_loader = LoadAnnotations( + reduce_zero_label=reduce_zero_label, **gt_seg_map_loader_cfg) + + self.file_client_args = file_client_args + self.file_client = mmcv.FileClient.infer_client(self.file_client_args) + + if test_mode: + assert self.CLASSES is not None, \ + '`cls.CLASSES` or `classes` should be specified when testing' + + # join paths if data_root is specified + if self.data_root is not None: + if not osp.isabs(self.img_dir): + self.img_dir = osp.join(self.data_root, self.img_dir) + if not (self.ann_dir is None or osp.isabs(self.ann_dir)): + self.ann_dir = osp.join(self.data_root, self.ann_dir) + if not (self.split is None or osp.isabs(self.split)): + self.split = osp.join(self.data_root, self.split) + + # load annotations + self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, + self.ann_dir, + self.seg_map_suffix, self.split) + + def __len__(self): + """Total number of samples of data.""" + return len(self.img_infos) + + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, + split): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + lines = mmcv.list_from_file( + split, file_client_args=self.file_client_args) + for line in lines: + img_name = line.strip() + img_info = dict(filename=img_name + img_suffix) + if ann_dir is not None: + seg_map = img_name + seg_map_suffix + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + else: + for img in self.file_client.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=img_suffix, + recursive=True): + img_info = dict(filename=img) + if ann_dir is not None: + seg_map = img.replace(img_suffix, seg_map_suffix) + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + img_infos = sorted(img_infos, key=lambda x: x['filename']) + + print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + return img_infos + + def get_ann_info(self, idx): + """Get annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.img_infos[idx]['ann'] + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['seg_fields'] = [] + results['img_prefix'] = self.img_dir + results['seg_prefix'] = self.ann_dir + if self.custom_classes: + results['label_map'] = self.label_map + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set + False). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + else: + return self.prepare_train_img(idx) + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + + img_info = self.img_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by + pipeline. + """ + + img_info = self.img_infos[idx] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def format_results(self, results, imgfile_prefix, indices=None, **kwargs): + """Place holder to format result to dataset specific output.""" + raise NotImplementedError + + def get_gt_seg_map_by_idx(self, index): + """Get one ground truth segmentation map for evaluation.""" + ann_info = self.get_ann_info(index) + results = dict(ann_info=ann_info) + self.pre_pipeline(results) + self.gt_seg_map_loader(results) + return results['gt_semantic_seg'] + + def get_gt_seg_maps(self, efficient_test=None): + """Get ground truth segmentation maps for evaluation.""" + if efficient_test is not None: + warnings.warn( + 'DeprecationWarning: ``efficient_test`` has been deprecated ' + 'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory ' + 'friendly by default. ') + + for idx in range(len(self)): + ann_info = self.get_ann_info(idx) + results = dict(ann_info=ann_info) + self.pre_pipeline(results) + self.gt_seg_map_loader(results) + yield results['gt_semantic_seg'] + + def pre_eval(self, preds, indices): + """Collect eval result from each iteration. + + Args: + preds (list[torch.Tensor] | torch.Tensor): the segmentation logit + after argmax, shape (N, H, W). + indices (list[int] | int): the prediction related ground truth + indices. + + Returns: + list[torch.Tensor]: (area_intersect, area_union, area_prediction, + area_ground_truth). + """ + # In order to compat with batch inference + if not isinstance(indices, list): + indices = [indices] + if not isinstance(preds, list): + preds = [preds] + + pre_eval_results = [] + + for pred, index in zip(preds, indices): + seg_map = self.get_gt_seg_map_by_idx(index) + pre_eval_results.append( + intersect_and_union( + pred, + seg_map, + len(self.CLASSES), + self.ignore_index, + # as the label map has already been applied and zero label + # has already been reduced by get_gt_seg_map_by_idx() i.e. + # LoadAnnotations.__call__(), these operations should not + # be duplicated. See the following issues/PRs: + # https://github.com/open-mmlab/mmsegmentation/issues/1415 + # https://github.com/open-mmlab/mmsegmentation/pull/1417 + # https://github.com/open-mmlab/mmsegmentation/pull/2504 + # for more details + label_map=dict(), + reduce_zero_label=False)) + + return pre_eval_results + + def get_classes_and_palette(self, classes=None, palette=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, random + palette will be generated. Default: None + """ + if classes is None: + self.custom_classes = False + return self.CLASSES, self.PALETTE + + self.custom_classes = True + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if self.CLASSES: + if not set(class_names).issubset(self.CLASSES): + raise ValueError('classes is not a subset of CLASSES.') + + # dictionary, its keys are the old label ids and its values + # are the new label ids. + # used for changing pixel labels in load_annotations. + self.label_map = {} + for i, c in enumerate(self.CLASSES): + if c not in class_names: + self.label_map[i] = 255 + else: + self.label_map[i] = class_names.index(c) + + palette = self.get_palette_for_custom_classes(class_names, palette) + + return class_names, palette + + def get_palette_for_custom_classes(self, class_names, palette=None): + + if self.label_map is not None: + # return subset of palette + palette = [] + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + palette.append(self.PALETTE[old_id]) + palette = type(self.PALETTE)(palette) + + elif palette is None: + if self.PALETTE is None: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + palette = np.random.randint(0, 255, size=(len(class_names), 3)) + np.random.set_state(state) + else: + palette = self.PALETTE + + return palette + + def evaluate(self, + results, + metric='mIoU', + logger=None, + gt_seg_maps=None, + **kwargs): + """Evaluate the dataset. + + Args: + results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval + results or predict segmentation map for computing evaluation + metric. + metric (str | list[str]): Metrics to be evaluated. 'mIoU', + 'mDice' and 'mFscore' are supported. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + gt_seg_maps (generator[ndarray]): Custom gt seg maps as input, + used in ConcatDataset + + Returns: + dict[str, float]: Default metrics. + """ + if isinstance(metric, str): + metric = [metric] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metric).issubset(set(allowed_metrics)): + raise KeyError('metric {} is not supported'.format(metric)) + + eval_results = {} + # test a list of files + if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( + results, str): + if gt_seg_maps is None: + gt_seg_maps = self.get_gt_seg_maps() + num_classes = len(self.CLASSES) + ret_metrics = eval_metrics( + results, + gt_seg_maps, + num_classes, + self.ignore_index, + metric, + label_map=dict(), + reduce_zero_label=False) + # test a list of pre_eval_results + else: + ret_metrics = pre_eval_to_metrics(results, metric) + + # Because dataset.CLASSES is required for per-eval. + if self.CLASSES is None: + class_names = tuple(range(num_classes)) + else: + class_names = self.CLASSES + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + + # for logger + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + summary_table_data = PrettyTable() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + summary_table_data.add_column(key, [val]) + else: + summary_table_data.add_column('m' + key, [val]) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + print_log('Summary:', logger) + print_log('\n' + summary_table_data.get_string(), logger=logger) + + # each metric dict + for key, value in ret_metrics_summary.items(): + if key == 'aAcc': + eval_results[key] = value / 100.0 + else: + eval_results['m' + key] = value / 100.0 + + ret_metrics_class.pop('Class', None) + for key, value in ret_metrics_class.items(): + eval_results.update({ + key + '.' + str(name): value[idx] / 100.0 + for idx, name in enumerate(class_names) + }) + + return eval_results diff --git a/data_utils/easyportrait/mmseg/datasets/dark_zurich.py b/data_utils/easyportrait/mmseg/datasets/dark_zurich.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6fda6e9e27d9ccb1b6fb2e22aa4340e635a4da --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/dark_zurich.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class DarkZurichDataset(CityscapesDataset): + """DarkZurichDataset dataset.""" + + def __init__(self, **kwargs): + super().__init__( + img_suffix='_rgb_anon.png', + seg_map_suffix='_gt_labelTrainIds.png', + **kwargs) diff --git a/data_utils/easyportrait/mmseg/datasets/dataset_wrappers.py b/data_utils/easyportrait/mmseg/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..1fb089f9f287f841d0a99f67ab840f28175c87ec --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/dataset_wrappers.py @@ -0,0 +1,277 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import bisect +import collections +import copy +from itertools import chain + +import mmcv +import numpy as np +from mmcv.utils import build_from_cfg, print_log +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .builder import DATASETS, PIPELINES +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + support evaluation and formatting results + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + separate_eval (bool): Whether to evaluate the concatenated + dataset results separately, Defaults to True. + """ + + def __init__(self, datasets, separate_eval=True): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + self.PALETTE = datasets[0].PALETTE + self.separate_eval = separate_eval + assert separate_eval in [True, False], \ + f'separate_eval can only be True or False,' \ + f'but get {separate_eval}' + if any([isinstance(ds, CityscapesDataset) for ds in datasets]): + raise NotImplementedError( + 'Evaluating ConcatDataset containing CityscapesDataset' + 'is not supported!') + + def evaluate(self, results, logger=None, **kwargs): + """Evaluate the results. + + Args: + results (list[tuple[torch.Tensor]] | list[str]]): per image + pre_eval results or predict segmentation map for + computing evaluation metric. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str: float]: evaluate results of the total dataset + or each separate + dataset if `self.separate_eval=True`. + """ + assert len(results) == self.cumulative_sizes[-1], \ + ('Dataset and results have different sizes: ' + f'{self.cumulative_sizes[-1]} v.s. {len(results)}') + + # Check whether all the datasets support evaluation + for dataset in self.datasets: + assert hasattr(dataset, 'evaluate'), \ + f'{type(dataset)} does not implement evaluate function' + + if self.separate_eval: + dataset_idx = -1 + total_eval_results = dict() + for size, dataset in zip(self.cumulative_sizes, self.datasets): + start_idx = 0 if dataset_idx == -1 else \ + self.cumulative_sizes[dataset_idx] + end_idx = self.cumulative_sizes[dataset_idx + 1] + + results_per_dataset = results[start_idx:end_idx] + print_log( + f'\nEvaluateing {dataset.img_dir} with ' + f'{len(results_per_dataset)} images now', + logger=logger) + + eval_results_per_dataset = dataset.evaluate( + results_per_dataset, logger=logger, **kwargs) + dataset_idx += 1 + for k, v in eval_results_per_dataset.items(): + total_eval_results.update({f'{dataset_idx}_{k}': v}) + + return total_eval_results + + if len(set([type(ds) for ds in self.datasets])) != 1: + raise NotImplementedError( + 'All the datasets should have same types when ' + 'self.separate_eval=False') + else: + if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( + results, str): + # merge the generators of gt_seg_maps + gt_seg_maps = chain( + *[dataset.get_gt_seg_maps() for dataset in self.datasets]) + else: + # if the results are `pre_eval` results, + # we do not need gt_seg_maps to evaluate + gt_seg_maps = None + eval_results = self.datasets[0].evaluate( + results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs) + return eval_results + + def get_dataset_idx_and_sample_idx(self, indice): + """Return dataset and sample index when given an indice of + ConcatDataset. + + Args: + indice (int): indice of sample in ConcatDataset + + Returns: + int: the index of sub dataset the sample belong to + int: the index of sample in its corresponding subset + """ + if indice < 0: + if -indice > len(self): + raise ValueError( + 'absolute value of index should not exceed dataset length') + indice = len(self) + indice + dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice) + if dataset_idx == 0: + sample_idx = indice + else: + sample_idx = indice - self.cumulative_sizes[dataset_idx - 1] + return dataset_idx, sample_idx + + def format_results(self, results, imgfile_prefix, indices=None, **kwargs): + """format result for every sample of ConcatDataset.""" + if indices is None: + indices = list(range(len(self))) + + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' + + ret_res = [] + for i, indice in enumerate(indices): + dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( + indice) + res = self.datasets[dataset_idx].format_results( + [results[i]], + imgfile_prefix + f'/{dataset_idx}', + indices=[sample_idx], + **kwargs) + ret_res.append(res) + return sum(ret_res, []) + + def pre_eval(self, preds, indices): + """do pre eval for every sample of ConcatDataset.""" + # In order to compat with batch inference + if not isinstance(indices, list): + indices = [indices] + if not isinstance(preds, list): + preds = [preds] + ret_res = [] + for i, indice in enumerate(indices): + dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( + indice) + res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx) + ret_res.append(res) + return sum(ret_res, []) + + +@DATASETS.register_module() +class RepeatDataset(object): + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + self.PALETTE = dataset.PALETTE + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + """Get item from original dataset.""" + return self.dataset[idx % self._ori_len] + + def __len__(self): + """The length is multiplied by ``times``""" + return self.times * self._ori_len + + +@DATASETS.register_module() +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. + + + Args: + dataset (:obj:`CustomDataset`): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + """ + + def __init__(self, dataset, pipeline, skip_type_keys=None): + assert isinstance(pipeline, collections.abc.Sequence) + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + transform = build_from_cfg(transform, PIPELINES) + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self.dataset = dataset + self.CLASSES = dataset.CLASSES + self.PALETTE = dataset.PALETTE + self.num_samples = len(dataset) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indexes'): + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + results['mix_results'] = mix_results + + results = transform(results) + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. + + It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys diff --git a/data_utils/easyportrait/mmseg/datasets/drive.py b/data_utils/easyportrait/mmseg/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..d44fb0da7123eeed7a8e0b932d12e41305a33f22 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/drive.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class DRIVEDataset(CustomDataset): + """DRIVE dataset. + + In segmentation map annotation for DRIVE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(DRIVEDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_manual1.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) diff --git a/data_utils/easyportrait/mmseg/datasets/easy_portrait.py b/data_utils/easyportrait/mmseg/datasets/easy_portrait.py new file mode 100644 index 0000000000000000000000000000000000000000..e65f1760a7d0e525974d8b1c69dbbefbbfe8c782 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/easy_portrait.py @@ -0,0 +1,35 @@ +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class EasyPortraitDataset(CustomDataset): + """EasyPortrait dataset. + + In segmentation map annotation for EasyPortrait, 0 stands for background, + which is included in 9 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'person', 'skin', + 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth') + + PALETTE = [[0, 0, 0], [223, 87, 188], [160, 221, 255], + [130, 106, 237], [200, 121, 255], [255, 183, 255], + [0, 144, 193], [113, 137, 255], [230, 232, 230]] + + def __init__(self, **kwargs): + super(EasyPortraitDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/datasets/easy_portrait_face_parsing.py b/data_utils/easyportrait/mmseg/datasets/easy_portrait_face_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..4249940fda3b9311b19d6bd9214f5539065ad8fa --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/easy_portrait_face_parsing.py @@ -0,0 +1,58 @@ +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class EasyPortraitFPDataset(CustomDataset): + """EasyPortraitFPDataset dataset. + + In segmentation map annotation for EasyPortrait, 0 stands for background, + which is included in 9 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'skin', + 'left brow', 'right brow', 'left eye', + 'right eye', 'lips', 'teeth') + + PALETTE = [[0, 0, 0], [160, 221, 255], + [130, 106, 237], [200, 121, 255], [255, 183, 255], + [0, 144, 193], [113, 137, 255], [230, 232, 230]] + + def __init__(self, **kwargs): + super(EasyPortraitFPDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) + +@DATASETS.register_module() +class EasyPortraitFPDatasetCross(CustomDataset): + """EasyPortraitFPDatasetCross dataset. + + In segmentation map annotation for EasyPortrait, 0 stands for background, + which is included in 9 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'left brow', 'right brow', 'left eye', 'right eye', 'lips') + PALETTE = [[0, 0, 0], [160, 221, 255], + [130, 106, 237], [200, 121, 255], [255, 183, 255], + [0, 144, 193]] + + def __init__(self, **kwargs): + super(EasyPortraitFPDatasetCross, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/datasets/easy_portrait_portrait_segmentation.py b/data_utils/easyportrait/mmseg/datasets/easy_portrait_portrait_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..cb36502ba944b850b2d317937ad2d92712159c2d --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/easy_portrait_portrait_segmentation.py @@ -0,0 +1,31 @@ +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class EasyPortraitPSDataset(CustomDataset): + """EasyPortrait dataset. + + In segmentation map annotation for EasyPortrait, 0 stands for background, + which is included in 9 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'person') + + PALETTE = [[0, 0, 0], [160, 221, 255]] + + def __init__(self, **kwargs): + super(EasyPortraitPSDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/datasets/face.py b/data_utils/easyportrait/mmseg/datasets/face.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc2345b09952e97b0b60ad6fa584d04ab4b6c6b --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/face.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class FaceOccludedDataset(CustomDataset): + """Face Occluded dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + + CLASSES = ('background', 'face') + + PALETTE = [[0, 0, 0], [128, 0, 0]] + + def __init__(self, split, **kwargs): + super(FaceOccludedDataset, self).__init__( + img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/data_utils/easyportrait/mmseg/datasets/hrf.py b/data_utils/easyportrait/mmseg/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3ea8d79c2a62dd4e9a2e84af35ec4ea1879091 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/hrf.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class HRFDataset(CustomDataset): + """HRF dataset. + + In segmentation map annotation for HRF, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(HRFDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) diff --git a/data_utils/easyportrait/mmseg/datasets/imagenets.py b/data_utils/easyportrait/mmseg/datasets/imagenets.py new file mode 100644 index 0000000000000000000000000000000000000000..77fbb388d0d39c5ca6a2aff2b211b903254645a4 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/imagenets.py @@ -0,0 +1,1004 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from mmseg.core import intersect_and_union +from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile +from .builder import DATASETS, PIPELINES +from .custom import CustomDataset + + +@PIPELINES.register_module() +class LoadImageNetSImageFromFile(LoadImageFromFile): + """Load an image from the ImageNetS dataset. + + To avoid out of memory, images that are too large will + be downsampled to the scale of 1000. + + Args: + downsample_large_image (bool): Whether to downsample the large images. + False may cause out of memory. + Defaults to True. + """ + + def __init__(self, downsample_large_image=True, **kwargs): + super().__init__(**kwargs) + self.downsample_large_image = downsample_large_image + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + results = super().__call__(results) + if not self.downsample_large_image: + return results + + # Images that are too large + # (H * W > 1000 * 100, + # these images are included in ImageNetSDataset.LARGES) + # will be downsampled to 1000 along the longer side. + H, W = results['img_shape'][:2] + if H * W > pow(1000, 2): + if H > W: + target_size = (int(1000 * W / H), 1000) + else: + target_size = (1000, int(1000 * H / W)) + + results['img'] = mmcv.imresize( + results['img'], size=target_size, interpolation='bilinear') + if self.to_float32: + results['img'] = results['img'].astype(np.float32) + + results['img_shape'] = results['img'].shape + results['ori_shape'] = results['img'].shape + # Set initial values for default meta_keys + results['pad_shape'] = results['img'].shape + return results + + +@PIPELINES.register_module() +class LoadImageNetSAnnotations(LoadAnnotations): + """Load annotations for the ImageNetS dataset. The annotations in + ImageNet-S are saved as RGB images. + + The annotations with format of RGB should be + converted to the format of Gray as R + G * 256. + """ + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + results = super().__call__(results) + + # The annotations in ImageNet-S are saved as RGB images, + # due to 919 > 255 (upper bound of gray images). + + # For training, + # the annotations with format of RGB should be + # converted to the format of Gray as R + G * 256. + results['gt_semantic_seg'] = \ + results['gt_semantic_seg'][:, :, 1] * 256 + \ + results['gt_semantic_seg'][:, :, 2] + results['gt_semantic_seg'] = results['gt_semantic_seg'].astype( + np.int32) + return results + + +@DATASETS.register_module() +class ImageNetSDataset(CustomDataset): + """ImageNet-S dataset. + + In segmentation map annotation for ImageNet-S, 0 stands for others, which + is not included in 50/300/919 categories. ``ignore_index`` is fixed to + 1000. The ``img_suffix`` is fixed to '.JPEG' and ``seg_map_suffix`` is + fixed to '.png'. + """ + CLASSES50 = ('others', 'goldfish', 'tiger shark', 'goldfinch', 'tree frog', + 'kuvasz', 'red fox', 'siamese cat', 'american black bear', + 'ladybug', 'sulphur butterfly', 'wood rabbit', 'hamster', + 'wild boar', 'gibbon', 'african elephant', 'giant panda', + 'airliner', 'ashcan', 'ballpoint', 'beach wagon', 'boathouse', + 'bullet train', 'cellular telephone', 'chest', 'clog', + 'container ship', 'digital watch', 'dining table', + 'golf ball', 'grand piano', 'iron', 'lab coat', 'mixing bowl', + 'motor scooter', 'padlock', 'park bench', 'purse', + 'streetcar', 'table lamp', 'television', 'toilet seat', + 'umbrella', 'vase', 'water bottle', 'water tower', 'yawl', + 'street sign', 'lemon', 'carbonara', 'agaric') + CLASSES300 = ( + 'others', 'tench', 'goldfish', 'tiger shark', 'hammerhead', + 'electric ray', 'ostrich', 'goldfinch', 'house finch', + 'indigo bunting', 'kite', 'common newt', 'axolotl', 'tree frog', + 'tailed frog', 'mud turtle', 'banded gecko', 'american chameleon', + 'whiptail', 'african chameleon', 'komodo dragon', 'american alligator', + 'triceratops', 'thunder snake', 'ringneck snake', 'king snake', + 'rock python', 'horned viper', 'harvestman', 'scorpion', + 'garden spider', 'tick', 'african grey', 'lorikeet', + 'red-breasted merganser', 'wallaby', 'koala', 'jellyfish', + 'sea anemone', 'conch', 'fiddler crab', 'american lobster', + 'spiny lobster', 'isopod', 'bittern', 'crane', 'limpkin', 'bustard', + 'albatross', 'toy terrier', 'afghan hound', 'bluetick', 'borzoi', + 'irish wolfhound', 'whippet', 'ibizan hound', 'staffordshire ' + 'bullterrier', 'border terrier', 'yorkshire terrier', + 'lakeland terrier', 'giant schnauzer', 'standard schnauzer', + 'scotch terrier', 'lhasa', 'english setter', 'clumber', + 'english springer', 'welsh springer spaniel', 'kuvasz', 'kelpie', + 'doberman', 'miniature pinscher', 'malamute', 'pug', 'leonberg', + 'great pyrenees', 'samoyed', 'brabancon griffon', 'cardigan', 'coyote', + 'red fox', 'kit fox', 'grey fox', 'persian cat', 'siamese cat', + 'cougar', 'lynx', 'tiger', 'american black bear', 'sloth bear', + 'ladybug', 'leaf beetle', 'weevil', 'bee', 'cicada', 'leafhopper', + 'damselfly', 'ringlet', 'cabbage butterfly', 'sulphur butterfly', + 'sea cucumber', 'wood rabbit', 'hare', 'hamster', 'wild boar', + 'hippopotamus', 'bighorn', 'ibex', 'badger', 'three-toed sloth', + 'orangutan', 'gibbon', 'colobus', 'spider monkey', 'squirrel monkey', + 'madagascar cat', 'indian elephant', 'african elephant', 'giant panda', + 'barracouta', 'eel', 'coho', 'academic gown', 'accordion', 'airliner', + 'ambulance', 'analog clock', 'ashcan', 'backpack', 'balloon', + 'ballpoint', 'barbell', 'barn', 'bassoon', 'bath towel', 'beach wagon', + 'bicycle-built-for-two', 'binoculars', 'boathouse', 'bonnet', + 'bookcase', 'bow', 'brass', 'breastplate', 'bullet train', 'cannon', + 'can opener', "carpenter's kit", 'cassette', 'cellular telephone', + 'chain saw', 'chest', 'china cabinet', 'clog', 'combination lock', + 'container ship', 'corkscrew', 'crate', 'crock pot', 'digital watch', + 'dining table', 'dishwasher', 'doormat', 'dutch oven', 'electric fan', + 'electric locomotive', 'envelope', 'file', 'folding chair', + 'football helmet', 'freight car', 'french horn', 'fur coat', + 'garbage truck', 'goblet', 'golf ball', 'grand piano', 'half track', + 'hamper', 'hard disc', 'harmonica', 'harvester', 'hook', + 'horizontal bar', 'horse cart', 'iron', "jack-o'-lantern", 'lab coat', + 'ladle', 'letter opener', 'liner', 'mailbox', 'megalith', + 'military uniform', 'milk can', 'mixing bowl', 'monastery', 'mortar', + 'mosquito net', 'motor scooter', 'mountain bike', 'mountain tent', + 'mousetrap', 'necklace', 'nipple', 'ocarina', 'padlock', 'palace', + 'parallel bars', 'park bench', 'pedestal', 'pencil sharpener', + 'pickelhaube', 'pillow', 'planetarium', 'plastic bag', + 'polaroid camera', 'pole', 'pot', 'purse', 'quilt', 'radiator', + 'radio', 'radio telescope', 'rain barrel', 'reflex camera', + 'refrigerator', 'rifle', 'rocking chair', 'rubber eraser', 'rule', + 'running shoe', 'sewing machine', 'shield', 'shoji', 'ski', 'ski mask', + 'slot', 'soap dispenser', 'soccer ball', 'sock', 'soup bowl', + 'space heater', 'spider web', 'spindle', 'sports car', + 'steel arch bridge', 'stethoscope', 'streetcar', 'submarine', + 'swimming trunks', 'syringe', 'table lamp', 'tank', 'teddy', + 'television', 'throne', 'tile roof', 'toilet seat', 'trench coat', + 'trimaran', 'typewriter keyboard', 'umbrella', 'vase', 'volleyball', + 'wardrobe', 'warplane', 'washer', 'water bottle', 'water tower', + 'whiskey jug', 'wig', 'wine bottle', 'wok', 'wreck', 'yawl', 'yurt', + 'street sign', 'traffic light', 'consomme', 'ice cream', 'bagel', + 'cheeseburger', 'hotdog', 'mashed potato', 'spaghetti squash', + 'bell pepper', 'cardoon', 'granny smith', 'strawberry', 'lemon', + 'carbonara', 'burrito', 'cup', 'coral reef', "yellow lady's slipper", + 'buckeye', 'agaric', 'gyromitra', 'earthstar', 'bolete') + CLASSES919 = ( + 'others', 'house finch', 'stupa', 'agaric', 'hen-of-the-woods', + 'wild boar', 'kit fox', 'desk', 'beaker', 'spindle', 'lipstick', + 'cardoon', 'ringneck snake', 'daisy', 'sturgeon', 'scorpion', + 'pelican', 'bustard', 'rock crab', 'rock beauty', 'minivan', 'menu', + 'thunder snake', 'zebra', 'partridge', 'lacewing', 'starfish', + 'italian greyhound', 'marmot', 'cardigan', 'plate', 'ballpoint', + 'chesapeake bay retriever', 'pirate', 'potpie', 'keeshond', 'dhole', + 'waffle iron', 'cab', 'american egret', 'colobus', 'radio telescope', + 'gordon setter', 'mousetrap', 'overskirt', 'hamster', 'wine bottle', + 'bluetick', 'macaque', 'bullfrog', 'junco', 'tusker', 'scuba diver', + 'pool table', 'samoyed', 'mailbox', 'purse', 'monastery', 'bathtub', + 'window screen', 'african crocodile', 'traffic light', 'tow truck', + 'radio', 'recreational vehicle', 'grey whale', 'crayfish', + 'rottweiler', 'racer', 'whistle', 'pencil box', 'barometer', + 'cabbage butterfly', 'sloth bear', 'rhinoceros beetle', 'guillotine', + 'rocking chair', 'sports car', 'bouvier des flandres', 'border collie', + 'fiddler crab', 'slot', 'go-kart', 'cocker spaniel', 'plate rack', + 'common newt', 'tile roof', 'marimba', 'moped', 'terrapin', 'oxcart', + 'lionfish', 'bassinet', 'rain barrel', 'american black bear', 'goose', + 'half track', 'kite', 'microphone', 'shield', 'mexican hairless', + 'measuring cup', 'bubble', 'platypus', 'saint bernard', 'police van', + 'vase', 'lhasa', 'wardrobe', 'teapot', 'hummingbird', 'revolver', + 'jinrikisha', 'mailbag', 'red-breasted merganser', 'assault rifle', + 'loudspeaker', 'fig', 'american lobster', 'can opener', 'arctic fox', + 'broccoli', 'long-horned beetle', 'television', 'airship', + 'black stork', 'marmoset', 'panpipe', 'drumstick', 'knee pad', + 'lotion', 'french loaf', 'throne', 'jeep', 'jersey', 'tiger cat', + 'cliff', 'sealyham terrier', 'strawberry', 'minibus', 'goldfinch', + 'goblet', 'burrito', 'harp', 'tractor', 'cornet', 'leopard', 'fly', + 'fireboat', 'bolete', 'barber chair', 'consomme', 'tripod', + 'breastplate', 'pineapple', 'wok', 'totem pole', 'alligator lizard', + 'common iguana', 'digital clock', 'bighorn', 'siamese cat', 'bobsled', + 'irish setter', 'zucchini', 'crock pot', 'loggerhead', + 'irish wolfhound', 'nipple', 'rubber eraser', 'impala', 'barbell', + 'snow leopard', 'siberian husky', 'necklace', 'manhole cover', + 'electric fan', 'hippopotamus', 'entlebucher', 'prison', 'doberman', + 'ruffed grouse', 'coyote', 'toaster', 'puffer', 'black swan', + 'schipperke', 'file', 'prairie chicken', 'hourglass', + 'greater swiss mountain dog', 'pajama', 'ear', 'pedestal', 'viaduct', + 'shoji', 'snowplow', 'puck', 'gyromitra', 'birdhouse', 'flatworm', + 'pier', 'coral reef', 'pot', 'mortar', 'polaroid camera', + 'passenger car', 'barracouta', 'banded gecko', + 'black-and-tan coonhound', 'safe', 'ski', 'torch', 'green lizard', + 'volleyball', 'brambling', 'solar dish', 'lawn mower', 'swing', + 'hyena', 'staffordshire bullterrier', 'screw', 'toilet tissue', + 'velvet', 'scale', 'stopwatch', 'sock', 'koala', 'garbage truck', + 'spider monkey', 'afghan hound', 'chain', 'upright', 'flagpole', + 'tree frog', 'cuirass', 'chest', 'groenendael', 'christmas stocking', + 'lakeland terrier', 'perfume', 'neck brace', 'lab coat', 'carbonara', + 'porcupine', 'shower curtain', 'slug', 'pitcher', + 'flat-coated retriever', 'pekinese', 'oscilloscope', 'church', 'lynx', + 'cowboy hat', 'table lamp', 'pug', 'crate', 'water buffalo', + 'labrador retriever', 'weimaraner', 'giant schnauzer', 'stove', + 'sea urchin', 'banjo', 'tiger', 'miniskirt', 'eft', + 'european gallinule', 'vending machine', 'miniature schnauzer', + 'maypole', 'bull mastiff', 'hoopskirt', 'coffeepot', 'four-poster', + 'safety pin', 'monarch', 'beer glass', 'grasshopper', 'head cabbage', + 'parking meter', 'bonnet', 'chiffonier', 'great dane', 'spider web', + 'electric locomotive', 'scotch terrier', 'australian terrier', + 'honeycomb', 'leafhopper', 'beer bottle', 'mud turtle', 'lifeboat', + 'cassette', "potter's wheel", 'oystercatcher', 'space heater', + 'coral fungus', 'sunglass', 'quail', 'triumphal arch', 'collie', + 'walker hound', 'bucket', 'bee', 'komodo dragon', 'dugong', 'gibbon', + 'trailer truck', 'king crab', 'cheetah', 'rifle', 'stingray', 'bison', + 'ipod', 'modem', 'box turtle', 'motor scooter', 'container ship', + 'vestment', 'dingo', 'radiator', 'giant panda', 'nail', 'sea slug', + 'indigo bunting', 'trimaran', 'jacamar', 'chimpanzee', 'comic book', + 'odometer', 'dishwasher', 'bolo tie', 'barn', 'paddlewheel', + 'appenzeller', 'great white shark', 'green snake', 'jackfruit', + 'llama', 'whippet', 'hay', 'leaf beetle', 'sombrero', 'ram', + 'washbasin', 'cup', 'wall clock', 'acorn squash', 'spotted salamander', + 'boston bull', 'border terrier', 'doormat', 'cicada', 'kimono', + 'hand blower', 'ox', 'meerkat', 'space shuttle', 'african hunting dog', + 'violin', 'artichoke', 'toucan', 'bulbul', 'coucal', 'red wolf', + 'seat belt', 'bicycle-built-for-two', 'bow tie', 'pretzel', + 'bedlington terrier', 'albatross', 'punching bag', 'cocktail shaker', + 'diamondback', 'corn', 'ant', 'mountain bike', 'walking stick', + 'standard schnauzer', 'power drill', 'cardigan', 'accordion', + 'wire-haired fox terrier', 'streetcar', 'beach wagon', 'ibizan hound', + 'hair spray', 'car mirror', 'mountain tent', 'trench coat', + 'studio couch', 'pomeranian', 'dough', 'corkscrew', 'broom', + 'parachute', 'band aid', 'water tower', 'teddy', 'fire engine', + 'hornbill', 'hotdog', 'theater curtain', 'crane', 'malinois', 'lion', + 'african elephant', 'handkerchief', 'caldron', 'shopping basket', + 'gown', 'wolf spider', 'vizsla', 'electric ray', 'freight car', + 'pembroke', 'feather boa', 'wallet', 'agama', 'hard disc', 'stretcher', + 'sorrel', 'trilobite', 'basset', 'vulture', 'tarantula', 'hermit crab', + 'king snake', 'robin', 'bernese mountain dog', 'ski mask', + 'fountain pen', 'combination lock', 'yurt', 'clumber', 'park bench', + 'baboon', 'kuvasz', 'centipede', 'tabby', 'steam locomotive', 'badger', + 'irish water spaniel', 'picket fence', 'gong', 'canoe', + 'swimming trunks', 'submarine', 'echidna', 'bib', 'refrigerator', + 'hammer', 'lemon', 'admiral', 'chihuahua', 'basenji', 'pinwheel', + 'golfcart', 'bullet train', 'crib', 'muzzle', 'eggnog', + 'old english sheepdog', 'tray', 'tiger beetle', 'electric guitar', + 'peacock', 'soup bowl', 'wallaby', 'abacus', 'dalmatian', 'harvester', + 'aircraft carrier', 'snowmobile', 'welsh springer spaniel', + 'affenpinscher', 'oboe', 'cassette player', 'pencil sharpener', + 'japanese spaniel', 'plunger', 'black widow', 'norfolk terrier', + 'reflex camera', 'ice bear', 'redbone', 'mongoose', 'warthog', + 'arabian camel', 'bittern', 'mixing bowl', 'tailed frog', 'scabbard', + 'castle', 'curly-coated retriever', 'garden spider', 'folding chair', + 'mouse', 'prayer rug', 'red fox', 'toy terrier', 'leonberg', + 'lycaenid', 'poncho', 'goldfish', 'red-backed sandpiper', 'holster', + 'hair slide', 'coho', 'komondor', 'macaw', 'maltese dog', 'megalith', + 'sarong', 'green mamba', 'sea lion', 'water ouzel', 'bulletproof vest', + 'sulphur-crested cockatoo', 'scottish deerhound', 'steel arch bridge', + 'catamaran', 'brittany spaniel', 'redshank', 'otter', + 'brabancon griffon', 'balloon', 'rule', 'planetarium', 'trombone', + 'mitten', 'abaya', 'crash helmet', 'milk can', 'hartebeest', + 'windsor tie', 'irish terrier', 'african chameleon', 'matchstick', + 'water bottle', 'cloak', 'ground beetle', 'ashcan', 'crane', + 'gila monster', 'unicycle', 'gazelle', 'wombat', 'brain coral', + 'projector', 'custard apple', 'proboscis monkey', 'tibetan mastiff', + 'mosque', 'plastic bag', 'backpack', 'drum', 'norwich terrier', + 'pizza', 'carton', 'plane', 'gorilla', 'jigsaw puzzle', 'forklift', + 'isopod', 'otterhound', 'vacuum', 'european fire salamander', 'apron', + 'langur', 'boxer', 'african grey', 'ice lolly', 'toilet seat', + 'golf ball', 'titi', 'drake', 'ostrich', 'magnetic compass', + 'great pyrenees', 'rhodesian ridgeback', 'buckeye', 'dungeness crab', + 'toy poodle', 'ptarmigan', 'amphibian', 'monitor', 'school bus', + 'schooner', 'spatula', 'weevil', 'speedboat', 'sundial', 'borzoi', + 'bassoon', 'bath towel', 'pill bottle', 'acorn', 'tick', 'briard', + 'thimble', 'brass', 'white wolf', 'boathouse', 'yawl', + 'miniature pinscher', 'barn spider', 'jean', 'water snake', 'dishrag', + 'yorkshire terrier', 'hammerhead', 'typewriter keyboard', 'papillon', + 'ocarina', 'washer', 'standard poodle', 'china cabinet', 'steel drum', + 'swab', 'mobile home', 'german short-haired pointer', 'saluki', + 'bee eater', 'rock python', 'vine snake', 'kelpie', 'harmonica', + 'military uniform', 'reel', 'thatch', 'maraca', 'tricycle', + 'sidewinder', 'parallel bars', 'banana', 'flute', 'paintbrush', + 'sleeping bag', "yellow lady's slipper", 'three-toed sloth', + 'white stork', 'notebook', 'weasel', 'tiger shark', 'football helmet', + 'madagascar cat', 'dowitcher', 'wreck', 'king penguin', 'lighter', + 'timber wolf', 'racket', 'digital watch', 'liner', 'hen', + 'suspension bridge', 'pillow', "carpenter's kit", 'butternut squash', + 'sandal', 'sussex spaniel', 'hip', 'american staffordshire terrier', + 'flamingo', 'analog clock', 'black and gold garden spider', + 'sea cucumber', 'indian elephant', 'syringe', 'lens cap', 'missile', + 'cougar', 'diaper', 'chambered nautilus', 'garter snake', + 'anemone fish', 'organ', 'limousine', 'horse cart', 'jaguar', + 'frilled lizard', 'crutch', 'sea anemone', 'guenon', 'meat loaf', + 'slide rule', 'saltshaker', 'pomegranate', 'acoustic guitar', + 'shopping cart', 'drilling platform', 'nematode', 'chickadee', + 'academic gown', 'candle', 'norwegian elkhound', 'armadillo', + 'horizontal bar', 'orangutan', 'obelisk', 'stone wall', 'cannon', + 'rugby ball', 'ping-pong ball', 'window shade', 'trolleybus', + 'ice cream', 'pop bottle', 'cock', 'harvestman', 'leatherback turtle', + 'killer whale', 'spaghetti squash', 'chain saw', 'stinkhorn', + 'espresso maker', 'loafer', 'bagel', 'ballplayer', 'skunk', + 'chainlink fence', 'earthstar', 'whiptail', 'barrel', + 'kerry blue terrier', 'triceratops', 'chow', 'grey fox', 'sax', + 'binoculars', 'ladybug', 'silky terrier', 'gas pump', 'cradle', + 'whiskey jug', 'french bulldog', 'eskimo dog', 'hog', 'hognose snake', + 'pickup', 'indian cobra', 'hand-held computer', 'printer', 'pole', + 'bald eagle', 'american alligator', 'dumbbell', 'umbrella', 'mink', + 'shower cap', 'tank', 'quill', 'fox squirrel', 'ambulance', + 'lesser panda', 'frying pan', 'letter opener', 'hook', 'strainer', + 'pick', 'dragonfly', 'gar', 'piggy bank', 'envelope', 'stole', 'ibex', + 'american chameleon', 'bearskin', 'microwave', 'petri dish', + 'wood rabbit', 'beacon', 'dung beetle', 'warplane', 'ruddy turnstone', + 'knot', 'fur coat', 'hamper', 'beagle', 'ringlet', 'mask', + 'persian cat', 'cellular telephone', 'american coot', 'apiary', + 'shovel', 'coffee mug', 'sewing machine', 'spoonbill', 'padlock', + 'bell pepper', 'great grey owl', 'squirrel monkey', + 'sulphur butterfly', 'scoreboard', 'bow', 'malamute', 'siamang', + 'snail', 'remote control', 'sea snake', 'loupe', 'model t', + 'english setter', 'dining table', 'face powder', 'tench', + "jack-o'-lantern", 'croquet ball', 'water jug', 'airedale', 'airliner', + 'guinea pig', 'hare', 'damselfly', 'thresher', 'limpkin', 'buckle', + 'english springer', 'boa constrictor', 'french horn', + 'black-footed ferret', 'shetland sheepdog', 'capuchin', 'cheeseburger', + 'miniature poodle', 'spotlight', 'wooden spoon', + 'west highland white terrier', 'wig', 'running shoe', 'cowboy boot', + 'brown bear', 'iron', 'brassiere', 'magpie', 'gondola', 'grand piano', + 'granny smith', 'mashed potato', 'german shepherd', 'stethoscope', + 'cauliflower', 'soccer ball', 'pay-phone', 'jellyfish', 'cairn', + 'polecat', 'trifle', 'photocopier', 'shih-tzu', 'orange', 'guacamole', + 'hatchet', 'cello', 'egyptian cat', 'basketball', 'moving van', + 'mortarboard', 'dial telephone', 'street sign', 'oil filter', 'beaver', + 'spiny lobster', 'chime', 'bookcase', 'chiton', 'black grouse', 'jay', + 'axolotl', 'oxygen mask', 'cricket', 'worm fence', 'indri', + 'cockroach', 'mushroom', 'dandie dinmont', 'tennis ball', + 'howler monkey', 'rapeseed', 'tibetan terrier', 'newfoundland', + 'dutch oven', 'paddle', 'joystick', 'golden retriever', + 'blenheim spaniel', 'mantis', 'soft-coated wheaten terrier', + 'little blue heron', 'convertible', 'bloodhound', 'palace', + 'medicine chest', 'english foxhound', 'cleaver', 'sweatshirt', + 'mosquito net', 'soap dispenser', 'ladle', 'screwdriver', + 'fire screen', 'binder', 'suit', 'barrow', 'clog', 'cucumber', + 'baseball', 'lorikeet', 'conch', 'quilt', 'eel', 'horned viper', + 'night snake', 'angora', 'pickelhaube', 'gasmask', 'patas') + + # Some too large images are downsampled in LoadImageNetSImageFromFile. + # These images should be upsampled back in results2img. + LARGES = { + '00022800': [1225, 900], + '00037230': [2082, 2522], + '00011749': [1000, 1303], + '00040173': [1280, 960], + '00027045': [1880, 1330], + '00019424': [2304, 3072], + '00015496': [1728, 2304], + '00025715': [1083, 1624], + '00008260': [1400, 1400], + '00047233': [850, 1540], + '00043667': [2066, 1635], + '00024274': [1920, 2560], + '00028437': [1920, 2560], + '00018910': [1536, 2048], + '00046074': [1600, 1164], + '00021215': [1024, 1540], + '00034174': [960, 1362], + '00007361': [960, 1280], + '00030207': [1512, 1016], + '00015637': [1600, 1200], + '00013665': [2100, 1500], + '00028501': [1200, 852], + '00047237': [1624, 1182], + '00026950': [1200, 1600], + '00041704': [1920, 2560], + '00027074': [1200, 1600], + '00016473': [1200, 1200], + '00012206': [2448, 3264], + '00019622': [960, 1280], + '00008728': [2806, 750], + '00027712': [1128, 1700], + '00007195': [1290, 1824], + '00002942': [2560, 1920], + '00037032': [1954, 2613], + '00018543': [1067, 1600], + '00041570': [1536, 2048], + '00004422': [1728, 2304], + '00044827': [800, 1280], + '00046674': [1200, 1600], + '00017711': [1200, 1600], + '00048488': [1889, 2834], + '00000706': [1501, 2001], + '00032736': [1200, 1600], + '00024348': [1536, 2048], + '00023430': [1051, 1600], + '00030496': [1350, 900], + '00026543': [1280, 960], + '00010969': [2560, 1920], + '00025272': [1294, 1559], + '00019950': [1536, 1024], + '00004466': [1182, 1722], + '00029917': [3072, 2304], + '00014683': [1145, 1600], + '00013084': [1281, 2301], + '00039792': [1760, 1034], + '00046246': [2448, 3264], + '00004280': [984, 1440], + '00009435': [1127, 1502], + '00012860': [1673, 2500], + '00016702': [1444, 1000], + '00011278': [2048, 3072], + '00048174': [1605, 2062], + '00035451': [1225, 1636], + '00024769': [1200, 900], + '00032797': [1251, 1664], + '00027924': [1453, 1697], + '00010965': [1536, 2048], + '00020735': [1200, 1600], + '00027789': [853, 1280], + '00015113': [1324, 1999], + '00037571': [1251, 1586], + '00030120': [1536, 2048], + '00044219': [2448, 3264], + '00024604': [1535, 1955], + '00010926': [1200, 900], + '00017509': [1536, 2048], + '00042373': [924, 1104], + '00037066': [1536, 2048], + '00025494': [1880, 1060], + '00028610': [1377, 2204], + '00007196': [1202, 1600], + '00030788': [2592, 1944], + '00046865': [1920, 2560], + '00027141': [1600, 1200], + '00023215': [1200, 1600], + '00000218': [1439, 1652], + '00048126': [1516, 927], + '00030408': [1600, 2400], + '00038582': [1600, 1200], + '00046959': [1304, 900], + '00016988': [1242, 1656], + '00017201': [1629, 1377], + '00017658': [1000, 1035], + '00002766': [1495, 2383], + '00038573': [1600, 1071], + '00042297': [1200, 1200], + '00010564': [995, 1234], + '00001189': [1600, 1200], + '00007018': [1858, 2370], + '00043554': [1200, 1600], + '00000746': [1200, 1600], + '00001386': [960, 1280], + '00029975': [1600, 1200], + '00016221': [2877, 2089], + '00003152': [1200, 1600], + '00002552': [1200, 1600], + '00009402': [1125, 1500], + '00040672': [960, 1280], + '00024540': [960, 1280], + '00049770': [1457, 1589], + '00014533': [841, 1261], + '00006228': [1417, 1063], + '00034688': [1354, 2032], + '00032897': [1071, 1600], + '00024356': [2043, 3066], + '00019656': [1318, 1984], + '00035802': [2288, 2001], + '00017499': [1502, 1162], + '00046898': [1200, 1600], + '00040883': [1024, 1280], + '00031353': [1544, 1188], + '00028419': [1600, 1200], + '00048897': [2304, 3072], + '00040683': [1296, 1728], + '00042406': [848, 1200], + '00036007': [900, 1200], + '00010515': [1688, 1387], + '00048409': [5005, 3646], + '00032654': [1200, 1600], + '00037955': [1200, 1600], + '00038471': [3072, 2048], + '00036201': [913, 1328], + '00038619': [1728, 2304], + '00038165': [926, 2503], + '00033240': [1061, 1158], + '00023086': [1200, 1600], + '00041385': [1200, 1600], + '00014066': [2304, 3072], + '00049973': [1211, 1261], + '00043188': [2000, 3000], + '00047186': [1535, 1417], + '00046975': [1560, 2431], + '00034402': [1776, 2700], + '00017033': [1392, 1630], + '00041068': [1280, 960], + '00011024': [1317, 900], + '00048035': [1800, 1200], + '00033286': [994, 1500], + '00016613': [1152, 1536], + '00044160': [888, 1200], + '00021138': [902, 1128], + '00022300': [798, 1293], + '00034300': [1920, 2560], + '00008603': [1661, 1160], + '00045173': [2312, 903], + '00048616': [960, 1280], + '00048317': [3872, 2592], + '00045470': [1920, 1800], + '00043934': [1667, 2500], + '00010699': [2240, 1488], + '00030550': [1200, 1600], + '00010516': [1704, 2272], + '00001779': [1536, 2048], + '00018389': [1084, 1433], + '00013889': [3072, 2304], + '00022440': [2112, 2816], + '00024005': [2592, 1944], + '00046620': [960, 1280], + '00035227': [960, 1280], + '00033636': [1110, 1973], + '00003624': [1165, 1600], + '00033400': [1200, 1600], + '00013891': [1200, 1600], + '00022593': [1472, 1456], + '00009546': [1936, 2592], + '00022022': [1182, 1740], + '00022982': [1200, 1600], + '00039569': [1600, 1067], + '00009276': [930, 1240], + '00026777': [960, 1280], + '00047680': [1425, 882], + '00040785': [853, 1280], + '00002037': [1944, 2592], + '00005813': [1098, 987], + '00018328': [1128, 1242], + '00022318': [1500, 1694], + '00026654': [790, 1285], + '00012895': [1600, 1067], + '00007882': [980, 1024], + '00043771': [1008, 1043], + '00032990': [3621, 2539], + '00034094': [1175, 1600], + '00034302': [1463, 1134], + '00025021': [1503, 1520], + '00000771': [900, 1200], + '00025149': [1600, 1200], + '00005211': [1063, 1600], + '00049544': [1063, 1417], + '00025378': [1800, 2400], + '00024287': [1200, 1600], + '00013550': [2448, 3264], + '00008076': [1200, 1600], + '00039536': [1000, 1500], + '00020331': [1024, 1280], + '00002623': [1050, 1400], + '00031071': [873, 1320], + '00025266': [1024, 1536], + '00015109': [1213, 1600], + '00027390': [1200, 1600], + '00018894': [1584, 901], + '00049009': [900, 1203], + '00026671': [1201, 1601], + '00018668': [1024, 990], + '00016942': [1024, 1024], + '00046430': [1944, 3456], + '00033261': [1341, 1644], + '00017363': [2304, 2898], + '00045935': [2112, 2816], + '00027084': [900, 1200], + '00037716': [1611, 981], + '00030879': [1200, 1600], + '00027539': [1534, 1024], + '00030052': [1280, 852], + '00011015': [2808, 2060], + '00037004': [1920, 2560], + '00044012': [2240, 1680], + '00049818': [1704, 2272], + '00003541': [1200, 1600], + '00000520': [2448, 3264], + '00028331': [3264, 2448], + '00030244': [1200, 1600], + '00039079': [1600, 1200], + '00033432': [1600, 1200], + '00010533': [1200, 1600], + '00005916': [899, 1200], + '00038903': [1052, 1592], + '00025169': [1895, 850], + '00049042': [1200, 1600], + '00021828': [1280, 988], + '00013420': [3648, 2736], + '00045201': [1381, 1440], + '00021857': [776, 1296], + '00048810': [1168, 1263], + '00047860': [2592, 3888], + '00046960': [2304, 3072], + '00039357': [1200, 1600], + '00019620': [1536, 2048], + '00026710': [1944, 2592], + '00021277': [1079, 1151], + '00028387': [1128, 1585], + '00028796': [990, 1320], + '00035149': [1064, 1600], + '00020182': [1843, 1707], + '00018286': [2592, 1944], + '00035658': [1488, 1984], + '00008180': [1024, 1633], + '00018740': [1200, 1600], + '00044356': [1536, 2048], + '00038857': [1252, 1676], + '00035014': [1200, 1600], + '00044824': [1200, 1600], + '00009912': [1200, 1600], + '00014572': [2400, 1800], + '00001585': [1600, 1067], + '00047704': [1200, 1600], + '00038537': [920, 1200], + '00027941': [2200, 3000], + '00028526': [2592, 1944], + '00042353': [1280, 1024], + '00043409': [2000, 1500], + '00002209': [2592, 1944], + '00040841': [1613, 1974], + '00038889': [900, 1200], + '00046941': [1200, 1600], + '00014029': [846, 1269], + '00023091': [900, 1200], + '00036184': [877, 1350], + '00006165': [1200, 1600], + '00033991': [868, 2034], + '00035078': [1680, 2240], + '00045681': [1467, 1134], + '00043867': [1200, 1600], + '00003586': [1200, 1600], + '00039024': [1283, 2400], + '00048990': [1200, 1200], + '00044334': [960, 1280], + '00020939': [960, 1280], + '00031529': [1302, 1590], + '00014867': [2112, 2816], + '00034239': [1536, 2048], + '00031845': [1200, 1600], + '00045721': [1536, 2048], + '00025336': [1441, 1931], + '00040323': [900, 1152], + '00009133': [876, 1247], + '00033687': [2357, 3657], + '00038351': [1306, 1200], + '00022618': [1060, 1192], + '00001626': [777, 1329], + '00039137': [1071, 1600], + '00034896': [1426, 1590], + '00048502': [1187, 1837], + '00048077': [1712, 2288], + '00026239': [1200, 1600], + '00032687': [857, 1280], + '00006639': [1498, 780], + '00037738': [2112, 2816], + '00035760': [1123, 1447], + '00004897': [1083, 1393], + '00012141': [3584, 2016], + '00016278': [3234, 2281], + '00006661': [1787, 3276], + '00033040': [1200, 1800], + '00009881': [960, 1280], + '00008240': [2592, 1944], + '00023506': [960, 1280], + '00046982': [1693, 2480], + '00049632': [2310, 1638], + '00005473': [960, 1280], + '00013491': [2000, 3008], + '00005581': [1593, 1200], + '00005196': [1417, 2133], + '00049433': [1207, 1600], + '00012323': [1200, 1800], + '00021883': [1600, 2400], + '00031877': [2448, 3264], + '00046428': [1200, 1600], + '00000725': [881, 1463], + '00044936': [894, 1344], + '00012054': [3040, 4048], + '00025447': [900, 1200], + '00005290': [1520, 2272], + '00023326': [984, 1312], + '00047891': [1067, 1600], + '00026115': [1067, 1600], + '00010051': [1062, 1275], + '00005999': [1123, 1600], + '00021752': [1071, 1600], + '00041559': [1200, 1600], + '00025931': [836, 1410], + '00009327': [2848, 4288], + '00029735': [1905, 1373], + '00012922': [1024, 1547], + '00042259': [1548, 1024], + '00024949': [1050, 956], + '00014669': [900, 1200], + '00028028': [1170, 1730], + '00003183': [1152, 1535], + '00039304': [1050, 1680], + '00014939': [1904, 1240], + '00048366': [1600, 1200], + '00022406': [3264, 2448], + '00033363': [1125, 1500], + '00041230': [1125, 1500], + '00044222': [2105, 2472], + '00021950': [1200, 1200], + '00028475': [2691, 3515], + '00002149': [900, 1600], + '00033356': [1080, 1920], + '00041158': [960, 1280], + '00029672': [1536, 2048], + '00045816': [1023, 1153], + '00020471': [2076, 2716], + '00012398': [1067, 1600], + '00017884': [2048, 3072], + '00025132': [1200, 1600], + '00042429': [1362, 1980], + '00021285': [1127, 1200], + '00045113': [2792, 2528], + '00047915': [1200, 891], + '00009481': [1097, 924], + '00025448': [1760, 2400], + '00033911': [1759, 2197], + '00044684': [1200, 1600], + '00033754': [2304, 1728], + '00002733': [1536, 2048], + '00027371': [936, 1128], + '00019941': [685, 1591], + '00028479': [1944, 2592], + '00018451': [1028, 1028], + '00024067': [1000, 1352], + '00016524': [1704, 2272], + '00048926': [1944, 2592], + '00020992': [1024, 1280], + '00044576': [1024, 1280], + '00031796': [960, 1280], + '00043540': [2448, 3264], + '00049250': [1056, 1408], + '00030602': [2592, 3872], + '00046571': [1118, 1336], + '00024908': [1442, 1012], + '00018903': [3072, 2304], + '00032370': [1944, 2592], + '00043445': [1050, 1680], + '00030791': [2228, 3168], + '00046866': [2057, 3072], + '00047293': [1800, 2400], + '00024853': [1296, 1936], + '00014344': [1125, 1500], + '00041327': [960, 1280], + '00017867': [2592, 3872], + '00037615': [1664, 2496], + '00011247': [1605, 2934], + '00034664': [2304, 1728], + '00013733': [1024, 1280], + '00009125': [1200, 1600], + '00035163': [1654, 1233], + '00017537': [1200, 1600], + '00043423': [1536, 2048], + '00035755': [1154, 900], + '00021712': [1600, 1200], + '00000597': [2792, 1908], + '00033579': [882, 1181], + '00035830': [2112, 2816], + '00005917': [920, 1380], + '00029722': [2736, 3648], + '00039979': [1200, 1600], + '00040854': [1606, 2400], + '00039884': [2848, 4288], + '00003508': [1128, 1488], + '00019862': [1200, 1600], + '00041813': [1226, 1160], + '00007121': [985, 1072], + '00013315': [883, 1199], + '00049822': [922, 1382], + '00027622': [1434, 1680], + '00047689': [1536, 2048], + '00017415': [1491, 2283], + '00023713': [927, 1287], + '00001632': [1200, 1600], + '00033104': [1200, 1600], + '00017643': [1002, 1200], + '00038396': [1330, 1999], + '00027614': [2166, 2048], + '00025962': [1600, 1200], + '00015915': [1067, 1600], + '00008940': [1942, 2744], + '00012468': [2000, 2000], + '00046953': [828, 1442], + '00002084': [1067, 1600], + '00040245': [2657, 1898], + '00023718': [900, 1440], + '00022770': [924, 1280], + '00028957': [960, 1280], + '00001054': [2048, 3072], + '00040541': [1369, 1809], + '00024869': [960, 1280], + '00037655': [900, 1440], + '00037200': [2171, 2575], + '00037390': [1394, 1237], + '00025318': [1054, 1024], + '00021634': [1800, 2400], + '00044217': [1003, 1024], + '00014877': [1200, 1600], + '00029504': [1224, 1632], + '00016422': [960, 1280], + '00028015': [1944, 2592], + '00006235': [967, 1291], + '00045909': [2272, 1704] + } + + def __init__(self, subset=919, **kwargs): + + assert subset in (50, 300, 919), \ + 'ImageNet-S has three subsets, i.e., '\ + 'ImageNet-S50, ImageNet-S300 and ImageNet-S919.' + if subset == 50: + self.CLASSES = self.CLASSES50 + elif subset == 300: + self.CLASSES = self.CLASSES300 + else: + self.CLASSES = self.CLASSES919 + + super(ImageNetSDataset, self).__init__( + img_suffix='.JPEG', + seg_map_suffix='.png', + reduce_zero_label=False, + ignore_index=1000, + **kwargs) + + self.subset = subset + gt_seg_map_loader_cfg = kwargs.get('gt_seg_map_loader_cfg', None) + self.gt_seg_map_loader = LoadImageNetSAnnotations( + ) if gt_seg_map_loader_cfg is None else LoadImageNetSAnnotations( + **gt_seg_map_loader_cfg) + + def pre_eval(self, preds, indices): + """Collect eval result for ImageNet-S. In LoadImageNetSImageFromFile, + the too large images have been downsampled. Here the preds should be + upsampled back after argmax. + + Args: + preds (list[torch.Tensor] | torch.Tensor): the segmentation logit + after argmax, shape (N, H, W). + indices (list[int] | int): the prediction related ground truth + indices. + + Returns: + list[torch.Tensor]: (area_intersect, area_union, area_prediction, + area_ground_truth). + """ + # In order to compat with batch inference + if not isinstance(indices, list): + indices = [indices] + if not isinstance(preds, list): + preds = [preds] + + pre_eval_results = [] + + for pred, index in zip(preds, indices): + seg_map = self.get_gt_seg_map_by_idx(index) + pred = mmcv.imresize( + pred, + size=(seg_map.shape[1], seg_map.shape[0]), + interpolation='nearest') + pre_eval_results.append( + intersect_and_union( + pred, + seg_map, + len(self.CLASSES), + self.ignore_index, + # as the labels has been converted when dataset initialized + # in `get_palette_for_custom_classes ` this `label_map` + # should be `dict()`, see + # https://github.com/open-mmlab/mmsegmentation/issues/1415 + # for more ditails + label_map=dict(), + reduce_zero_label=self.reduce_zero_label)) + + return pre_eval_results + + def results2img(self, results, imgfile_prefix, to_label_id, indices=None): + """Write the segmentation results to images for ImageNetS. The results + should be converted as RGB images due to 919 (>256) categroies. In + LoadImageNetSImageFromFile, the too large images have been downsampled. + Here the results should be upsampled back after argmax. + + Args: + results (list[ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission. + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + if indices is None: + indices = list(range(len(self))) + + result_files = [] + for result, idx in zip(results, indices): + + filename = self.img_infos[idx]['filename'] + + directory = filename.split('/')[-2] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, directory, + f'{basename}.png') + + # The index range of output is from 0 to 919/300/50. + result_rgb = np.zeros(shape=(result.shape[0], result.shape[1], 3)) + result_rgb[:, :, 0] = result % 256 + result_rgb[:, :, 1] = result // 256 + + if basename.split('_')[2] in self.LARGES.keys(): + result_rgb = mmcv.imresize( + result_rgb, + size=(self.LARGES[basename.split('_')[2]][1], + self.LARGES[basename.split('_')[2]][0]), + interpolation='nearest') + + mmcv.mkdir_or_exist(osp.join(imgfile_prefix, directory)) + output = Image.fromarray(result_rgb.astype(np.uint8)) + output.save(png_filename) + result_files.append(png_filename) + + return result_files + + def format_results(self, + results, + imgfile_prefix, + to_label_id=True, + indices=None): + """Format the results into dir (standard format for ImageNetS + evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str | None): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". + to_label_id (bool): whether convert output to label_id for + submission. Default: False + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + + if indices is None: + indices = list(range(len(self))) + + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' + + result_files = self.results2img(results, imgfile_prefix, to_label_id, + indices) + + return result_files diff --git a/data_utils/easyportrait/mmseg/datasets/isaid.py b/data_utils/easyportrait/mmseg/datasets/isaid.py new file mode 100644 index 0000000000000000000000000000000000000000..db24f937650634066b69eb2288c506faf6479078 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/isaid.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import mmcv +from mmcv.utils import print_log + +from ..utils import get_root_logger +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class iSAIDDataset(CustomDataset): + """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images + In segmentation map annotation for iSAID dataset, which is included + in 16 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'Ground_Track_Field', + 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', + 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', + 'Harbor') + + PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], + [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], + [0, 127, 191], [0, 127, 255], [0, 100, 155]] + + def __init__(self, **kwargs): + super(iSAIDDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + ignore_index=255, + **kwargs) + assert self.file_client.exists(self.img_dir) + + def load_annotations(self, + img_dir, + img_suffix, + ann_dir, + seg_map_suffix=None, + split=None): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + with open(split) as f: + for line in f: + name = line.strip() + img_info = dict(filename=name + img_suffix) + if ann_dir is not None: + ann_name = name + '_instance_color_RGB' + seg_map = ann_name + seg_map_suffix + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + else: + for img in mmcv.scandir(img_dir, img_suffix, recursive=True): + img_info = dict(filename=img) + if ann_dir is not None: + seg_img = img + seg_map = seg_img.replace( + img_suffix, '_instance_color_RGB' + seg_map_suffix) + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + + print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + return img_infos diff --git a/data_utils/easyportrait/mmseg/datasets/isprs.py b/data_utils/easyportrait/mmseg/datasets/isprs.py new file mode 100644 index 0000000000000000000000000000000000000000..5f23e1a9b61942361ed0811052700f85e5d0a1a0 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/isprs.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ISPRSDataset(CustomDataset): + """ISPRS dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter') + + PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + def __init__(self, **kwargs): + super(ISPRSDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) diff --git a/data_utils/easyportrait/mmseg/datasets/lapa.py b/data_utils/easyportrait/mmseg/datasets/lapa.py new file mode 100644 index 0000000000000000000000000000000000000000..614503561efd1b8ae2db0e96da27fbffad1ad012 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/lapa.py @@ -0,0 +1,36 @@ +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class LaPaDataset(CustomDataset): + """EasyPortrait dataset. + + In segmentation map annotation for LaPa, 0 stands for background, + which is included in 11 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'skin', 'left eyebrow', + 'right eyebrow', 'left eye', 'right eye', + 'nose', 'upper lip', 'inner mouth', 'lower lip', 'hair') + + PALETTE = [[0, 0, 0], [0, 153, 255], [102, 255, 153], + [0, 204, 153], [255, 255, 102], [255, 255, 204], + [255, 153, 0], [255, 102, 255], [102, 0, 51], + [255, 204, 255], [255, 0, 102]] + + def __init__(self, **kwargs): + super(EasyPortraitDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/datasets/loveda.py b/data_utils/easyportrait/mmseg/datasets/loveda.py new file mode 100644 index 0000000000000000000000000000000000000000..90d654f625a0e5ea381c0fdb1ebb4cc6c921498f --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/loveda.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class LoveDADataset(CustomDataset): + """LoveDA dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural') + + PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]] + + def __init__(self, **kwargs): + super(LoveDADataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) + + def results2img(self, results, imgfile_prefix, indices=None): + """Write the segmentation results to images. + + Args: + results (list[ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + indices (list[int], optional): Indices of input results, if not + set, all the indices of the dataset will be used. + Default: None. + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + for result, idx in zip(results, indices): + + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + # The index range of official requirement is from 0 to 6. + output = Image.fromarray(result.astype(np.uint8)) + output.save(png_filename) + result_files.append(png_filename) + + return result_files + + def format_results(self, results, imgfile_prefix, indices=None): + """Format the results into dir (standard format for LoveDA evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". + indices (list[int], optional): Indices of input results, + if not set, all the indices of the dataset will be used. + Default: None. + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + if indices is None: + indices = list(range(len(self))) + + assert isinstance(results, list), 'results must be a list.' + assert isinstance(indices, list), 'indices must be a list.' + + result_files = self.results2img(results, imgfile_prefix, indices) + + return result_files diff --git a/data_utils/easyportrait/mmseg/datasets/night_driving.py b/data_utils/easyportrait/mmseg/datasets/night_driving.py new file mode 100644 index 0000000000000000000000000000000000000000..6620586e3f11f56690982568c7761dfa7b3dbf50 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/night_driving.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class NightDrivingDataset(CityscapesDataset): + """NightDrivingDataset dataset.""" + + def __init__(self, **kwargs): + super().__init__( + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtCoarse_labelTrainIds.png', + **kwargs) diff --git a/data_utils/easyportrait/mmseg/datasets/pascal_context.py b/data_utils/easyportrait/mmseg/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..20285d8f5af2839e6b669c5acfc648bd072ed9ee --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pascal_context.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalContextDataset(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', + 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', + 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', + 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', + 'floor', 'flower', 'food', 'grass', 'ground', 'horse', + 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', + 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', + 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', + 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', + 'window', 'wood') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists(self.img_dir) and self.split is not None + + +@DATASETS.register_module() +class PascalContextDataset59(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext59, background is not + included in 59 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed + to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', + 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', + 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', + 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', + 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', + 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', + 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', + 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', + 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood') + + PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], + [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], + [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], + [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], + [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], + [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], + [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], + [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], + [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], + [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], + [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], + [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], + [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], + [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset59, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=True, + **kwargs) + assert self.file_client.exists(self.img_dir) and self.split is not None diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/__init__.py b/data_utils/easyportrait/mmseg/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8256a6fe2f03a381ee62a0271411d5102caf8c43 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .compose import Compose +from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, + Transpose, to_tensor) +from .loading import LoadAnnotations, LoadImageFromFile +from .test_time_aug import MultiScaleFlipAug +from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomFlip, RandomMosaic, RandomRotate, Rerange, + Resize, RGB2Gray, SegRescale) + +__all__ = [ + 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', + 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', + 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', + 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', + 'RandomMosaic' +] diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/compose.py b/data_utils/easyportrait/mmseg/datasets/pipelines/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..30280c1332abc253434ae4e88271d73de2690ecb --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/compose.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections + +from mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + """Compose multiple transforms sequentially. + + Args: + transforms (Sequence[dict | callable]): Sequence of transform object or + config dict to be composed. + """ + + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += f' {t}' + format_string += '\n)' + return format_string diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/formating.py b/data_utils/easyportrait/mmseg/datasets/pipelines/formating.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e53bfebe3e76412600361da01c36cb440bafd8 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/formating.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# flake8: noqa +import warnings + +from .formatting import * + +warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be ' + 'deprecated in 2021, please replace it with ' + 'mmseg.datasets.pipelines.formatting.') diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/formatting.py b/data_utils/easyportrait/mmseg/datasets/pipelines/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..049652e684f4009fd079e2f75cf78bbacd130276 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/formatting.py @@ -0,0 +1,289 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence + +import mmcv +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class ToTensor(object): + """Convert some results to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert data in results to :obj:`torch.Tensor`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class ImageToTensor(object): + """Convert image to :obj:`torch.Tensor` by given keys. + + The dimension order of input image is (H, W, C). The pipeline will convert + it to (C, H, W). If only 2 dimension (H, W) is given, the output would be + (1, H, W). + + Args: + keys (Sequence[str]): Key of images to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = to_tensor(img.transpose(2, 0, 1)) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose(object): + """Transpose some results by given keys. + + Args: + keys (Sequence[str]): Keys of results to be transposed. + order (Sequence[int]): Order of transpose. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class ToDataContainer(object): + """Convert results to :obj:`mmcv.DataContainer` by given fields. + + Args: + fields (Sequence[dict]): Each field is a dict like + ``dict(key='xxx', **kwargs)``. The ``key`` in result will + be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. + Default: ``(dict(key='img', stack=True), + dict(key='gt_semantic_seg'))``. + """ + + def __init__(self, + fields=(dict(key='img', + stack=True), dict(key='gt_semantic_seg'))): + self.fields = fields + + def __call__(self, results): + """Call function to convert data in results to + :obj:`mmcv.DataContainer`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted to + :obj:`mmcv.DataContainer`. + """ + + for field in self.fields: + field = field.copy() + key = field.pop('key') + results[key] = DC(results[key], **field) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(fields={self.fields})' + + +@PIPELINES.register_module() +class DefaultFormatBundle(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img" + and "gt_semantic_seg". These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with + default bundle. + """ + + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + if 'gt_semantic_seg' in results: + # convert to long + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, + ...].astype(np.int64)), + stack=True) + return results + + def __repr__(self): + return self.__class__.__name__ + + +@PIPELINES.register_module() +class Collect(object): + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. Typically keys + is set to some subset of "img", "gt_semantic_seg". + + The "img_meta" item is always populated. The contents of the "img_meta" + dictionary depends on "meta_keys". By default this includes: + + - "img_shape": shape of the image input to the network as a tuple + (h, w, c). Note that images may be zero padded on the bottom/right + if the batch tensor is larger than this shape. + + - "scale_factor": a float indicating the preprocessing scale + + - "flip": a boolean indicating if image flip transform was used + + - "filename": path to the image file + + - "ori_shape": original shape of the image as a tuple (h, w, c) + + - "pad_shape": image shape after padding + + - "img_norm_cfg": a dict of normalization information: + - mean - per channel mean subtraction + - std - per channel std divisor + - to_rgb - bool indicating if bgr was converted to rgb + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: (``filename``, ``ori_filename``, ``ori_shape``, + ``img_shape``, ``pad_shape``, ``scale_factor``, ``flip``, + ``flip_direction``, ``img_norm_cfg``) + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'img_norm_cfg')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + """Call function to collect keys in results. The keys in ``meta_keys`` + will be converted to :obj:mmcv.DataContainer. + + Args: + results (dict): Result dict contains the data to collect. + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``img_metas`` + """ + + data = {} + img_meta = {} + for key in self.meta_keys: + img_meta[key] = results.get(key, None) + data['img_metas'] = DC(img_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/loading.py b/data_utils/easyportrait/mmseg/datasets/pipelines/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..6ccaf7da85d86f5b7f703a4cd468f5454d9822a8 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/loading.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class LoadImageFromFile(object): + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`mmcv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'cv2' + """ + + def __init__(self, + to_float32=False, + color_type='color', + file_client_args=dict(backend='disk'), + imdecode_backend='cv2'): + self.to_float32 = to_float32 + self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('img_prefix') is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + img_bytes = self.file_client.get(filename) + img = mmcv.imfrombytes( + img_bytes, flag=self.color_type, backend=self.imdecode_backend) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results['img_norm_cfg'] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(to_float32={self.to_float32},' + repr_str += f"color_type='{self.color_type}'," + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str + + +@PIPELINES.register_module() +class LoadAnnotations(object): + """Load annotations for semantic segmentation. + + Args: + reduce_zero_label (bool): Whether reduce all label value by 1. + Usually used for datasets where 0 is background label. + Default: False. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'pillow' + """ + + def __init__(self, + reduce_zero_label=False, + file_client_args=dict(backend='disk'), + imdecode_backend='pillow'): + self.reduce_zero_label = reduce_zero_label + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('seg_prefix', None) is not None: + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + else: + filename = results['ann_info']['seg_map'] + img_bytes = self.file_client.get(filename) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + # reduce zero_label + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_semantic_seg'] = gt_semantic_seg + results['seg_fields'].append('gt_semantic_seg') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label},' + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/test_time_aug.py b/data_utils/easyportrait/mmseg/datasets/pipelines/test_time_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..49640879f24167ef7cc9aa63ce880b2da81742df --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/test_time_aug.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv + +from ..builder import PIPELINES +from .compose import Compose + + +@PIPELINES.register_module() +class MultiScaleFlipAug(object): + """Test-time augmentation with multiple scales and flipping. + + An example configuration is as followed: + + .. code-block:: + + img_scale=(2048, 1024), + img_ratios=[0.5, 1.0], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ] + + After MultiScaleFLipAug with above configuration, the results are wrapped + into lists of the same length as followed: + + .. code-block:: + + dict( + img=[...], + img_shape=[...], + scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] + flip=[False, True, False, True] + ... + ) + + Args: + transforms (list[dict]): Transforms to apply in each augmentation. + img_scale (None | tuple | list[tuple]): Images scales for resizing. + img_ratios (float | list[float]): Image ratios for resizing + flip (bool): Whether apply flip augmentation. Default: False. + flip_direction (str | list[str]): Flip augmentation directions, + options are "horizontal" and "vertical". If flip_direction is list, + multiple flip augmentations will be applied. + It has no effect when flip == False. Default: "horizontal". + """ + + def __init__(self, + transforms, + img_scale, + img_ratios=None, + flip=False, + flip_direction='horizontal'): + if flip: + trans_index = { + key['type']: index + for index, key in enumerate(transforms) + } + if 'RandomFlip' in trans_index and 'Pad' in trans_index: + assert trans_index['RandomFlip'] < trans_index['Pad'], \ + 'Pad must be executed after RandomFlip when flip is True' + self.transforms = Compose(transforms) + if img_ratios is not None: + img_ratios = img_ratios if isinstance(img_ratios, + list) else [img_ratios] + assert mmcv.is_list_of(img_ratios, float) + if img_scale is None: + # mode 1: given img_scale=None and a range of image ratio + self.img_scale = None + assert mmcv.is_list_of(img_ratios, float) + elif isinstance(img_scale, tuple) and mmcv.is_list_of( + img_ratios, float): + assert len(img_scale) == 2 + # mode 2: given a scale and a range of image ratio + self.img_scale = [(int(img_scale[0] * ratio), + int(img_scale[1] * ratio)) + for ratio in img_ratios] + else: + # mode 3: given multiple scales + self.img_scale = img_scale if isinstance(img_scale, + list) else [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None + self.flip = flip + self.img_ratios = img_ratios + self.flip_direction = flip_direction if isinstance( + flip_direction, list) else [flip_direction] + assert mmcv.is_list_of(self.flip_direction, str) + if not self.flip and self.flip_direction != ['horizontal']: + warnings.warn( + 'flip_direction has no effect when flip is set to False') + if (self.flip + and not any([t['type'] == 'RandomFlip' for t in transforms])): + warnings.warn( + 'flip has no effect when RandomFlip is not in transforms') + + def __call__(self, results): + """Call function to apply test time augment transforms on results. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + + aug_data = [] + if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): + h, w = results['img'].shape[:2] + img_scale = [(int(w * ratio), int(h * ratio)) + for ratio in self.img_ratios] + else: + img_scale = self.img_scale + flip_aug = [False, True] if self.flip else [False] + for scale in img_scale: + for flip in flip_aug: + for direction in self.flip_direction: + _results = results.copy() + _results['scale'] = scale + _results['flip'] = flip + _results['flip_direction'] = direction + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'img_scale={self.img_scale}, flip={self.flip})' + repr_str += f'flip_direction={self.flip_direction}' + return repr_str diff --git a/data_utils/easyportrait/mmseg/datasets/pipelines/transforms.py b/data_utils/easyportrait/mmseg/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..5673b646fa654bcba39ea897d37a4e7371b5c77f --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/pipelines/transforms.py @@ -0,0 +1,1335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import mmcv +import numpy as np +from mmcv.utils import deprecated_api_warning, is_tuple_of +from numpy import random + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class ResizeToMultiple(object): + """Resize images & seg to multiple of divisor. + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def __call__(self, results): + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str + + +@PIPELINES.register_module() +class Resize(object): + """Resize images & seg. + + This transform resizes the input image to some scale. If the input dict + contains the key "scale", then the scale in the input dict is used, + otherwise the specified scale in the init method is used. + + ``img_scale`` can be None, a tuple (single-scale) or a list of tuple + (multi-scale). There are 4 multiscale modes: + + - ``ratio_range is not None``: + 1. When img_scale is None, img_scale is the shape of image in results + (img_scale = results['img'].shape[:2]) and the image is resized based + on the original size. (mode 1) + 2. When img_scale is a tuple (single-scale), randomly sample a ratio from + the ratio range and multiply it with the image scale. (mode 2) + + - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a + scale from the a range. (mode 3) + + - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a + scale from multiple scales. (mode 4) + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + Default:None. + multiscale_mode (str): Either "range" or "value". + Default: 'range' + ratio_range (tuple[float]): (min_ratio, max_ratio). + Default: None + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Default: True + min_size (int, optional): The minimum size for input and the shape + of the image and seg map will not be less than ``min_size``. + As the shape of model input is fixed like 'SETR' and 'BEiT'. + Following the setting in these models, resized images must be + bigger than the crop size in ``slide_inference``. Default: None + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + min_size=None): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given img_scale=None and a range of image ratio + # mode 2: given a scale and a range of image ratio + assert self.img_scale is None or len(self.img_scale) == 1 + else: + # mode 3 and 4: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + self.min_size = min_size + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + if self.img_scale is None: + h, w = results['img'].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) + else: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + if self.keep_ratio: + if self.min_size is not None: + # TODO: Now 'min_size' is an 'int' which means the minimum + # shape of images is (min_size, min_size, 3). 'min_size' + # with tuple type will be supported, i.e. the width and + # height are not equal. + if min(results['scale']) < self.min_size: + new_short = self.min_size + else: + new_short = min(results['scale']) + + h, w = results['img'].shape[:2] + if h > w: + new_h, new_w = new_short * h / w, new_short + else: + new_h, new_w = new_short, new_short * w / h + results['scale'] = (new_h, new_w) + + img, scale_factor = mmcv.imrescale( + results['img'], results['scale'], return_scale=True) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results['img'], results['scale'], return_scale=True) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape # in case that there is no padding + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], results['scale'], interpolation='nearest') + else: + gt_seg = mmcv.imresize( + results[key], results['scale'], interpolation='nearest') + results[key] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + self._random_scale(results) + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +@PIPELINES.register_module() +class RandomFlip(object): + """Flip the image & seg. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + Args: + prob (float, optional): The flipping probability. Default: None. + direction(str, optional): The flipping direction. Options are + 'horizontal' and 'vertical'. Default: 'horizontal'. + """ + + @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') + def __init__(self, prob=None, direction='horizontal'): + self.prob = prob + self.direction = direction + if prob is not None: + assert prob >= 0 and prob <= 1 + assert direction in ['horizontal', 'vertical'] + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added into + result dict. + """ + + if 'flip' not in results: + flip = True if np.random.rand() < self.prob else False + results['flip'] = flip + if 'flip_direction' not in results: + results['flip_direction'] = self.direction + if results['flip']: + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + # use copy() to make numpy stride positive + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']).copy() + return results + + def __repr__(self): + return self.__class__.__name__ + f'(prob={self.prob})' + + +@PIPELINES.register_module() +class Pad(object): + """Pad the image & mask. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + """ + + def __init__(self, + size=None, + size_divisor=None, + pad_val=0, + seg_pad_val=255): + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + # only one of size and size_divisor should be valid + assert size is not None or size_divisor is not None + assert size is None or size_divisor is None + + def _pad_img(self, results): + """Pad images according to ``self.size``.""" + if self.size is not None: + padded_img = mmcv.impad( + results['img'], shape=self.size, pad_val=self.pad_val) + elif self.size_divisor is not None: + padded_img = mmcv.impad_to_multiple( + results['img'], self.size_divisor, pad_val=self.pad_val) + results['img'] = padded_img + results['pad_shape'] = padded_img.shape + results['pad_fixed_size'] = self.size + results['pad_size_divisor'] = self.size_divisor + + def _pad_seg(self, results): + """Pad masks according to ``results['pad_shape']``.""" + for key in results.get('seg_fields', []): + results[key] = mmcv.impad( + results[key], + shape=results['pad_shape'][:2], + pad_val=self.seg_pad_val) + + def __call__(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + + self._pad_img(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \ + f'pad_val={self.pad_val})' + return repr_str + + +@PIPELINES.register_module() +class Normalize(object): + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + + results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, + self.to_rgb) + results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ + f'{self.to_rgb})' + return repr_str + + +@PIPELINES.register_module() +class Rerange(object): + """Rerange the image pixel value. + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def __call__(self, results): + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@PIPELINES.register_module() +class CLAHE(object): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def __call__(self, results): + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@PIPELINES.register_module() +class RandomCrop(object): + """Random crop the image & seg. + + Args: + crop_size (tuple): Expected size after cropping, (h, w). + cat_max_ratio (float): The maximum ratio that single category could + occupy. + """ + + def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def get_crop_bbox(self, img): + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def __call__(self, results): + """Call function to randomly crop images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.get_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = self.get_crop_bbox(img) + + # crop the image + img = self.crop(img, crop_bbox) + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@PIPELINES.register_module() +class RandomRotate(object): + """Rotate the image & seg. + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + def __call__(self, results): + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate = True if np.random.rand() < self.prob else False + degree = np.random.uniform(min(*self.degree), max(*self.degree)) + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@PIPELINES.register_module() +class RGB2Gray(object): + """Convert RGB image to grayscale image. + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def __call__(self, results): + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@PIPELINES.register_module() +class AdjustGamma(object): + """Using gamma correction to process the image. + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@PIPELINES.register_module() +class SegRescale(object): + """Rescale semantic segmentation maps. + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def __call__(self, results): + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@PIPELINES.register_module() +class PhotoMetricDistortion(object): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, img, alpha=1, beta=0): + """Multiple with alpha and add beat with clip.""" + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img): + """Brightness distortion.""" + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img): + """Contrast distortion.""" + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img): + """Saturation distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img): + """Hue distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def __call__(self, results): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str + + +@PIPELINES.register_module() +class RandomCutOut(object): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + Args: + prob (float): cutout probability. + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. + """ + + def __init__(self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None): + + assert 0 <= prob and prob <= 1 + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) + self.prob = prob + self.n_holes = n_holes + self.fill_in = fill_in + self.seg_fill_in = seg_fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + def __call__(self, results): + """Call function to drop some regions of image.""" + cutout = True if np.random.rand() < self.prob else False + if cutout: + h, w, c = results['img'].shape + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + for _ in range(n_holes): + x1 = np.random.randint(0, w) + y1 = np.random.randint(0, h) + index = np.random.randint(0, len(self.candidates)) + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' + return repr_str + + +@PIPELINES.register_module() +class RandomMosaic(object): + """Mosaic augmentation. Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Args: + prob (float): mosaic probability. + img_scale (Sequence[int]): Image size after mosaic pipeline of + a single image. The size of the output image is four times + that of a single image. The output image comprises 4 single images. + Default: (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Default: (0.5, 1.5). + pad_val (int): Pad value. Default: 0. + seg_pad_val (int): Pad value of segmentation map. Default: 255. + """ + + def __init__(self, + prob, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + pad_val=0, + seg_pad_val=255): + assert 0 <= prob and prob <= 1 + assert isinstance(img_scale, tuple) + self.prob = prob + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def __call__(self, results): + """Call function to make a mosaic of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mosaic transformed. + """ + mosaic = True if np.random.rand() < self.prob else False + if mosaic: + results = self._mosaic_transform_img(results) + results = self._mosaic_transform_seg(results) + return results + + def get_indexes(self, dataset): + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + indexes = [random.randint(0, len(dataset)) for _ in range(3)] + return indexes + + def _mosaic_transform_img(self, results): + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + if len(results['img'].shape) == 3: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + self.center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + self.center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = result_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape + results['ori_shape'] = mosaic_img.shape + + return results + + def _mosaic_transform_seg(self, results): + """Mosaic transform function for label annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + for key in results.get('seg_fields', []): + mosaic_seg = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.seg_pad_val, + dtype=results[key].dtype) + + # mosaic center x, y + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + gt_seg_i = result_patch[key] + h_i, w_i = gt_seg_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + gt_seg_i = mmcv.imresize( + gt_seg_i, + (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)), + interpolation='nearest') + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, gt_seg_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c, + x1_c:x2_c] + + results[key] = mosaic_seg + + return results + + def _mosaic_combine(self, loc, center_position_xy, img_shape_wh): + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'seg_pad_val={self.pad_val})' + return repr_str diff --git a/data_utils/easyportrait/mmseg/datasets/potsdam.py b/data_utils/easyportrait/mmseg/datasets/potsdam.py new file mode 100644 index 0000000000000000000000000000000000000000..2986b8faa02f9633295879a791ed6202e2f29919 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/potsdam.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PotsdamDataset(CustomDataset): + """ISPRS Potsdam dataset. + + In segmentation map annotation for Potsdam dataset, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter') + + PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + def __init__(self, **kwargs): + super(PotsdamDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) diff --git a/data_utils/easyportrait/mmseg/datasets/samplers/__init__.py b/data_utils/easyportrait/mmseg/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da09effaf20fefe1a102277672b98db7d884f002 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .distributed_sampler import DistributedSampler + +__all__ = ['DistributedSampler'] diff --git a/data_utils/easyportrait/mmseg/datasets/samplers/distributed_sampler.py b/data_utils/easyportrait/mmseg/datasets/samplers/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9bf357973e408c8e0c3c32847af9a2a18a7740 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/samplers/distributed_sampler.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import division +from typing import Iterator, Optional + +import torch +from torch.utils.data import Dataset +from torch.utils.data import DistributedSampler as _DistributedSampler + +from mmseg.core.utils import sync_random_seed +from mmseg.utils import get_device + + +class DistributedSampler(_DistributedSampler): + """DistributedSampler inheriting from + `torch.utils.data.DistributedSampler`. + + Args: + datasets (Dataset): the dataset will be loaded. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, world_size is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within num_replicas. + By default, rank is retrieved from the current distributed group. + shuffle (bool): If True (default), sampler will shuffle the indices. + seed (int): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + """ + + def __init__(self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed=0) -> None: + super().__init__( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + device = get_device() + self.seed = sync_random_seed(seed, device) + + def __iter__(self) -> Iterator: + """ + Yields: + Iterator: iterator of indices for rank. + """ + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.epoch + self.seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/data_utils/easyportrait/mmseg/datasets/stare.py b/data_utils/easyportrait/mmseg/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..a24d1d9570b710e70e3437f4c33a1c84299c6313 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/stare.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class STAREDataset(CustomDataset): + """STARE dataset. + + In segmentation map annotation for STARE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.ah.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(STAREDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.ah.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/data_utils/easyportrait/mmseg/datasets/voc.py b/data_utils/easyportrait/mmseg/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..3cec9e3505e5931868869ccebf43263e8883d660 --- /dev/null +++ b/data_utils/easyportrait/mmseg/datasets/voc.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalVOCDataset(CustomDataset): + """Pascal VOC dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + + CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', + 'train', 'tvmonitor') + + PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + def __init__(self, split, **kwargs): + super(PascalVOCDataset, self).__init__( + img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/data_utils/easyportrait/mmseg/models/__init__.py b/data_utils/easyportrait/mmseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87d8108e3f1977bf4830fa83ad7498081d2a9a51 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, + build_head, build_loss, build_segmentor) +from .decode_heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .segmentors import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', + 'build_head', 'build_loss', 'build_segmentor' +] diff --git a/data_utils/easyportrait/mmseg/models/backbones/__init__.py b/data_utils/easyportrait/mmseg/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10cc9976adf092a7748f9877fc64ba4ecb1ef7da --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beit import BEiT +from .bisenetv1 import BiSeNetV1 +from .bisenetv2 import BiSeNetV2 +from .cgnet import CGNet +from .erfnet import ERFNet +from .fast_scnn import FastSCNN +from .hrnet import HRNet +from .icnet import ICNet +from .mae import MAE +from .mit import MixVisionTransformer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mscan import MSCAN +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnext import ResNeXt +from .stdc import STDCContextPathNet, STDCNet +from .swin import SwinTransformer +from .timm_backbone import TIMMBackbone +from .twins import PCPVT, SVT +from .unet import UNet +from .vit import VisionTransformer + +__all__ = [ + 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', + 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', + 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', + 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', + 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'MSCAN' +] \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/models/backbones/beit.py b/data_utils/easyportrait/mmseg/models/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..fade60137dff4a5ab7f16d2d10b829b78232e656 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/beit.py @@ -0,0 +1,559 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmcv.runner import BaseModule, ModuleList, _load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.utils import get_root_logger +from ..builder import BACKBONES +from ..utils import PatchEmbed +from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer + +try: + from scipy import interpolate +except ImportError: + interpolate = None + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + self.window_size = window_size + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (tuple[int], optional): The height and width of the window. + Default: None. + init_values (float, optional): Initialize the values of BEiTAttention + and FFN with learnable scaling. Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + bias='qv_bias', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=None, + attn_cfg=dict(), + ffn_cfg=dict(add_identity=False), + init_values=None): + attn_cfg.update(dict(window_size=window_size, qk_scale=None)) + + super(BEiTTransformerEncoderLayer, self).__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + qkv_bias=bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + attn_cfg=attn_cfg, + ffn_cfg=ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + self.gamma_1 = nn.Parameter( + init_values * torch.ones((embed_dims)), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((embed_dims)), requires_grad=True) + + def build_attn(self, attn_cfg): + self.attn = BEiTAttention(**attn_cfg) + + def forward(self, x): + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x))) + return x + + +@BACKBONES.register_module() +class BEiT(BaseModule): + """BERT Pre-Training of Image Transformers. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_layers (int): Depth of transformer. Default: 12. + num_heads (int): Number of attention heads. Default: 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qv_bias (bool): Enable bias for qv if True. Default: True. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): Model pretrained path. Default: None. + init_values (float): Initialize the values of BEiTAttention and FFN + with learnable scaling. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qv_bias=True, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super(BEiT, self).__init__(init_cfg=init_cfg) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.img_size = img_size + self.patch_size = patch_size + self.norm_eval = norm_eval + self.pretrained = pretrained + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.num_fcs = num_fcs + self.qv_bias = qv_bias + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.patch_norm = patch_norm + self.init_values = init_values + self.window_size = (img_size[0] // patch_size, + img_size[1] // patch_size) + self.patch_shape = self.window_size + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self._build_patch_embedding() + self._build_layers() + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + def _build_patch_embedding(self): + """Build patch embedding layer.""" + self.patch_embed = PatchEmbed( + in_channels=self.in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, + norm_cfg=self.norm_cfg if self.patch_norm else None, + init_cfg=None) + + def _build_layers(self): + """Build transformer encoding layers.""" + + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + BEiTTransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias='qv_bias' if self.qv_bias else False, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.window_size, + init_values=self.init_values)) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _geometric_sequence_interpolation(self, src_size, dst_size, sequence, + num): + """Get new sequence via geometric sequence interpolation. + + Args: + src_size (int): Pos_embedding size in pre-trained model. + dst_size (int): Pos_embedding size in the current model. + sequence (tensor): The relative position bias of the pretrain + model after removing the extra tokens. + num (int): Number of attention heads. + Returns: + new_sequence (tensor): Geometric sequence interpolate the + pre-trained relative position bias to the size of + the current model. + """ + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + # Here is a binary function. + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + # The position of each interpolated point is determined + # by the ratio obtained by dichotomy. + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + r_ids = [-_ for _ in reversed(dis)] + x = r_ids + [0] + dis + y = r_ids + [0] + dis + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + # Interpolation functions are being executed and called. + new_sequence = [] + for i in range(num): + z = sequence[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + new_sequence.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence)) + new_sequence = torch.cat(new_sequence, dim=-1) + return new_sequence + + def resize_rel_pos_embed(self, checkpoint): + """Resize relative pos_embed weights. + + This function is modified from + https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + Args: + checkpoint (dict): Key and value of the pretrain model. + Returns: + state_dict (dict): Interpolate the relative pos_embed weights + in the pre-train model to the current model size. + """ + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + # In order to keep the center of pos_bias as consistent as + # possible after interpolation, and vice versa in the edge + # area, the geometric sequence interpolation method is adopted. + if 'relative_position_bias_table' in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = self.state_dict()[key].size() + dst_patch_shape = self.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + # Count the number of extra tokens. + num_extra_tokens = dst_num_pos - ( + dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + new_rel_pos_bias = self._geometric_sequence_interpolation( + src_size, dst_size, rel_pos_bias, num_attn_heads) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + return state_dict + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super(BEiT, self).init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super(BEiT, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/data_utils/easyportrait/mmseg/models/backbones/bisenetv1.py b/data_utils/easyportrait/mmseg/models/backbones/bisenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..4beb7b394d307bf34abab44507ecb92c9bb6cdb1 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/bisenetv1.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import BACKBONES, build_backbone + + +class SpatialPath(BaseModule): + """Spatial Path to preserve the spatial size of the original input image + and encode affluent spatial information. + + Args: + in_channels(int): The number of channels of input + image. Default: 3. + num_channels (Tuple[int]): The number of channels of + each layers in Spatial Path. + Default: (64, 64, 64, 128). + Returns: + x (torch.Tensor): Feature map for Feature Fusion Module. + """ + + def __init__(self, + in_channels=3, + num_channels=(64, 64, 64, 128), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(SpatialPath, self).__init__(init_cfg=init_cfg) + assert len(num_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + self.layers = [] + for i in range(len(num_channels)): + layer_name = f'layer{i + 1}' + self.layers.append(layer_name) + if i == 0: + self.add_module( + layer_name, + ConvModule( + in_channels=in_channels, + out_channels=num_channels[i], + kernel_size=7, + stride=2, + padding=3, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + elif i == len(num_channels) - 1: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + for i, layer_name in enumerate(self.layers): + layer_stage = getattr(self, layer_name) + x = layer_stage(x) + return x + + +class AttentionRefinementModule(BaseModule): + """Attention Refinement Module (ARM) to refine the features of each stage. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Attention Refinement Module. + """ + + def __init__(self, + in_channels, + out_channel, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(AttentionRefinementModule, self).__init__(init_cfg=init_cfg) + self.conv_layer = ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.atten_conv_layer = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), nn.Sigmoid()) + + def forward(self, x): + x = self.conv_layer(x) + x_atten = self.atten_conv_layer(x) + x_out = x * x_atten + return x_out + + +class ContextPath(BaseModule): + """Context Path to provide sufficient receptive field. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + context_channels (Tuple[int]): The number of channel numbers + of various modules in Context Path. + Default: (128, 256, 512). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + Returns: + x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps + undergoing upsampling from 1/16 and 1/32 downsampling + feature maps. These two feature maps are used for Feature + Fusion Module and Auxiliary Head. + """ + + def __init__(self, + backbone_cfg, + context_channels=(128, 256, 512), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(ContextPath, self).__init__(init_cfg=init_cfg) + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.backbone = build_backbone(backbone_cfg) + + self.align_corners = align_corners + self.arm16 = AttentionRefinementModule(context_channels[1], + context_channels[0]) + self.arm32 = AttentionRefinementModule(context_channels[2], + context_channels[0]) + self.conv_head32 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_head16 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap_conv = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=context_channels[2], + out_channels=context_channels[0], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + x_4, x_8, x_16, x_32 = self.backbone(x) + x_gap = self.gap_conv(x_32) + + x_32_arm = self.arm32(x_32) + x_32_sum = x_32_arm + x_gap + x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') + x_32_up = self.conv_head32(x_32_up) + + x_16_arm = self.arm16(x_16) + x_16_sum = x_16_arm + x_32_up + x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') + x_16_up = self.conv_head16(x_16_up) + + return x_16_up, x_32_up + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module to fuse low level output feature of Spatial Path + and high level output feature of Context Path. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Feature Fusion Module. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(FeatureFusionModule, self).__init__(init_cfg=init_cfg) + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_atten = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), nn.Sigmoid()) + + def forward(self, x_sp, x_cp): + x_concat = torch.cat([x_sp, x_cp], dim=1) + x_fuse = self.conv1(x_concat) + x_atten = self.gap(x_fuse) + # Note: No BN and more 1x1 conv in paper. + x_atten = self.conv_atten(x_atten) + x_atten = x_fuse * x_atten + x_out = x_atten + x_fuse + return x_out + + +@BACKBONES.register_module() +class BiSeNetV1(BaseModule): + """BiSeNetV1 backbone. + + This backbone is the implementation of `BiSeNet: Bilateral + Segmentation Network for Real-time Semantic + Segmentation `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input + image. Default: 3. + spatial_channels (Tuple[int]): Size of channel numbers of + various layers in Spatial Path. + Default: (64, 64, 64, 128). + context_channels (Tuple[int]): Size of channel numbers of + various modules in Context Path. + Default: (128, 256, 512). + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + out_channels(int): The number of channels of output. + It must be the same with `in_channels` of decode_head. + Default: 256. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + spatial_channels=(64, 64, 64, 128), + context_channels=(128, 256, 512), + out_indices=(0, 1, 2), + align_corners=False, + out_channels=256, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super(BiSeNetV1, self).__init__(init_cfg=init_cfg) + assert len(spatial_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.out_indices = out_indices + self.align_corners = align_corners + self.context_path = ContextPath(backbone_cfg, context_channels, + self.align_corners) + self.spatial_path = SpatialPath(in_channels, spatial_channels) + self.ffm = FeatureFusionModule(context_channels[1], out_channels) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_context8, x_context16 = self.context_path(x) + x_spatial = self.spatial_path(x) + x_fuse = self.ffm(x_spatial, x_context8) + + outs = [x_fuse, x_context8, x_context16] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/bisenetv2.py b/data_utils/easyportrait/mmseg/models/backbones/bisenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..d908b321caa4acacd71c997f16b6ac99b4256434 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/bisenetv2.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_activation_layer, build_norm_layer) +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import BACKBONES + + +class DetailBranch(BaseModule): + """Detail Branch with wide channels and shallow layers to capture low-level + details and generate high-resolution feature representation. + + Args: + detail_channels (Tuple[int]): Size of channel numbers of each stage + in Detail Branch, in paper it has 3 stages. + Default: (64, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Feature map of Detail Branch. + """ + + def __init__(self, + detail_channels=(64, 64, 128), + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(DetailBranch, self).__init__(init_cfg=init_cfg) + detail_branch = [] + for i in range(len(detail_channels)): + if i == 0: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + else: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=detail_channels[i - 1], + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + self.detail_branch = nn.ModuleList(detail_branch) + + def forward(self, x): + for stage in self.detail_branch: + x = stage(x) + return x + + +class StemBlock(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): First feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(StemBlock, self).__init__(init_cfg=init_cfg) + + self.conv_first = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.convs = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=out_channels // 2, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.pool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=False) + self.fuse_last = ConvModule( + in_channels=out_channels * 2, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.conv_first(x) + x_left = self.convs(x) + x_right = self.pool(x) + x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) + return x + + +class GELayer(BaseModule): + """Gather-and-Expansion Layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + stride (int): Stride of GELayer. Default: 1 + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Intermediate feature map in + Semantic Branch. + """ + + def __init__(self, + in_channels, + out_channels, + exp_ratio=6, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(GELayer, self).__init__(init_cfg=init_cfg) + mid_channel = in_channels * exp_ratio + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if stride == 1: + self.dwconv = nn.Sequential( + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.shortcut = None + else: + self.dwconv = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=mid_channel, + out_channels=mid_channel, + kernel_size=3, + stride=1, + padding=1, + groups=mid_channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + self.shortcut = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=None, + )) + + self.conv2 = nn.Sequential( + ConvModule( + in_channels=mid_channel, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + )) + + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.dwconv(x) + x = self.conv2(x) + if self.shortcut is not None: + shortcut = self.shortcut(identity) + x = x + shortcut + else: + x = x + identity + x = self.act(x) + return x + + +class CEBlock(BaseModule): + """Context Embedding Block for large receptive filed in Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Last feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(CEBlock, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + build_norm_layer(norm_cfg, self.in_channels)[1]) + self.conv_gap = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # Note: in paper here is naive conv2d, no bn-relu + self.conv_last = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + identity = x + x = self.gap(x) + x = self.conv_gap(x) + x = identity + x + x = self.conv_last(x) + return x + + +class SemanticBranch(BaseModule): + """Semantic Branch which is lightweight with narrow channels and deep + layers to obtain high-level semantic context. + + Args: + semantic_channels(Tuple[int]): Size of channel numbers of + various stages in Semantic Branch. + Default: (16, 32, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + semantic_outs (List[torch.Tensor]): List of several feature maps + for auxiliary heads (Booster) and Bilateral + Guided Aggregation Layer. + """ + + def __init__(self, + semantic_channels=(16, 32, 64, 128), + in_channels=3, + exp_ratio=6, + init_cfg=None): + super(SemanticBranch, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.semantic_channels = semantic_channels + self.semantic_stages = [] + for i in range(len(semantic_channels)): + stage_name = f'stage{i + 1}' + self.semantic_stages.append(stage_name) + if i == 0: + self.add_module( + stage_name, + StemBlock(self.in_channels, semantic_channels[i])) + elif i == (len(semantic_channels) - 1): + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + else: + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + + self.add_module(f'stage{len(semantic_channels)}_CEBlock', + CEBlock(semantic_channels[-1], semantic_channels[-1])) + self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') + + def forward(self, x): + semantic_outs = [] + for stage_name in self.semantic_stages: + semantic_stage = getattr(self, stage_name) + x = semantic_stage(x) + semantic_outs.append(x) + return semantic_outs + + +class BGALayer(BaseModule): + """Bilateral Guided Aggregation Layer to fuse the complementary information + from both Detail Branch and Semantic Branch. + + Args: + out_channels (int): Number of output channels. + Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + output (torch.Tensor): Output feature map for Segment heads. + """ + + def __init__(self, + out_channels=128, + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(BGALayer, self).__init__(init_cfg=init_cfg) + self.out_channels = out_channels + self.align_corners = align_corners + self.detail_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.detail_down = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) + self.semantic_conv = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None)) + self.semantic_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.conv = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + inplace=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + def forward(self, x_d, x_s): + detail_dwconv = self.detail_dwconv(x_d) + detail_down = self.detail_down(x_d) + semantic_conv = self.semantic_conv(x_s) + semantic_dwconv = self.semantic_dwconv(x_s) + semantic_conv = resize( + input=semantic_conv, + size=detail_dwconv.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) + fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) + fuse_2 = resize( + input=fuse_2, + size=fuse_1.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = self.conv(fuse_1 + fuse_2) + return output + + +@BACKBONES.register_module() +class BiSeNetV2(BaseModule): + """BiSeNetV2: Bilateral Network with Guided Aggregation for + Real-time Semantic Segmentation. + + This backbone is the implementation of + `BiSeNetV2 `_. + + Args: + in_channels (int): Number of channel of input image. Default: 3. + detail_channels (Tuple[int], optional): Channels of each stage + in Detail Branch. Default: (64, 64, 128). + semantic_channels (Tuple[int], optional): Channels of each stage + in Semantic Branch. Default: (16, 32, 64, 128). + See Table 1 and Figure 3 of paper for more details. + semantic_expansion_ratio (int, optional): The expansion factor + expanding channel number of middle channels in Semantic Branch. + Default: 6. + bga_channels (int, optional): Number of middle channels in + Bilateral Guided Aggregation Layer. Default: 128. + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2, 3, 4). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + super(BiSeNetV2, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_indices = out_indices + self.detail_channels = detail_channels + self.semantic_channels = semantic_channels + self.semantic_expansion_ratio = semantic_expansion_ratio + self.bga_channels = bga_channels + self.align_corners = align_corners + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.detail = DetailBranch(self.detail_channels, self.in_channels) + self.semantic = SemanticBranch(self.semantic_channels, + self.in_channels, + self.semantic_expansion_ratio) + self.bga = BGALayer(self.bga_channels, self.align_corners) + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_detail = self.detail(x) + x_semantic_lst = self.semantic(x) + x_head = self.bga(x_detail, x_semantic_lst[-1]) + outs = [x_head] + x_semantic_lst[:-1] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/cgnet.py b/data_utils/easyportrait/mmseg/models/backbones/cgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..168194c106ed19fd46891de7ee8214a4db88f3c7 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/cgnet.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmcv.runner import BaseModule +from mmcv.utils.parrots_wrapper import _BatchNorm + +from ..builder import BACKBONES + + +class GlobalContextExtractor(nn.Module): + """Global Context Extractor for CGNet. + + This class is employed to refine the joint feature of both local feature + and surrounding context. + + Args: + channel (int): Number of input feature channels. + reduction (int): Reductions for global context extractor. Default: 16. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, channel, reduction=16, with_cp=False): + super(GlobalContextExtractor, self).__init__() + self.channel = channel + self.reduction = reduction + assert reduction >= 1 and channel >= reduction + self.with_cp = with_cp + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + + def _inner_forward(x): + num_batch, num_channel = x.size()[:2] + y = self.avg_pool(x).view(num_batch, num_channel) + y = self.fc(y).view(num_batch, num_channel, 1, 1) + return x * y + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class ContextGuidedBlock(nn.Module): + """Context Guided Block for CGNet. + + This class consists of four components: local feature extractor, + surrounding feature extractor, joint feature extractor and global + context extractor. + + Args: + in_channels (int): Number of input feature channels. + out_channels (int): Number of output feature channels. + dilation (int): Dilation rate for surrounding context extractor. + Default: 2. + reduction (int): Reduction for global context extractor. Default: 16. + skip_connect (bool): Add input to output or not. Default: True. + downsample (bool): Downsample the input to 1/2 or not. Default: False. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + dilation=2, + reduction=16, + skip_connect=True, + downsample=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + with_cp=False): + super(ContextGuidedBlock, self).__init__() + self.with_cp = with_cp + self.downsample = downsample + + channels = out_channels if downsample else out_channels // 2 + if 'type' in act_cfg and act_cfg['type'] == 'PReLU': + act_cfg['num_parameters'] = channels + kernel_size = 3 if downsample else 1 + stride = 2 if downsample else 1 + padding = (kernel_size - 1) // 2 + + self.conv1x1 = ConvModule( + in_channels, + channels, + kernel_size, + stride, + padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.f_loc = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=False) + self.f_sur = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=dilation, + groups=channels, + dilation=dilation, + bias=False) + + self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] + self.activate = nn.PReLU(2 * channels) + + if downsample: + self.bottleneck = build_conv_layer( + conv_cfg, + 2 * channels, + out_channels, + kernel_size=1, + bias=False) + + self.skip_connect = skip_connect and not downsample + self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) + + def forward(self, x): + + def _inner_forward(x): + out = self.conv1x1(x) + loc = self.f_loc(out) + sur = self.f_sur(out) + + joi_feat = torch.cat([loc, sur], 1) # the joint feature + joi_feat = self.bn(joi_feat) + joi_feat = self.activate(joi_feat) + if self.downsample: + joi_feat = self.bottleneck(joi_feat) # channel = out_channels + # f_glo is employed to refine the joint feature + out = self.f_glo(joi_feat) + + if self.skip_connect: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InputInjection(nn.Module): + """Downsampling module for CGNet.""" + + def __init__(self, num_downsampling): + super(InputInjection, self).__init__() + self.pool = nn.ModuleList() + for i in range(num_downsampling): + self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + for pool in self.pool: + x = pool(x) + return x + + +@BACKBONES.register_module() +class CGNet(BaseModule): + """CGNet backbone. + + This backbone is the implementation of `A Light-weight Context Guided + Network for Semantic Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Normally 3. + num_channels (tuple[int]): Numbers of feature channels at each stages. + Default: (32, 64, 128). + num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2. + Default: (3, 21). + dilations (tuple[int]): Dilation rate for surrounding context + extractors at stage 1 and stage 2. Default: (2, 4). + reductions (tuple[int]): Reductions for global context extractors at + stage 1 and stage 2. Default: (8, 16). + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + + super(CGNet, self).__init__(init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer=['Conv2d', 'Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Constant', val=0, layer='PReLU') + ] + else: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.num_channels = num_channels + assert isinstance(self.num_channels, tuple) and len( + self.num_channels) == 3 + self.num_blocks = num_blocks + assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 + self.dilations = dilations + assert isinstance(self.dilations, tuple) and len(self.dilations) == 2 + self.reductions = reductions + assert isinstance(self.reductions, tuple) and len(self.reductions) == 2 + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': + self.act_cfg['num_parameters'] = num_channels[0] + self.norm_eval = norm_eval + self.with_cp = with_cp + + cur_channels = in_channels + self.stem = nn.ModuleList() + for i in range(3): + self.stem.append( + ConvModule( + cur_channels, + num_channels[0], + 3, + 2 if i == 0 else 1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + cur_channels = num_channels[0] + + self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 + self.inject_4x = InputInjection(2) # down-sample for Input, factor=4 + + cur_channels += in_channels + self.norm_prelu_0 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 1 + self.level1 = nn.ModuleList() + for i in range(num_blocks[0]): + self.level1.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[1], + num_channels[1], + dilations[0], + reductions[0], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[1] + in_channels + self.norm_prelu_1 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 2 + self.level2 = nn.ModuleList() + for i in range(num_blocks[1]): + self.level2.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[2], + num_channels[2], + dilations[1], + reductions[1], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[2] + self.norm_prelu_2 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + def forward(self, x): + output = [] + + # stage 0 + inp_2x = self.inject_2x(x) + inp_4x = self.inject_4x(x) + for layer in self.stem: + x = layer(x) + x = self.norm_prelu_0(torch.cat([x, inp_2x], 1)) + output.append(x) + + # stage 1 + for i, layer in enumerate(self.level1): + x = layer(x) + if i == 0: + down1 = x + x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1)) + output.append(x) + + # stage 2 + for i, layer in enumerate(self.level2): + x = layer(x) + if i == 0: + down2 = x + x = self.norm_prelu_2(torch.cat([down2, x], 1)) + output.append(x) + + return output + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(CGNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/data_utils/easyportrait/mmseg/models/backbones/erfnet.py b/data_utils/easyportrait/mmseg/models/backbones/erfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8921c18f3cc30bf8104ae31f53fb70470e12d30b --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/erfnet.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import BACKBONES + + +class DownsamplerBlock(BaseModule): + """Downsampler block of ERFNet. + + This module is a little different from basical ConvModule. + The features from Conv and MaxPool layers are + concatenated before BatchNorm. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(DownsamplerBlock, self).__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels - in_channels, + kernel_size=3, + stride=2, + padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + conv_out = self.conv(input) + pool_out = self.pool(input) + pool_out = resize( + input=pool_out, + size=conv_out.size()[2:], + mode='bilinear', + align_corners=False) + output = torch.cat([conv_out, pool_out], 1) + output = self.bn(output) + output = self.act(output) + return output + + +class NonBottleneck1d(BaseModule): + """Non-bottleneck block of ERFNet. + + Args: + channels (int): Number of channels in Non-bottleneck block. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + dilation (int): Dilation rate for last two conv layers. + Default 1. + num_conv_layer (int): Number of 3x1 and 1x3 convolution layers. + Default 2. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + channels, + drop_rate=0, + dilation=1, + num_conv_layer=2, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(NonBottleneck1d, self).__init__(init_cfg=init_cfg) + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.act = build_activation_layer(self.act_cfg) + + self.convs_layers = nn.ModuleList() + for conv_layer in range(num_conv_layer): + first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0) + first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1) + second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation) + second_conv_dilation = 1 if conv_layer == 0 else (1, dilation) + + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=first_conv_padding, + bias=True, + dilation=first_conv_dilation)) + self.convs_layers.append(self.act) + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=second_conv_padding, + bias=True, + dilation=second_conv_dilation)) + self.convs_layers.append( + build_norm_layer(self.norm_cfg, channels)[1]) + if conv_layer == 0: + self.convs_layers.append(self.act) + else: + self.convs_layers.append(nn.Dropout(p=drop_rate)) + + def forward(self, input): + output = input + for conv in self.convs_layers: + output = conv(output) + output = self.act(output + input) + return output + + +class UpsamplerBlock(BaseModule): + """Upsampler block of ERFNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(UpsamplerBlock, self).__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=True) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + output = self.act(output) + return output + + +@BACKBONES.register_module() +class ERFNet(BaseModule): + """ERFNet backbone. + + This backbone is the implementation of `ERFNet: Efficient Residual + Factorized ConvNet for Real-time SemanticSegmentation + `_. + + Args: + in_channels (int): The number of channels of input + image. Default: 3. + enc_downsample_channels (Tuple[int]): Size of channel + numbers of various Downsampler block in encoder. + Default: (16, 64, 128). + enc_stage_non_bottlenecks (Tuple[int]): Number of stages of + Non-bottleneck block in encoder. + Default: (5, 8). + enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each + stage of Non-bottleneck block of encoder. + Default: (2, 4, 8, 16). + enc_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in encoder. + Default: (64, 128). + dec_upsample_channels (Tuple[int]): Size of channel numbers of + various Deconvolution block in decoder. + Default: (64, 16). + dec_stages_non_bottleneck (Tuple[int]): Number of stages of + Non-bottleneck block in decoder. + Default: (2, 2). + dec_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in decoder. + Default: (64, 16). + drop_rate (float): Probability of an element to be zeroed. + Default 0.1. + """ + + def __init__(self, + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_stage_non_bottlenecks=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super(ERFNet, self).__init__(init_cfg=init_cfg) + assert len(enc_downsample_channels) \ + == len(dec_upsample_channels)+1, 'Number of downsample\ + block of encoder does not \ + match number of upsample block of decoder!' + assert len(enc_downsample_channels) \ + == len(enc_stage_non_bottlenecks)+1, 'Number of \ + downsample block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(enc_downsample_channels) \ + == len(enc_non_bottleneck_channels)+1, 'Number of \ + downsample block of encoder does not match \ + number of channels of Non-bottleneck block of encoder!' + assert enc_stage_non_bottlenecks[-1] \ + % len(enc_non_bottleneck_dilations) == 0, 'Number of \ + Non-bottleneck block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(dec_upsample_channels) \ + == len(dec_stages_non_bottleneck), 'Number of \ + upsample block of decoder does not match \ + number of Non-bottleneck block of decoder!' + assert len(dec_stages_non_bottleneck) \ + == len(dec_non_bottleneck_channels), 'Number of \ + Non-bottleneck block of decoder does not match \ + number of channels of Non-bottleneck block of decoder!' + + self.in_channels = in_channels + self.enc_downsample_channels = enc_downsample_channels + self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks + self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations + self.enc_non_bottleneck_channels = enc_non_bottleneck_channels + self.dec_upsample_channels = dec_upsample_channels + self.dec_stages_non_bottleneck = dec_stages_non_bottleneck + self.dec_non_bottleneck_channels = dec_non_bottleneck_channels + self.dropout_ratio = dropout_ratio + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.encoder.append( + DownsamplerBlock(self.in_channels, enc_downsample_channels[0])) + + for i in range(len(enc_downsample_channels) - 1): + self.encoder.append( + DownsamplerBlock(enc_downsample_channels[i], + enc_downsample_channels[i + 1])) + # Last part of encoder is some dilated NonBottleneck1d blocks. + if i == len(enc_downsample_channels) - 2: + iteration_times = int(enc_stage_non_bottlenecks[-1] / + len(enc_non_bottleneck_dilations)) + for j in range(iteration_times): + for k in range(len(enc_non_bottleneck_dilations)): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k])) + else: + for j in range(enc_stage_non_bottlenecks[i]): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[i + 1], + self.dropout_ratio)) + + for i in range(len(dec_upsample_channels)): + if i == 0: + self.decoder.append( + UpsamplerBlock(enc_downsample_channels[-1], + dec_non_bottleneck_channels[i])) + else: + self.decoder.append( + UpsamplerBlock(dec_non_bottleneck_channels[i - 1], + dec_non_bottleneck_channels[i])) + for j in range(dec_stages_non_bottleneck[i]): + self.decoder.append( + NonBottleneck1d(dec_non_bottleneck_channels[i])) + + def forward(self, x): + for enc in self.encoder: + x = enc(x) + for dec in self.decoder: + x = dec(x) + return [x] diff --git a/data_utils/easyportrait/mmseg/models/backbones/fast_scnn.py b/data_utils/easyportrait/mmseg/models/backbones/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfbcaf4f3bf508098a281cffa87f94bd1bdd47c --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/fast_scnn.py @@ -0,0 +1,409 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.runner import BaseModule + +from mmseg.models.decode_heads.psp_head import PPM +from mmseg.ops import resize +from ..builder import BACKBONES +from ..utils import InvertedResidual + + +class LearningToDownsample(nn.Module): + """Learning to downsample module. + + Args: + in_channels (int): Number of input channels. + dw_channels (tuple[int]): Number of output channels of the first and + the second depthwise conv (dwconv) layers. + out_channels (int): Number of output channels of the whole + 'learning to downsample' module. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + """ + + def __init__(self, + in_channels, + dw_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dw_act_cfg=None): + super(LearningToDownsample, self).__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.dw_act_cfg = dw_act_cfg + dw_channels1 = dw_channels[0] + dw_channels2 = dw_channels[1] + + self.conv = ConvModule( + in_channels, + dw_channels1, + 3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.dsconv1 = DepthwiseSeparableConvModule( + dw_channels1, + dw_channels2, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + self.dsconv2 = DepthwiseSeparableConvModule( + dw_channels2, + out_channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + def forward(self, x): + x = self.conv(x) + x = self.dsconv1(x) + x = self.dsconv2(x) + return x + + +class GlobalFeatureExtractor(nn.Module): + """Global feature extractor module. + + Args: + in_channels (int): Number of input channels of the GFE module. + Default: 64 + block_channels (tuple[int]): Tuple of ints. Each int specifies the + number of output channels of each Inverted Residual module. + Default: (64, 96, 128) + out_channels(int): Number of output channels of the GFE module. + Default: 128 + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + Default: 6 + num_blocks (tuple[int]): Tuple of ints. Each int specifies the + number of times each Inverted Residual module is repeated. + The repeated Inverted Residual modules are called a 'group'. + Default: (3, 3, 3) + strides (tuple[int]): Tuple of ints. Each int specifies + the downsampling factor of each 'group'. + Default: (2, 2, 1) + pool_scales (tuple[int]): Tuple of ints. Each int specifies + the parameter required in 'global average pooling' within PPM. + Default: (1, 2, 3, 6) + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + expand_ratio=6, + num_blocks=(3, 3, 3), + strides=(2, 2, 1), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super(GlobalFeatureExtractor, self).__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + assert len(block_channels) == len(num_blocks) == 3 + self.bottleneck1 = self._make_layer(in_channels, block_channels[0], + num_blocks[0], strides[0], + expand_ratio) + self.bottleneck2 = self._make_layer(block_channels[0], + block_channels[1], num_blocks[1], + strides[1], expand_ratio) + self.bottleneck3 = self._make_layer(block_channels[1], + block_channels[2], num_blocks[2], + strides[2], expand_ratio) + self.ppm = PPM( + pool_scales, + block_channels[2], + block_channels[2] // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=align_corners) + + self.out = ConvModule( + block_channels[2] * 2, + out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _make_layer(self, + in_channels, + out_channels, + blocks, + stride=1, + expand_ratio=6): + layers = [ + InvertedResidual( + in_channels, + out_channels, + stride, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + for i in range(1, blocks): + layers.append( + InvertedResidual( + out_channels, + out_channels, + 1, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.bottleneck1(x) + x = self.bottleneck2(x) + x = self.bottleneck3(x) + x = torch.cat([x, *self.ppm(x)], dim=1) + x = self.out(x) + return x + + +class FeatureFusionModule(nn.Module): + """Feature fusion module. + + Args: + higher_in_channels (int): Number of input channels of the + higher-resolution branch. + lower_in_channels (int): Number of input channels of the + lower-resolution branch. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + dwconv_act_cfg (dict): Config of activation layers in 3x3 conv. + Default: dict(type='ReLU'). + conv_act_cfg (dict): Config of activation layers in the two 1x1 conv. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + """ + + def __init__(self, + higher_in_channels, + lower_in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dwconv_act_cfg=dict(type='ReLU'), + conv_act_cfg=None, + align_corners=False): + super(FeatureFusionModule, self).__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dwconv_act_cfg = dwconv_act_cfg + self.conv_act_cfg = conv_act_cfg + self.align_corners = align_corners + self.dwconv = ConvModule( + lower_in_channels, + out_channels, + 3, + padding=1, + groups=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.dwconv_act_cfg) + self.conv_lower_res = ConvModule( + out_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.conv_higher_res = ConvModule( + higher_in_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.relu = nn.ReLU(True) + + def forward(self, higher_res_feature, lower_res_feature): + lower_res_feature = resize( + lower_res_feature, + size=higher_res_feature.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + lower_res_feature = self.dwconv(lower_res_feature) + lower_res_feature = self.conv_lower_res(lower_res_feature) + + higher_res_feature = self.conv_higher_res(higher_res_feature) + out = higher_res_feature + lower_res_feature + return self.relu(out) + + +@BACKBONES.register_module() +class FastSCNN(BaseModule): + """Fast-SCNN Backbone. + + This backbone is the implementation of `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels (int): Number of input image channels. Default: 3. + downsample_dw_channels (tuple[int]): Number of output channels after + the first conv layer & the second conv layer in + Learning-To-Downsample (LTD) module. + Default: (32, 48). + global_in_channels (int): Number of input channels of + Global Feature Extractor(GFE). + Equal to number of output channels of LTD. + Default: 64. + global_block_channels (tuple[int]): Tuple of integers that describe + the output channels for each of the MobileNet-v2 bottleneck + residual blocks in GFE. + Default: (64, 96, 128). + global_block_strides (tuple[int]): Tuple of integers + that describe the strides (downsampling factors) for each of the + MobileNet-v2 bottleneck residual blocks in GFE. + Default: (2, 2, 1). + global_out_channels (int): Number of output channels of GFE. + Default: 128. + higher_in_channels (int): Number of input channels of the higher + resolution branch in FFM. + Equal to global_in_channels. + Default: 64. + lower_in_channels (int): Number of input channels of the lower + resolution branch in FFM. + Equal to global_out_channels. + Default: 128. + fusion_out_channels (int): Number of output channels of FFM. + Default: 128. + out_indices (tuple): Tuple of indices of list + [higher_res_features, lower_res_features, fusion_output]. + Often set to (0,1,2) to enable aux. heads. + Default: (0, 1, 2). + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + dw_act_cfg=None, + init_cfg=None): + + super(FastSCNN, self).__init__(init_cfg) + + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + + if global_in_channels != higher_in_channels: + raise AssertionError('Global Input Channels must be the same \ + with Higher Input Channels!') + elif global_out_channels != lower_in_channels: + raise AssertionError('Global Output Channels must be the same \ + with Lower Input Channels!') + + self.in_channels = in_channels + self.downsample_dw_channels1 = downsample_dw_channels[0] + self.downsample_dw_channels2 = downsample_dw_channels[1] + self.global_in_channels = global_in_channels + self.global_block_channels = global_block_channels + self.global_block_strides = global_block_strides + self.global_out_channels = global_out_channels + self.higher_in_channels = higher_in_channels + self.lower_in_channels = lower_in_channels + self.fusion_out_channels = fusion_out_channels + self.out_indices = out_indices + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.learning_to_downsample = LearningToDownsample( + in_channels, + downsample_dw_channels, + global_in_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + dw_act_cfg=dw_act_cfg) + self.global_feature_extractor = GlobalFeatureExtractor( + global_in_channels, + global_block_channels, + global_out_channels, + strides=self.global_block_strides, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.feature_fusion = FeatureFusionModule( + higher_in_channels, + lower_in_channels, + fusion_out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dwconv_act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, x): + higher_res_features = self.learning_to_downsample(x) + lower_res_features = self.global_feature_extractor(higher_res_features) + fusion_output = self.feature_fusion(higher_res_features, + lower_res_features) + + outs = [higher_res_features, lower_res_features, fusion_output] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/hrnet.py b/data_utils/easyportrait/mmseg/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..90feadcf62798f7b7a5e9fdddd0e2202687cc113 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/hrnet.py @@ -0,0 +1,642 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner import BaseModule, ModuleList, Sequential +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmseg.ops import Upsample, resize +from ..builder import BACKBONES +from .resnet import BasicBlock, Bottleneck + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + block_init_cfg=None, + init_cfg=None): + super(HRModule, self).__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + """Check branches configuration.""" + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ + f'{len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ + f'{len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ + f'{len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, num_channels[branch_index] * + block.expansion)[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + + return Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + """Build multiple branch.""" + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return ModuleList(branches) + + def _make_fuse_layers(self): + """Build fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + # we set align_corners=False for HRNet + Upsample( + scale_factor=2**(j - i), + mode='bilinear', + align_corners=False))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + elif j > i: + y = y + resize( + self.fuse_layers[i][j](x[j]), + size=x[i].shape[2:], + mode='bilinear', + align_corners=False) + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@BACKBONES.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + This backbone is the implementation of `High-Resolution Representations + for Labeling Pixels and Regions `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Use `BN` by default. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + frozen_stages=-1, + zero_init_residual=False, + multiscale_output=True, + pretrained=None, + init_cfg=None): + super(HRNet, self).__init__(init_cfg) + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + # Assert configurations of 4 stages are in extra + assert 'stage1' in extra and 'stage2' in extra \ + and 'stage3' in extra and 'stage4' in extra + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + for i in range(4): + cfg = extra[f'stage{i + 1}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.frozen_stages = frozen_stages + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * block.expansion + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multiscale_output=multiscale_output) + + self._freeze_stages() + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + """Make each layer.""" + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + + return Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules), in_channels + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + + self.norm1.eval() + self.norm2.eval() + for m in [self.conv1, self.norm1, self.conv2, self.norm2]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if i == 1: + m = getattr(self, f'layer{i}') + t = getattr(self, f'transition{i}') + elif i == 4: + m = getattr(self, f'stage{i}') + else: + m = getattr(self, f'stage{i}') + t = getattr(self, f'transition{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + t.eval() + for param in t.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super(HRNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/data_utils/easyportrait/mmseg/models/backbones/icnet.py b/data_utils/easyportrait/mmseg/models/backbones/icnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6faaeab01c26eca979ad47a29c07dd935124da2d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/icnet.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import BACKBONES, build_backbone +from ..decode_heads.psp_head import PPM + + +@BACKBONES.register_module() +class ICNet(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This backbone is the implementation of + `ICNet `_. + + Args: + backbone_cfg (dict): Config dict to build backbone. Usually it is + ResNet but it can also be other backbones. + in_channels (int): The number of input image channels. Default: 3. + layer_channels (Sequence[int]): The numbers of feature channels at + layer 2 and layer 4 in ResNet. It can also be other backbones. + Default: (512, 2048). + light_branch_middle_channels (int): The number of channels of the + middle layer in light branch. Default: 32. + psp_out_channels (int): The number of channels of the output of PSP + module. Default: 512. + out_channels (Sequence[int]): The numbers of output feature channels + at each branches. Default: (64, 256, 256). + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + layer_channels=(512, 2048), + light_branch_middle_channels=32, + psp_out_channels=512, + out_channels=(64, 256, 256), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + if backbone_cfg is None: + raise TypeError('backbone_cfg must be passed from config file!') + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', mode='fan_out', layer='Conv2d'), + dict(type='Constant', val=1, layer='_BatchNorm'), + dict(type='Normal', mean=0.01, layer='Linear') + ] + super(ICNet, self).__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.backbone = build_backbone(backbone_cfg) + + # Note: Default `ceil_mode` is false in nn.MaxPool2d, set + # `ceil_mode=True` to keep information in the corner of feature map. + self.backbone.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=True) + + self.psp_modules = PPM( + pool_scales=pool_scales, + in_channels=layer_channels[1], + channels=psp_out_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + align_corners=align_corners) + + self.psp_bottleneck = ConvModule( + layer_channels[1] + len(pool_scales) * psp_out_channels, + psp_out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.conv_sub1 = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=out_channels[0], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + self.conv_sub2 = ConvModule( + layer_channels[0], + out_channels[1], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + self.conv_sub4 = ConvModule( + psp_out_channels, + out_channels[2], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + def forward(self, x): + output = [] + + # sub 1 + output.append(self.conv_sub1(x)) + + # sub 2 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.stem(x) + x = self.backbone.maxpool(x) + x = self.backbone.layer1(x) + x = self.backbone.layer2(x) + output.append(self.conv_sub2(x)) + + # sub 4 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.layer3(x) + x = self.backbone.layer4(x) + psp_outs = self.psp_modules(x) + [x] + psp_outs = torch.cat(psp_outs, dim=1) + x = self.psp_bottleneck(psp_outs) + + output.append(self.conv_sub4(x)) + + return output diff --git a/data_utils/easyportrait/mmseg/models/backbones/mae.py b/data_utils/easyportrait/mmseg/models/backbones/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e8754bd2bb187529d5d28ac0c6d9e66b9afb5c --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/mae.py @@ -0,0 +1,261 @@ +# Copyright (c) OpenMMLab. All rights reserved.import math +import math + +import torch +import torch.nn as nn +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmcv.runner import ModuleList, _load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.utils import get_root_logger +from ..builder import BACKBONES +from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer + + +class MAEAttention(BEiTAttention): + """Multi-head self-attention with relative position bias used in MAE. + + This module is different from ``BEiTAttention`` by initializing the + relative bias table with zeros. + """ + + def init_weights(self): + """Initialize relative position bias with zeros.""" + + # As MAE initializes relative position bias as zeros and this class + # inherited from BEiT which initializes relative position bias + # with `trunc_normal`, `init_weights` here does + # nothing and just passes directly + + pass + + +class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + This module is different from ``BEiTTransformerEncoderLayer`` by replacing + ``BEiTAttention`` with ``MAEAttention``. + """ + + def build_attn(self, attn_cfg): + self.attn = MAEAttention(**attn_cfg) + + +@BACKBONES.register_module() +class MAE(BEiT): + """VisionTransformer with support for patch. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_values (float): Initialize the values of Attention and FFN + with learnable scaling. Defaults to 0.1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super(MAE, self).__init__( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + out_indices=out_indices, + qv_bias=False, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + patch_norm=patch_norm, + final_norm=final_norm, + num_fcs=num_fcs, + norm_eval=norm_eval, + pretrained=pretrained, + init_values=init_values, + init_cfg=init_cfg) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self.num_patches = self.patch_shape[0] * self.patch_shape[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, embed_dims)) + + def _build_layers(self): + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + MAETransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias=True, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.patch_shape, + init_values=self.init_values)) + + def fix_init_weight(self): + """Rescale the initialization according to layer id. + + This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + """ + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + self.fix_init_weight() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + state_dict = self.resize_abs_pos_embed(state_dict) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super(MAE, self).init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def resize_abs_pos_embed(self, state_dict): + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(self.num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + return state_dict + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/mit.py b/data_utils/easyportrait/mmseg/models/backbones/mit.py new file mode 100644 index 0000000000000000000000000000000000000000..4417cf113383c8576ee11c56d908d2fd2b639219 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/mit.py @@ -0,0 +1,450 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmcv.cnn.utils.weight_init import (constant_init, normal_init, + trunc_normal_init) +from mmcv.runner import BaseModule, ModuleList, Sequential + +from ..builder import BACKBONES +from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Conv to encode positional information. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super(MixFFN, self).__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class EfficientMultiheadAttention(MultiheadAttention): + """An implementation of Efficient Multi-head Attention of Segformer. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type='LN'), + sr_ratio=1): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + dropout_layer=dropout_layer, + init_cfg=init_cfg, + batch_first=batch_first, + bias=qkv_bias) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmseg import digit_version, mmcv_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'EfficientMultiheadAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # `need_weights=True` will let nn.MultiHeadAttention + # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` + # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set + # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. + # This issue - `https://github.com/pytorch/pytorch/issues/37583` report + # the error that large scale tensor sum operation may cause cuda error. + out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Segformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + init_cfg (dict, optional): Initialization config dict. + Default:None. + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + sr_ratio=1, + with_cp=False): + super(TransformerEncoderLayer, self).__init__() + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = EfficientMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + self.with_cp = with_cp + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@BACKBONES.register_module() +class MixVisionTransformer(BaseModule): + """The backbone of Segformer. + + This backbone is the implementation of `SegFormer: Simple and + Efficient Design for Semantic Segmentation with + Transformers `_. + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 4, 8]. + patch_sizes (Sequence[int]): The patch_size of each overlapped patch + embedding. Default: [7, 3, 3, 3]. + strides (Sequence[int]): The stride of each overlapped patch embedding. + Default: [4, 2, 2, 2]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrained=None, + init_cfg=None, + with_cp=False): + super(MixVisionTransformer, self).__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + self.with_cp = with_cp + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=patch_sizes[i] // 2, + norm_cfg=norm_cfg) + layer = ModuleList([ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + sr_ratio=sr_ratios[i]) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + self.layers.append(ModuleList([patch_embed, layer, norm])) + cur += num_layer + + def init_weights(self): + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super(MixVisionTransformer, self).init_weights() + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, hw_shape = layer[0](x) + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs diff --git a/data_utils/easyportrait/mmseg/models/backbones/mobilenet_v2.py b/data_utils/easyportrait/mmseg/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb9c6cd01fb83b330a6067415a0519bcb448e1c --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/mobilenet_v2.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils import InvertedResidual, make_divisible + + +@BACKBONES.register_module() +class MobileNetV2(BaseModule): + """MobileNetV2 backbone. + + This backbone is the implementation of + `MobileNetV2: Inverted Residuals and Linear Bottlenecks + `_. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], + [6, 96, 3], [6, 160, 3], [6, 320, 1]] + + def __init__(self, + widen_factor=1., + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super(MobileNetV2, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 7). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/data_utils/easyportrait/mmseg/models/backbones/mobilenet_v3.py b/data_utils/easyportrait/mmseg/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3d6eb176bbd8c77d25730744daeef6c02435bc --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/mobilenet_v3.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import Conv2dAdaptivePadding +from mmcv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from ..utils import InvertedResidualV3 as InvertedResidual + + +@BACKBONES.register_module() +class MobileNetV3(BaseModule): + """MobileNetV3 backbone. + + This backbone is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + arch (str): Architecture of mobilnetv3, from {'small', 'large'}. + Default: 'small'. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (tuple[int]): Output from which layer. + Default: (0, 1, 12). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 + [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 + [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super(MobileNetV3, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert arch in self.arch_settings + assert isinstance(reduction_factor, int) and reduction_factor > 0 + assert mmcv.is_tuple_of(out_indices, int) + for index in out_indices: + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])+2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])+2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.reduction_factor = reduction_factor + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + + # build the first layer (layer0) + in_channels = 16 + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + + if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ + i >= 8: + mid_channels = mid_channels // self.reduction_factor + out_channels = out_channels // self.reduction_factor + + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=(in_channels != mid_channels), + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = 'layer{}'.format(i + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # build the last layer + # block5 layer12 os=32 for small model + # block6 layer16 os=32 for large model + layer = ConvModule( + in_channels=in_channels, + out_channels=576 if self.arch == 'small' else 960, + kernel_size=1, + stride=1, + dilation=4, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = 'layer{}'.format(len(layer_setting) + 1) + self.add_module(layer_name, layer) + layers.append(layer_name) + + # next, convert backbone MobileNetV3 to a semantic segmentation version + if self.arch == 'small': + self.layer4.depthwise_conv.conv.stride = (1, 1) + self.layer9.depthwise_conv.conv.stride = (1, 1) + for i in range(4, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 9: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + else: + self.layer7.depthwise_conv.conv.stride = (1, 1) + self.layer13.depthwise_conv.conv.stride = (1, 1) + for i in range(7, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 13: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs + + def _freeze_stages(self): + for i in range(self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV3, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/data_utils/easyportrait/mmseg/models/backbones/mscan.py b/data_utils/easyportrait/mmseg/models/backbones/mscan.py new file mode 100644 index 0000000000000000000000000000000000000000..28284fb3f5d14c03d4f39eae0e18536d9756ad82 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/mscan.py @@ -0,0 +1,461 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.utils.weight_init import (constant_init, normal_init, + trunc_normal_init) +from mmcv.runner import BaseModule + +from mmseg.models.builder import BACKBONES + + +class Mlp(BaseModule): + """Multi Layer Perceptron (MLP) Module. + Args: + in_features (int): The dimension of input features. + hidden_features (int): The dimension of hidden features. + Defaults: None. + out_features (int): The dimension of output features. + Defaults: None. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward function.""" + + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(BaseModule): + """Stem Block at the beginning of Semantic Branch. + Args: + in_channels (int): The dimension of input channels. + out_channels (int): The dimension of output channels. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + in_channels, + out_channels, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super(StemConv, self).__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels // 2)[1], + build_activation_layer(act_cfg), + nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class MSCAAttention(BaseModule): + """Attention Module in Multi-Scale Convolutional Attention Module (MSCA). + Args: + channels (int): The dimension of channels. + kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + """ + + def __init__(self, + channels, + kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + paddings=[2, [0, 3], [0, 5], [0, 10]]): + super().__init__() + self.conv0 = nn.Conv2d( + channels, + channels, + kernel_size=kernel_sizes[0], + padding=paddings[0], + groups=channels) + for i, (kernel_size, + padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])): + kernel_size_ = [kernel_size, kernel_size[::-1]] + padding_ = [padding, padding[::-1]] + conv_name = [f'conv{i}_1', f'conv{i}_2'] + for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_, + conv_name): + self.add_module( + i_conv, + nn.Conv2d( + channels, + channels, + tuple(i_kernel), + padding=i_pad, + groups=channels)) + self.conv3 = nn.Conv2d(channels, channels, 1) + + def forward(self, x): + """Forward function.""" + + u = x.clone() + + attn = self.conv0(x) + + # Multi-Scale Feature extraction + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + + attn = attn + attn_0 + attn_1 + attn_2 + # Channel Mixing + attn = self.conv3(attn) + + # Convolutional Attention + x = attn * u + + return x + + +class MSCASpatialAttention(BaseModule): + """Spatial Attention Module in Multi-Scale Convolutional Attention Module + (MSCA). + Args: + in_channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU')): + super().__init__() + self.proj_1 = nn.Conv2d(in_channels, in_channels, 1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = MSCAAttention(in_channels, + attention_kernel_sizes, + attention_kernel_paddings) + self.proj_2 = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + """Forward function.""" + + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class MSCABlock(BaseModule): + """Basic Multi-Scale Convolutional Attention Block. It leverage the large- + kernel attention (LKA) mechanism to build both channel and spatial + attention. In each branch, it uses two depth-wise strip convolutions to + approximate standard depth-wise convolutions with large kernels. The kernel + size for each branch is set to 7, 11, and 21, respectively. + Args: + channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + mlp_ratio (float): The ratio of multiple input dimension to + calculate hidden feature in MLP layer. Defaults: 4.0. + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + drop_path (float): The ratio of drop paths. + Defaults: 0.0. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.norm1 = build_norm_layer(norm_cfg, channels)[1] + self.attn = MSCASpatialAttention(channels, attention_kernel_sizes, + attention_kernel_paddings, act_cfg) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, channels)[1] + mlp_hidden_channels = int(channels * mlp_ratio) + self.mlp = Mlp( + in_features=channels, + hidden_features=mlp_hidden_channels, + act_cfg=act_cfg, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((channels)), + requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((channels)), + requires_grad=True) + + def forward(self, x, H, W): + """Forward function.""" + + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.attn(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + x = x.view(B, C, N).permute(0, 2, 1) + return x + + +class OverlapPatchEmbed(BaseModule): + """Image to Patch Embedding. + Args: + patch_size (int): The patch size. + Defaults: 7. + stride (int): Stride of the convolutional layer. + Default: 4. + in_channels (int): The number of input channels. + Defaults: 3. + embed_dims (int): The dimensions of embedding. + Defaults: 768. + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + patch_size=7, + stride=4, + in_channels=3, + embed_dim=768, + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2) + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + + x = x.flatten(2).transpose(1, 2) + + return x, H, W + + +@BACKBONES.register_module() +class MSCAN(BaseModule): + """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. + This backbone is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + Args: + in_channels (int): The number of input channels. Defaults: 3. + embed_dims (list[int]): Embedding dimension. + Defaults: [64, 128, 256, 512]. + mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim. + Defaults: [4, 4, 4, 4]. + drop_rate (float): Dropout rate. Defaults: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0. + depths (list[int]): Depths of each Swin Transformer stage. + Default: [3, 4, 6, 3]. + num_stages (int): MSCAN stages. Default: 4. + attention_kernel_sizes (list): Size of attention kernel in + Attention Module (Figure 2(b) of original paper). + Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): Size of attention paddings + in Attention Module (Figure 2(b) of original paper). + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + norm_cfg (dict): Config of norm layers. + Defaults: dict(type='SyncBN', requires_grad=True). + pretrained (str, optional): model pretrained path. + Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[4, 4, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + pretrained=None, + init_cfg=None): + super(MSCAN, self).__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.num_stages = num_stages + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + if i == 0: + patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + MSCABlock( + channels=embed_dims[i], + attention_kernel_sizes=attention_kernel_sizes, + attention_kernel_paddings=attention_kernel_paddings, + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + act_cfg=act_cfg, + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm = nn.LayerNorm(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + """Initialize modules of MSCAN.""" + + print('init cfg', self.init_cfg) + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super(MSCAN, self).init_weights() + + def forward(self, x): + """Forward function.""" + + B = x.shape[0] + outs = [] + + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/models/backbones/resnest.py b/data_utils/easyportrait/mmseg/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..91952c2caf37c1ca5c72ed845030d219d064b290 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/resnest.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None): + super(SplitAttentionConv2d, self).__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + This backbone is the implementation of `ResNeSt: + Split-Attention Networks `_. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super(ResNeSt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/resnet.py b/data_utils/easyportrait/mmseg/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b961d5faed3a25eb09576b1339e5ce18ca9627 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/resnet.py @@ -0,0 +1,714 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmcv.runner import BaseModule +from mmcv.utils.parrots_wrapper import _BatchNorm + +from ..builder import BACKBONES +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + """Basic block for ResNet.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(Bottleneck, self).__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + """Forward function for plugins.""" + out = x + for name in plugin_names: + out = getattr(self, name)(x) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + This backbone is the improved implementation of `Deep Residual Learning + for Image Recognition `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: (1, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default: (1, 1, 1, 1). + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: 'pytorch'. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): Dictionary to construct and config conv layer. + When conv_cfg is None, cfg will be set to dict(type='Conv2d'). + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (dict | None): Dictionary to construct and config DCN conv layer. + When dcn is not None, conv_cfg must be None. Default: None. + stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each + stage. The length of stage_with_dcn is equal to num_stages. + Default: (False, False, False, False). + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + + - position (str, required): Position inside block to insert plugin, + options: 'after_conv1', 'after_conv2', 'after_conv3'. + + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + Default: None. + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None. + contract_dilation (bool): Whether contract first dilation of each layer + Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + multi_grid=None, + contract_dilation=False, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.multi_grid = multi_grid + self.contract_dilation = contract_dilation + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + # multi grid is applied to last layer only + stage_multi_grid = multi_grid if i == len( + self.stage_blocks) - 1 else None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + multi_grid=stage_multi_grid, + contract_dilation=contract_dilation, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i+1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """make plugins for ResNet 'stage_idx'th stage . + + Currently we support to insert 'context_block', + 'empirical_attention_block', 'nonlocal_block' into the backbone like + ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be : + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose 'stage_idx=0', the structure of blocks in the stage would be: + conv1-> conv2->conv3->yyy->zzz1->zzz2 + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1c(ResNet): + """ResNetV1c variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in + the input stem with three 3x3 convs. For more details please refer to `Bag + of Tricks for Image Classification with Convolutional Neural Networks + `_. + """ + + def __init__(self, **kwargs): + super(ResNetV1c, self).__init__( + deep_stem=True, avg_down=False, **kwargs) + + +@BACKBONES.register_module() +class ResNetV1d(ResNet): + """ResNetV1d variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/resnext.py b/data_utils/easyportrait/mmseg/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..805c27bf332b365a0c374cfeb5a1ea6ea268ec43 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/resnext.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + super(Bottleneck, self).__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@BACKBONES.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + This backbone is the implementation of `Aggregated + Residual Transformations for Deep Neural + Networks `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + num_stages (int): Resnet stages, normally 4. + groups (int): Group of resnext. + base_width (int): Base width of resnext. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmseg.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super(ResNeXt, self).__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/stdc.py b/data_utils/easyportrait/mmseg/models/backbones/stdc.py new file mode 100644 index 0000000000000000000000000000000000000000..04f2f7a2a7574633b1878e0d10c0bb6371fdaffa --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/stdc.py @@ -0,0 +1,422 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/MichaelFan01/STDC-Seg.""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner.base_module import BaseModule, ModuleList, Sequential + +from mmseg.ops import resize +from ..builder import BACKBONES, build_backbone +from .bisenetv1 import AttentionRefinementModule + + +class STDCModule(BaseModule): + """STDCModule. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels before scaling. + stride (int): The number of stride for the first conv layer. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layers. + fusion_type (str): Type of fusion operation. Default: 'add'. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + norm_cfg=None, + act_cfg=None, + num_convs=4, + fusion_type='add', + init_cfg=None): + super(STDCModule, self).__init__(init_cfg=init_cfg) + assert num_convs > 1 + assert fusion_type in ['add', 'cat'] + self.stride = stride + self.with_downsample = True if self.stride == 2 else False + self.fusion_type = fusion_type + + self.layers = ModuleList() + conv_0 = ConvModule( + in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) + + if self.with_downsample: + self.downsample = ConvModule( + out_channels // 2, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels // 2, + norm_cfg=norm_cfg, + act_cfg=None) + + if self.fusion_type == 'add': + self.layers.append(nn.Sequential(conv_0, self.downsample)) + self.skip = Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None)) + else: + self.layers.append(conv_0) + self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + self.layers.append(conv_0) + + for i in range(1, num_convs): + out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i + self.layers.append( + ConvModule( + out_channels // 2**i, + out_channels // out_factor, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + if self.fusion_type == 'add': + out = self.forward_add(inputs) + else: + out = self.forward_cat(inputs) + return out + + def forward_add(self, inputs): + layer_outputs = [] + x = inputs.clone() + for layer in self.layers: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + inputs = self.skip(inputs) + + return torch.cat(layer_outputs, dim=1) + inputs + + def forward_cat(self, inputs): + x0 = self.layers[0](inputs) + layer_outputs = [x0] + for i, layer in enumerate(self.layers[1:]): + if i == 0: + if self.with_downsample: + x = layer(self.downsample(x0)) + else: + x = layer(x0) + else: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + layer_outputs[0] = self.skip(x0) + return torch.cat(layer_outputs, dim=1) + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module. This module is different from FeatureFusionModule + in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter + channel number is calculated by given `scale_factor`, while + FeatureFusionModule in BiSeNetV1 only uses one ConvModule in + `self.conv_atten`. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + scale_factor (int): The number of channel scale factor. + Default: 4. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): The activation config for conv layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scale_factor=4, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(FeatureFusionModule, self).__init__(init_cfg=init_cfg) + channels = out_channels // scale_factor + self.conv0 = ConvModule( + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + out_channels, + channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=act_cfg), + ConvModule( + channels, + out_channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=None), nn.Sigmoid()) + + def forward(self, spatial_inputs, context_inputs): + inputs = torch.cat([spatial_inputs, context_inputs], dim=1) + x = self.conv0(inputs) + attn = self.attention(x) + x_attn = x * attn + return x_attn + x + + +@BACKBONES.register_module() +class STDCNet(BaseModule): + """This backbone is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + stdc_type (int): The type of backbone structure, + `STDCNet1` and`STDCNet2` denotes two main backbones in paper, + whose FLOPs is 813M and 1446M, respectively. + in_channels (int): The num of input_channels. + channels (tuple[int]): The output channels for each stage. + bottleneck_type (str): The type of STDC Module type, the value must + be 'add' or 'cat'. + norm_cfg (dict): Config dict for normalization layer. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layer at each STDC Module. + Default: 4. + with_final_conv (bool): Whether add a conv layer at the Module output. + Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> import torch + >>> stdc_type = 'STDCNet1' + >>> in_channels = 3 + >>> channels = (32, 64, 256, 512, 1024) + >>> bottleneck_type = 'cat' + >>> inputs = torch.rand(1, 3, 1024, 2048) + >>> self = STDCNet(stdc_type, in_channels, + ... channels, bottleneck_type).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 256, 128, 256]) + outputs[1].shape = torch.Size([1, 512, 64, 128]) + outputs[2].shape = torch.Size([1, 1024, 32, 64]) + """ + + arch_settings = { + 'STDCNet1': [(2, 1), (2, 1), (2, 1)], + 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] + } + + def __init__(self, + stdc_type, + in_channels, + channels, + bottleneck_type, + norm_cfg, + act_cfg, + num_convs=4, + with_final_conv=False, + pretrained=None, + init_cfg=None): + super(STDCNet, self).__init__(init_cfg=init_cfg) + assert stdc_type in self.arch_settings, \ + f'invalid structure {stdc_type} for STDCNet.' + assert bottleneck_type in ['add', 'cat'],\ + f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' + + assert len(channels) == 5,\ + f'invalid channels length {len(channels)} for STDCNet.' + + self.in_channels = in_channels + self.channels = channels + self.stage_strides = self.arch_settings[stdc_type] + self.prtrained = pretrained + self.num_convs = num_convs + self.with_final_conv = with_final_conv + + self.stages = ModuleList([ + ConvModule( + self.in_channels, + self.channels[0], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ]) + # `self.num_shallow_features` is the number of shallow modules in + # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. + # They are both not used for following modules like Attention + # Refinement Module and Feature Fusion Module. + # Thus they would be cut from `outs`. Please refer to Figure 4 + # of original paper for more details. + self.num_shallow_features = len(self.stages) + + for strides in self.stage_strides: + idx = len(self.stages) - 1 + self.stages.append( + self._make_stage(self.channels[idx], self.channels[idx + 1], + strides, norm_cfg, act_cfg, bottleneck_type)) + # After appending, `self.stages` is a ModuleList including several + # shallow modules and STDCModules. + # (len(self.stages) == + # self.num_shallow_features + len(self.stage_strides)) + if self.with_final_conv: + self.final_conv = ConvModule( + self.channels[-1], + max(1024, self.channels[-1]), + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def _make_stage(self, in_channels, out_channels, strides, norm_cfg, + act_cfg, bottleneck_type): + layers = [] + for i, stride in enumerate(strides): + layers.append( + STDCModule( + in_channels if i == 0 else out_channels, + out_channels, + stride, + norm_cfg, + act_cfg, + num_convs=self.num_convs, + fusion_type=bottleneck_type)) + return Sequential(*layers) + + def forward(self, x): + outs = [] + for stage in self.stages: + x = stage(x) + outs.append(x) + if self.with_final_conv: + outs[-1] = self.final_conv(outs[-1]) + outs = outs[self.num_shallow_features:] + return tuple(outs) + + +@BACKBONES.register_module() +class STDCContextPathNet(BaseModule): + """STDCNet with Context Path. The `outs` below is a list of three feature + maps from deep to shallow, whose height and width is from small to big, + respectively. The biggest feature map of `outs` is outputted for + `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. + The other two feature maps are used for Attention Refinement Module, + respectively. Besides, the biggest feature map of `outs` and the last + output of Attention Refinement Module are concatenated for Feature Fusion + Module. Then, this fusion feature map `feat_fuse` would be outputted for + `decode_head`. More details please refer to Figure 4 of original paper. + + Args: + backbone_cfg (dict): Config dict for stdc backbone. + last_in_channels (tuple(int)), The number of channels of last + two feature maps from stdc backbone. Default: (1024, 512). + out_channels (int): The channels of output feature maps. + Default: 128. + ffm_cfg (dict): Config dict for Feature Fusion Module. Default: + `dict(in_channels=512, out_channels=256, scale_factor=4)`. + upsample_mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'``. + align_corners (str): align_corners argument of F.interpolate. It + must be `None` if upsample_mode is ``'nearest'``. Default: None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Return: + outputs (tuple): The tuple of list of output feature map for + auxiliary heads and decoder head. + """ + + def __init__(self, + backbone_cfg, + last_in_channels=(1024, 512), + out_channels=128, + ffm_cfg=dict( + in_channels=512, out_channels=256, scale_factor=4), + upsample_mode='nearest', + align_corners=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) + self.backbone = build_backbone(backbone_cfg) + self.arms = ModuleList() + self.convs = ModuleList() + for channels in last_in_channels: + self.arms.append(AttentionRefinementModule(channels, out_channels)) + self.convs.append( + ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=norm_cfg)) + self.conv_avg = ConvModule( + last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) + + self.ffm = FeatureFusionModule(**ffm_cfg) + + self.upsample_mode = upsample_mode + self.align_corners = align_corners + + def forward(self, x): + outs = list(self.backbone(x)) + avg = F.adaptive_avg_pool2d(outs[-1], 1) + avg_feat = self.conv_avg(avg) + + feature_up = resize( + avg_feat, + size=outs[-1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + arms_out = [] + for i in range(len(self.arms)): + x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up + feature_up = resize( + x_arm, + size=outs[len(outs) - 1 - i - 1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + feature_up = self.convs[i](feature_up) + arms_out.append(feature_up) + + feat_fuse = self.ffm(outs[0], arms_out[1]) + + # The `outputs` has four feature maps. + # `outs[0]` is outputted for `STDCHead` auxiliary head. + # Two feature maps of `arms_out` are outputted for auxiliary head. + # `feat_fuse` is outputted for decoder head. + outputs = [outs[0]] + list(arms_out) + [feat_fuse] + return tuple(outputs) diff --git a/data_utils/easyportrait/mmseg/models/backbones/swin.py b/data_utils/easyportrait/mmseg/models/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf13288aff055b2e87dfe08849c6b4b4831424d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/swin.py @@ -0,0 +1,756 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) +from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, + load_state_dict) +from mmcv.utils import to_2tuple + +from ...utils import get_root_logger +from ..builder import BACKBONES +from ..utils.embed import PatchEmbed, PatchMerging + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super(SwinBlock, self).__init__(init_cfg=init_cfg) + + self.with_cp = with_cp + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + init_cfg=None) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@BACKBONES.register_module() +class SwinTransformer(BaseModule): + """Swin Transformer backbone. + + This backbone is the implementation of `Swin Transformer: + Hierarchical Vision Transformer using Shifted + Windows `_. + Inspiration from https://github.com/microsoft/Swin-Transformer. + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LN'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained=None, + frozen_stages=-1, + init_cfg=None): + self.frozen_stages = frozen_stages + + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + init_cfg = init_cfg + else: + raise TypeError('pretrained must be a str or None') + + super(SwinTransformer, self).__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=int(mlp_ratio * in_channels), + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + logger = get_root_logger() + if self.init_cfg is None: + logger.warn(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + _state_dict = ckpt['model'] + else: + _state_dict = ckpt + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + load_state_dict(self, state_dict, strict=False, logger=logger) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs diff --git a/data_utils/easyportrait/mmseg/models/backbones/timm_backbone.py b/data_utils/easyportrait/mmseg/models/backbones/timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..01b29fc5ed9f8f0cfbfc25aca8f8628e1f2670d6 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/timm_backbone.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from mmcv.cnn.bricks.registry import NORM_LAYERS +from mmcv.runner import BaseModule + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class TIMMBackbone(BaseModule): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + features_only=True, + pretrained=True, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model(x) + return features diff --git a/data_utils/easyportrait/mmseg/models/backbones/twins.py b/data_utils/easyportrait/mmseg/models/backbones/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd9469118ca76a0fe87e04e8a52dd8144791a43 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/twins.py @@ -0,0 +1,588 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN +from mmcv.cnn.utils.weight_init import (constant_init, normal_init, + trunc_normal_init) +from mmcv.runner import BaseModule, ModuleList +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.models.backbones.mit import EfficientMultiheadAttention +from mmseg.models.builder import BACKBONES +from ..utils.embed import PatchEmbed + + +class GlobalSubsampledAttention(EfficientMultiheadAttention): + """Global Sub-sampled Attention (Spatial Reduction Attention) + + This module is modified from EfficientMultiheadAttention, + which is a module from mmseg.models.backbones.mit.py. + Specifically, there is no difference between + `GlobalSubsampledAttention` and `EfficientMultiheadAttention`, + `GlobalSubsampledAttention` is built as a brand new class + because it is renamed as `Global sub-sampled attention (GSA)` + in paper. + + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dims) + or (n, batch, embed_dims). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default: True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT. + Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type='LN'), + sr_ratio=1, + init_cfg=None): + super(GlobalSubsampledAttention, self).__init__( + embed_dims, + num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio, + init_cfg=init_cfg) + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GSA. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \ + f'divided by num_heads ' \ + f'{num_heads}.' + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + x = x.view(b, h, w, c) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - w % self.window_size) % self.window_size + pad_b = (self.window_size - h % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(b, _h, self.window_size, _w, self.window_size, + c).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(b, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, c // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size, + self.window_size, c) + x = attn.transpose(2, 3).reshape(b, _h * self.window_size, + _w * self.window_size, c) + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + x = x.reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer in Twins-SVT. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + feat_token = x + cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w) + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +@BACKBONES.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4, 8]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [3, 4, 6, 3] + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [8, 4, 2, 1]. + norm_after_stage(bool): Add extra norm. Default False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + pretrained=None, + init_cfg=None): + super(PCPVT, self).__init__(init_cfg=init_cfg) + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + self.depths = depths + + # patch_embed + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.layers = ModuleList() + + for i in range(len(depths)): + self.patch_embeds.append( + PatchEmbed( + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dims=embed_dims[i], + conv_type='Conv2d', + kernel_size=patch_sizes[i], + stride=strides[i], + padding='corner', + norm_cfg=norm_cfg)) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in embed_dims + ]) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=sr_ratios[k]) for i in range(depths[k]) + ]) + self.layers.append(_block) + cur += depths[k] + + self.norm_name, norm = build_norm_layer( + norm_cfg, embed_dims[-1], postfix=1) + + self.out_indices = out_indices + self.norm_after_stage = norm_after_stage + if self.norm_after_stage: + self.norm_list = ModuleList() + for dim in embed_dims: + self.norm_list.append(build_norm_layer(norm_cfg, dim)[1]) + + def init_weights(self): + if self.init_cfg is not None: + super(PCPVT, self).init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(len(self.depths)): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.layers[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + if self.norm_after_stage: + x = self.norm_list[i](x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@BACKBONES.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Dropout rate. Default 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.2. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [4, 4, 4]. + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [4, 2, 1]. + windiow_sizes (list): Window size of LSA. Default: [7, 7, 7], + input_features_slice(bool): Input features need slice. Default: False. + norm_after_stage(bool): Add extra norm. Default False. + strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2) + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_cfg=dict(type='LN'), + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + windiow_sizes=[7, 7, 7], + norm_after_stage=True, + pretrained=None, + init_cfg=None): + super(SVT, self).__init__(in_channels, embed_dims, patch_sizes, + strides, num_heads, mlp_ratios, out_indices, + qkv_bias, drop_rate, attn_drop_rate, + drop_path_rate, norm_cfg, depths, sr_ratios, + norm_after_stage, pretrained, init_cfg) + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + for k in range(len(depths)): + for i in range(depths[k]): + if i % 2 == 0: + self.layers[k][i] = \ + LSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:k])+i], + qkv_bias=qkv_bias, + window_size=windiow_sizes[k]) diff --git a/data_utils/easyportrait/mmseg/models/backbones/unet.py b/data_utils/easyportrait/mmseg/models/backbones/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d33667f88fbdbea574c12b750cb540f79da387 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/unet.py @@ -0,0 +1,438 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, + build_norm_layer) +from mmcv.runner import BaseModule +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmseg.ops import Upsample +from ..builder import BACKBONES +from ..utils import UpConvBlock + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super(BasicConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super(DeconvModule, self).__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + norm_name, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super(InterpConv, self).__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@BACKBONES.register_module() +class UNet(BaseModule): + """UNet backbone. + + This backbone is the implementation of `U-Net: Convolutional Networks + for Biomedical Image Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None, + pretrained=None, + init_cfg=None): + super(UNet, self).__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append((nn.Sequential(*enc_conv_block))) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(UNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be divisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' diff --git a/data_utils/easyportrait/mmseg/models/backbones/vit.py b/data_utils/easyportrait/mmseg/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..37b9a4f4b816786678f08f77c71e5bbaaa167f78 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/backbones/vit.py @@ -0,0 +1,440 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, + load_state_dict) +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.ops import resize +from mmseg.utils import get_root_logger +from ..builder import BACKBONES +from ..utils import PatchEmbed + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False): + super(TransformerEncoderLayer, self).__init__() + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + attn_cfg.update( + dict( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + batch_first=batch_first, + bias=qkv_bias)) + + self.build_attn(attn_cfg) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + ffn_cfg.update( + dict( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.build_ffn(ffn_cfg) + self.with_cp = with_cp + + def build_attn(self, attn_cfg): + self.attn = MultiheadAttention(**attn_cfg) + + def build_ffn(self, ffn_cfg): + self.ffn = FFN(**ffn_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@BACKBONES.register_module() +class VisionTransformer(BaseModule): + """Vision Transformer. + + This backbone is the implementation of `An Image is Worth 16x16 Words: + Transformers for Image Recognition at + Scale `_. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qkv_bias (bool): enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Default: True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + interpolate_mode='bicubic', + num_fcs=2, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.img_size = img_size + self.patch_size = patch_size + self.interpolate_mode = interpolate_mode + self.norm_eval = norm_eval + self.with_cp = with_cp + self.pretrained = pretrained + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + + num_patches = (img_size[0] // patch_size) * \ + (img_size[1] // patch_size) + + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + logger.info(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), self.interpolate_mode) + + load_state_dict(self, state_dict, strict=False, logger=logger) + elif self.init_cfg is not None: + super(VisionTransformer, self).init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def _pos_embeding(self, patched_img, hw_shape, pos_embed): + """Positioning embeding method. + + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size) + 1: + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), + self.interpolate_mode) + return self.drop_after_pos(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + # keep dim for easy deployment + cls_token_weight = pos_embed[:, 0:1] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize( + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self._pos_embeding(x, hw_shape, self.pos_embed) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super(VisionTransformer, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/data_utils/easyportrait/mmseg/models/builder.py b/data_utils/easyportrait/mmseg/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..5e18e4e6430a80434d7cd2a3eed55ea343fccab6 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/builder.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION +from mmcv.utils import Registry + +MODELS = Registry('models', parent=MMCV_MODELS) +ATTENTION = Registry('attention', parent=MMCV_ATTENTION) + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/__init__.py b/data_utils/easyportrait/mmseg/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce44bccfb05b72cbf0489aca517785f03022b0ef --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ann_head import ANNHead +from .apc_head import APCHead +from .aspp_head import ASPPHead +from .cc_head import CCHead +from .da_head import DAHead +from .dm_head import DMHead +from .dnl_head import DNLHead +from .dpt_head import DPTHead +from .ema_head import EMAHead +from .enc_head import EncHead +from .fcn_head import FCNHead +from .fpn_head import FPNHead +from .gc_head import GCHead +from .ham_head import LightHamHead +from .isa_head import ISAHead +from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator +from .lraspp_head import LRASPPHead +from .nl_head import NLHead +from .ocr_head import OCRHead +from .point_head import PointHead +from .psa_head import PSAHead +from .psp_head import PSPHead +from .segformer_head import SegformerHead +from .segmenter_mask_head import SegmenterMaskTransformerHead +from .sep_aspp_head import DepthwiseSeparableASPPHead +from .sep_fcn_head import DepthwiseSeparableFCNHead +from .setr_mla_head import SETRMLAHead +from .setr_up_head import SETRUPHead +from .stdc_head import STDCHead +from .uper_head import UPerHead + +__all__ = [ + 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', + 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', + 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', + 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', + 'KernelUpdateHead', 'KernelUpdator', 'LightHamHead' +] \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/ann_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/ann_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d882e31963f09b9d250f0ce0bfde63f97d368d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/ann_head.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PPMConcat(nn.ModuleList): + """Pyramid Pooling Module that only concat the features of each layer. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + """ + + def __init__(self, pool_scales=(1, 3, 6, 8)): + super(PPMConcat, self).__init__( + [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) + + def forward(self, feats): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(feats) + ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) + concat_outs = torch.cat(ppm_outs, dim=2) + return concat_outs + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Make a ANN used SelfAttentionBlock. + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_scale (int): The scale of query feature map. + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, share_key_query, query_scale, key_pool_scales, + conv_cfg, norm_cfg, act_cfg): + key_psp = PPMConcat(key_pool_scales) + if query_scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=query_scale) + else: + query_downsample = None + super(SelfAttentionBlock, self).__init__( + key_in_channels=low_in_channels, + query_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=share_key_query, + query_downsample=query_downsample, + key_downsample=key_psp, + key_query_num_convs=1, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + +class AFNB(nn.Module): + """Asymmetric Fusion Non-local Block(AFNB) + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + and query projection. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, query_scales, key_pool_scales, conv_cfg, + norm_cfg, act_cfg): + super(AFNB, self).__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=False, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + out_channels + high_in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, low_feats, high_feats): + """Forward function.""" + priors = [stage(high_feats, low_feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, high_feats], 1)) + return output + + +class APNB(nn.Module): + """Asymmetric Pyramid Non-local Block (APNB) + + Args: + in_channels (int): Input channels of key/query feature, + which is the key feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, out_channels, query_scales, + key_pool_scales, conv_cfg, norm_cfg, act_cfg): + super(APNB, self).__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=in_channels, + high_in_channels=in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=True, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + 2 * in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, feats): + """Forward function.""" + priors = [stage(feats, feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, feats], 1)) + return output + + +@HEADS.register_module() +class ANNHead(BaseDecodeHead): + """Asymmetric Non-local Neural Networks for Semantic Segmentation. + + This head is the implementation of `ANNNet + `_. + + Args: + project_channels (int): Projection channels for Nonlocal. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): The pooling scales of key feature map. + Default: (1, 3, 6, 8). + """ + + def __init__(self, + project_channels, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + **kwargs): + super(ANNHead, self).__init__( + input_transform='multiple_select', **kwargs) + assert len(self.in_channels) == 2 + low_in_channels, high_in_channels = self.in_channels + self.project_channels = project_channels + self.fusion = AFNB( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + out_channels=high_in_channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + high_in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.context = APNB( + in_channels=self.channels, + out_channels=self.channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + low_feats, high_feats = self._transform_inputs(inputs) + output = self.fusion(low_feats, high_feats) + output = self.dropout(output) + output = self.bottleneck(output) + output = self.context(output) + output = self.cls_seg(output) + + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/apc_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/apc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3198fd18817fe0a3698de42352073ba2d8ca4b5e --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/apc_head.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class ACM(nn.Module): + """Adaptive Context Module used in APCNet. + + Args: + pool_scale (int): Pooling scale used in Adaptive Context + Module to extract region features. + fusion (bool): Add one conv to fuse residual feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super(ACM, self).__init__() + self.pool_scale = pool_scale + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.pooled_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.global_info = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) + + self.residual_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) + # [batch_size, channels, h, w] + x = self.input_redu_conv(x) + # [batch_size, channels, pool_scale, pool_scale] + pooled_x = self.pooled_redu_conv(pooled_x) + batch_size = x.size(0) + # [batch_size, pool_scale * pool_scale, channels] + pooled_x = pooled_x.view(batch_size, self.channels, + -1).permute(0, 2, 1).contiguous() + # [batch_size, h * w, pool_scale * pool_scale] + affinity_matrix = self.gla(x + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) + ).permute(0, 2, 3, 1).reshape( + batch_size, -1, self.pool_scale**2) + affinity_matrix = F.sigmoid(affinity_matrix) + # [batch_size, h * w, channels] + z_out = torch.matmul(affinity_matrix, pooled_x) + # [batch_size, channels, h * w] + z_out = z_out.permute(0, 2, 1).contiguous() + # [batch_size, channels, h, w] + z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) + z_out = self.residual_conv(z_out) + z_out = F.relu(z_out + x) + if self.fusion: + z_out = self.fusion_conv(z_out) + + return z_out + + +@HEADS.register_module() +class APCHead(BaseDecodeHead): + """Adaptive Pyramid Context Network for Semantic Segmentation. + + This head is the implementation of + `APCNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Adaptive Context + Module. Default: (1, 2, 3, 6). + fusion (bool): Add one conv to fuse residual feature. + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): + super(APCHead, self).__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.fusion = fusion + acm_modules = [] + for pool_scale in self.pool_scales: + acm_modules.append( + ACM(pool_scale, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.acm_modules = nn.ModuleList(acm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + acm_outs = [x] + for acm_module in self.acm_modules: + acm_outs.append(acm_module(x)) + acm_outs = torch.cat(acm_outs, dim=1) + output = self.bottleneck(acm_outs) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/aspp_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7059aee961183952ec60ca4d6ccea5d3f9ae3f79 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/aspp_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class ASPPModule(nn.ModuleList): + """Atrous Spatial Pyramid Pooling (ASPP) Module. + + Args: + dilations (tuple[int]): Dilation rate of each layer. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, + act_cfg): + super(ASPPModule, self).__init__() + self.dilations = dilations + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for dilation in dilations: + self.append( + ConvModule( + self.in_channels, + self.channels, + 1 if dilation == 1 else 3, + dilation=dilation, + padding=0 if dilation == 1 else dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x): + """Forward function.""" + aspp_outs = [] + for aspp_module in self: + aspp_outs.append(aspp_module(x)) + + return aspp_outs + + +@HEADS.register_module() +class ASPPHead(BaseDecodeHead): + """Rethinking Atrous Convolution for Semantic Image Segmentation. + + This head is the implementation of `DeepLabV3 + `_. + + Args: + dilations (tuple[int]): Dilation rates for ASPP module. + Default: (1, 6, 12, 18). + """ + + def __init__(self, dilations=(1, 6, 12, 18), **kwargs): + super(ASPPHead, self).__init__(**kwargs) + assert isinstance(dilations, (list, tuple)) + self.dilations = dilations + self.image_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.aspp_modules = ASPPModule( + dilations, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + (len(dilations) + 1) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + feats = self.bottleneck(aspp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/cascade_decode_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/cascade_decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c3da0d67f920af01cad802461ce74d44404d7d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/cascade_decode_head.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +from .decode_head import BaseDecodeHead + + +class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): + """Base class for cascade decode head used in + :class:`CascadeEncoderDecoder.""" + + def __init__(self, *args, **kwargs): + super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) + + @abstractmethod + def forward(self, inputs, prev_output): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, + train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs, prev_output) + losses = self.losses(seg_logits, gt_semantic_seg) + + return losses + + def forward_test(self, inputs, prev_output, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs, prev_output) diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/cc_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/cc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ed19eb46d123fa8da5b61cf98a6a3f78abf8a494 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/cc_head.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..builder import HEADS +from .fcn_head import FCNHead + +try: + from mmcv.ops import CrissCrossAttention +except ModuleNotFoundError: + CrissCrossAttention = None + + +@HEADS.register_module() +class CCHead(FCNHead): + """CCNet: Criss-Cross Attention for Semantic Segmentation. + + This head is the implementation of `CCNet + `_. + + Args: + recurrence (int): Number of recurrence of Criss Cross Attention + module. Default: 2. + """ + + def __init__(self, recurrence=2, **kwargs): + if CrissCrossAttention is None: + raise RuntimeError('Please install mmcv-full for ' + 'CrissCrossAttention ops') + super(CCHead, self).__init__(num_convs=2, **kwargs) + self.recurrence = recurrence + self.cca = CrissCrossAttention(self.channels) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + for _ in range(self.recurrence): + output = self.cca(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/da_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/da_head.py new file mode 100644 index 0000000000000000000000000000000000000000..77fd6639c00bd4cad4cfa8f0da13a3f38dccf840 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/da_head.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from torch import nn + +from mmseg.core import add_prefix +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PAM(_SelfAttentionBlock): + """Position Attention Module (PAM) + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + """ + + def __init__(self, in_channels, channels): + super(PAM, self).__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=1, + key_query_norm=False, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=False, + with_out=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + out = super(PAM, self).forward(x, x) + + out = self.gamma(out) + x + return out + + +class CAM(nn.Module): + """Channel Attention Module (CAM)""" + + def __init__(self): + super(CAM, self).__init__() + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + batch_size, channels, height, width = x.size() + proj_query = x.view(batch_size, channels, -1) + proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max( + energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = F.softmax(energy_new, dim=-1) + proj_value = x.view(batch_size, channels, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(batch_size, channels, height, width) + + out = self.gamma(out) + x + return out + + +@HEADS.register_module() +class DAHead(BaseDecodeHead): + """Dual Attention Network for Scene Segmentation. + + This head is the implementation of `DANet + `_. + + Args: + pam_channels (int): The channels of Position Attention Module(PAM). + """ + + def __init__(self, pam_channels, **kwargs): + super(DAHead, self).__init__(**kwargs) + self.pam_channels = pam_channels + self.pam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam = PAM(self.channels, pam_channels) + self.pam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + self.cam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam = CAM() + self.cam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + def pam_cls_seg(self, feat): + """PAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.pam_conv_seg(feat) + return output + + def cam_cls_seg(self, feat): + """CAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.cam_conv_seg(feat) + return output + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + pam_feat = self.pam_in_conv(x) + pam_feat = self.pam(pam_feat) + pam_feat = self.pam_out_conv(pam_feat) + pam_out = self.pam_cls_seg(pam_feat) + + cam_feat = self.cam_in_conv(x) + cam_feat = self.cam(cam_feat) + cam_feat = self.cam_out_conv(cam_feat) + cam_out = self.cam_cls_seg(cam_feat) + + feat_sum = pam_feat + cam_feat + pam_cam_out = self.cls_seg(feat_sum) + + return pam_cam_out, pam_out, cam_out + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing, only ``pam_cam`` is used.""" + return self.forward(inputs)[0] + + def losses(self, seg_logit, seg_label): + """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" + pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit + loss = dict() + loss.update( + add_prefix( + super(DAHead, self).losses(pam_cam_seg_logit, seg_label), + 'pam_cam')) + loss.update( + add_prefix( + super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) + loss.update( + add_prefix( + super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) + return loss diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/decode_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9a04a1bd0b526a688a712c5fe6edda7d088747a5 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/decode_head.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 + +from mmseg.core import build_pixel_sampler +from mmseg.ops import resize +from ..builder import build_loss +from ..losses import accuracy + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. + threshold (float): Threshold for binary segmentation in the case of + `out_channels==1`. Default: None. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + downsample_label_ratio (int): The ratio to downsample seg_label + in losses. downsample_label_ratio > 1 will reduce memory usage. + Disabled if downsample_label_ratio = 0. + Default: 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + out_channels=None, + threshold=None, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + downsample_label_ratio=0, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super(BaseDecodeHead, self).__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.ignore_index = ignore_index + self.align_corners = align_corners + self.downsample_label_ratio = downsample_label_ratio + if not isinstance(self.downsample_label_ratio, int) or \ + self.downsample_label_ratio < 0: + warnings.warn('downsample_label_ratio should ' + 'be set as an integer equal or larger than 0.') + + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert seg_logist into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @auto_fp16() + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self(inputs) + losses = self.losses(seg_logits, gt_semantic_seg) + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs) + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + @force_fp32(apply_to=('seg_logit', )) + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + loss = dict() + if self.downsample_label_ratio > 0: + seg_label = seg_label.float() + target_size = (seg_label.shape[2] // self.downsample_label_ratio, + seg_label.shape[3] // self.downsample_label_ratio) + seg_label = resize( + input=seg_label, size=target_size, mode='nearest') + seg_label = seg_label.long() + seg_logit = resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logit, seg_label, ignore_index=self.ignore_index) + return loss diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/dm_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/dm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ffaa870ab3a6f10afec7d07d395e8f405eda936f --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/dm_head.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class DCM(nn.Module): + """Dynamic Convolutional Module used in DMNet. + + Args: + filter_size (int): The filter size of generated convolution kernel + used in Dynamic Convolutional Module. + fusion (bool): Add one conv to fuse DCM output feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super(DCM, self).__init__() + self.filter_size = filter_size + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, + 0) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.norm_cfg is not None: + self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] + else: + self.norm = None + self.activate = build_activation_layer(self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + generated_filter = self.filter_gen_conv( + F.adaptive_avg_pool2d(x, self.filter_size)) + x = self.input_redu_conv(x) + b, c, h, w = x.shape + # [1, b * c, h, w], c = self.channels + x = x.view(1, b * c, h, w) + # [b * c, 1, filter_size, filter_size] + generated_filter = generated_filter.view(b * c, 1, self.filter_size, + self.filter_size) + pad = (self.filter_size - 1) // 2 + if (self.filter_size - 1) % 2 == 0: + p2d = (pad, pad, pad, pad) + else: + p2d = (pad + 1, pad, pad + 1, pad) + x = F.pad(input=x, pad=p2d, mode='constant', value=0) + # [1, b * c, h, w] + output = F.conv2d(input=x, weight=generated_filter, groups=b * c) + # [b, c, h, w] + output = output.view(b, c, h, w) + if self.norm is not None: + output = self.norm(output) + output = self.activate(output) + + if self.fusion: + output = self.fusion_conv(output) + + return output + + +@HEADS.register_module() +class DMHead(BaseDecodeHead): + """Dynamic Multi-scale Filters for Semantic Segmentation. + + This head is the implementation of + `DMNet `_. + + Args: + filter_sizes (tuple[int]): The size of generated convolutional filters + used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). + fusion (bool): Add one conv to fuse DCM output feature. + """ + + def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): + super(DMHead, self).__init__(**kwargs) + assert isinstance(filter_sizes, (list, tuple)) + self.filter_sizes = filter_sizes + self.fusion = fusion + dcm_modules = [] + for filter_size in self.filter_sizes: + dcm_modules.append( + DCM(filter_size, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.dcm_modules = nn.ModuleList(dcm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(filter_sizes) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + dcm_outs = [x] + for dcm_module in self.dcm_modules: + dcm_outs.append(dcm_module(x)) + dcm_outs = torch.cat(dcm_outs, dim=1) + output = self.bottleneck(dcm_outs) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/dnl_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/dnl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..dabf1542148d3dafa419fdb2884b006144742e5a --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/dnl_head.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d +from torch import nn + +from ..builder import HEADS +from .fcn_head import FCNHead + + +class DisentangledNonLocal2d(NonLocal2d): + """Disentangled Non-Local Blocks. + + Args: + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, *arg, temperature, **kwargs): + super().__init__(*arg, **kwargs) + self.temperature = temperature + self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) + + def embedded_gaussian(self, theta_x, phi_x): + """Embedded gaussian with temperature.""" + + # NonLocal2d pairwise_weight: [N, HxW, HxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= torch.tensor( + theta_x.shape[-1], + dtype=torch.float, + device=pairwise_weight.device)**torch.tensor( + 0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor( + self.temperature, device=pairwise_weight.device) + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def forward(self, x): + # x: [N, C, H, W] + n = x.size(0) + + # g_x: [N, HxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta_x: [N, HxW, C], phi_x: [N, C, HxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + # subtract mean + theta_x -= theta_x.mean(dim=-2, keepdim=True) + phi_x -= phi_x.mean(dim=-1, keepdim=True) + + pairwise_func = getattr(self, self.mode) + # pairwise_weight: [N, HxW, HxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # y: [N, HxW, C] + y = torch.matmul(pairwise_weight, g_x) + # y: [N, C, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + # unary_mask: [N, 1, HxW] + unary_mask = self.conv_mask(x) + unary_mask = unary_mask.view(n, 1, -1) + unary_mask = unary_mask.softmax(dim=-1) + # unary_x: [N, 1, C] + unary_x = torch.matmul(unary_mask, g_x) + # unary_x: [N, C, 1, 1] + unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( + n, self.inter_channels, 1, 1) + + output = x + self.conv_out(y + unary_x) + + return output + + +@HEADS.register_module() +class DNLHead(FCNHead): + """Disentangled Non-Local Neural Networks. + + This head is the implementation of `DNLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: False. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + temperature=0.05, + **kwargs): + super(DNLHead, self).__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.temperature = temperature + self.dnl_block = DisentangledNonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode, + temperature=self.temperature) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.dnl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/dpt_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6c895d02dfc06c369de878e919ae5ccd5073d819 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/dpt_head.py @@ -0,0 +1,294 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels=768, + out_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + init_cfg=None): + super(ReassembleBlocks, self).__init__(init_cfg) + + assert readout_type in ['ignore', 'add', 'project'] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList([ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + if self.readout_type == 'project': + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + Linear(2 * in_channels, in_channels), + build_activation_layer(dict(type='GELU')))) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == 'project': + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == 'add': + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + stride=1, + dilation=1, + init_cfg=None): + super(PreActResidualConvUnit, self).__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + expand=False, + align_corners=True, + init_cfg=None): + super(FeatureFusionBlock, self).__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule( + self.in_channels, + self.out_channels, + kernel_size=1, + act_cfg=None, + bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize( + inputs[1], + size=(x.shape[2], x.shape[3]), + mode='bilinear', + align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + x = self.project(x) + return x + + +@HEADS.register_module() +class DPTHead(BaseDecodeHead): + """Vision Transformers for Dense Prediction. + + This head is implemented of `DPT `_. + + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + act_cfg (dict): The activation config for residual conv unit. + Default dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + """ + + def __init__(self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + expand_channels=False, + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN'), + **kwargs): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, + post_process_channels, + readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel + for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append( + ConvModule( + channel, + self.channels, + kernel_size=3, + padding=1, + act_cfg=None, + bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append( + FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule( + self.channels, + self.channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + + def forward(self, inputs): + assert len(inputs) == self.num_reassemble_blocks + x = self._transform_inputs(inputs) + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.cls_seg(out) + return out diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/ema_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/ema_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f6de16711187c6235d7b2e594e5d71ac0ab66b99 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/ema_head.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +def reduce_mean(tensor): + """Reduce mean when distributed training.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +class EMAModule(nn.Module): + """Expectation Maximization Attention Module used in EMANet. + + Args: + channels (int): Channels of the whole module. + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + """ + + def __init__(self, channels, num_bases, num_stages, momentum): + super(EMAModule, self).__init__() + assert num_stages >= 1, 'num_stages must be at least 1!' + self.num_bases = num_bases + self.num_stages = num_stages + self.momentum = momentum + + bases = torch.zeros(1, channels, self.num_bases) + bases.normal_(0, math.sqrt(2. / self.num_bases)) + # [1, channels, num_bases] + bases = F.normalize(bases, dim=1, p=2) + self.register_buffer('bases', bases) + + def forward(self, feats): + """Forward function.""" + batch_size, channels, height, width = feats.size() + # [batch_size, channels, height*width] + feats = feats.view(batch_size, channels, height * width) + # [batch_size, channels, num_bases] + bases = self.bases.repeat(batch_size, 1, 1) + + with torch.no_grad(): + for i in range(self.num_stages): + # [batch_size, height*width, num_bases] + attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = F.softmax(attention, dim=2) + # l1 norm + attention_normed = F.normalize(attention, dim=1, p=1) + # [batch_size, channels, num_bases] + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = feats_recon.view(batch_size, channels, height, width) + + if self.training: + bases = bases.mean(dim=0, keepdim=True) + bases = reduce_mean(bases) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + self.bases = (1 - + self.momentum) * self.bases + self.momentum * bases + + return feats_recon + + +@HEADS.register_module() +class EMAHead(BaseDecodeHead): + """Expectation Maximization Attention Networks for Semantic Segmentation. + + This head is the implementation of `EMANet + `_. + + Args: + ema_channels (int): EMA module channels + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + concat_input (bool): Whether concat the input and output of convs + before classification layer. Default: True + momentum (float): Momentum to update the base. Default: 0.1. + """ + + def __init__(self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs): + super(EMAHead, self).__init__(**kwargs) + self.ema_channels = ema_channels + self.num_bases = num_bases + self.num_stages = num_stages + self.concat_input = concat_input + self.momentum = momentum + self.ema_module = EMAModule(self.ema_channels, self.num_bases, + self.num_stages, self.momentum) + + self.ema_in_conv = ConvModule( + self.in_channels, + self.ema_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # project (0, inf) -> (-inf, inf) + self.ema_mid_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None) + for param in self.ema_mid_conv.parameters(): + param.requires_grad = False + + self.ema_out_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.bottleneck = ConvModule( + self.ema_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.ema_in_conv(x) + identity = feats + feats = self.ema_mid_conv(feats) + recon = self.ema_module(feats) + recon = F.relu(recon, inplace=True) + recon = self.ema_out_conv(recon) + output = F.relu(identity + recon, inplace=True) + output = self.bottleneck(output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/enc_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/enc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..648c8906b91254658dd71acd6471f09e87e7a720 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/enc_head.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.ops import Encoding, resize +from ..builder import HEADS, build_loss +from .decode_head import BaseDecodeHead + + +class EncModule(nn.Module): + """Encoding Module used in EncNet. + + Args: + in_channels (int): Input channels. + num_codes (int): Number of code words. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): + super(EncModule, self).__init__() + self.encoding_project = ConvModule( + in_channels, + in_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # TODO: resolve this hack + # change to 1d + if norm_cfg is not None: + encoding_norm_cfg = norm_cfg.copy() + if encoding_norm_cfg['type'] in ['BN', 'IN']: + encoding_norm_cfg['type'] += '1d' + else: + encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( + '2d', '1d') + else: + # fallback to BN1d + encoding_norm_cfg = dict(type='BN1d') + self.encoding = nn.Sequential( + Encoding(channels=in_channels, num_codes=num_codes), + build_norm_layer(encoding_norm_cfg, num_codes)[1], + nn.ReLU(inplace=True)) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels), nn.Sigmoid()) + + def forward(self, x): + """Forward function.""" + encoding_projection = self.encoding_project(x) + encoding_feat = self.encoding(encoding_projection).mean(dim=1) + batch_size, channels, _, _ = x.size() + gamma = self.fc(encoding_feat) + y = gamma.view(batch_size, channels, 1, 1) + output = F.relu_(x + x * y) + return encoding_feat, output + + +@HEADS.register_module() +class EncHead(BaseDecodeHead): + """Context Encoding for Semantic Segmentation. + + This head is the implementation of `EncNet + `_. + + Args: + num_codes (int): Number of code words. Default: 32. + use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to + regularize the training. Default: True. + add_lateral (bool): Whether use lateral connection to fuse features. + Default: False. + loss_se_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss', use_sigmoid=True). + """ + + def __init__(self, + num_codes=32, + use_se_loss=True, + add_lateral=False, + loss_se_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=0.2), + **kwargs): + super(EncHead, self).__init__( + input_transform='multiple_select', **kwargs) + self.use_se_loss = use_se_loss + self.add_lateral = add_lateral + self.num_codes = num_codes + self.bottleneck = ConvModule( + self.in_channels[-1], + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if add_lateral: + self.lateral_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the last one + self.lateral_convs.append( + ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.fusion = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.enc_module = EncModule( + self.channels, + num_codes=num_codes, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.use_se_loss: + self.loss_se_decode = build_loss(loss_se_decode) + self.se_layer = nn.Linear(self.channels, self.num_classes) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + feat = self.bottleneck(inputs[-1]) + if self.add_lateral: + laterals = [ + resize( + lateral_conv(inputs[i]), + size=feat.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + feat = self.fusion(torch.cat([feat, *laterals], 1)) + encode_feat, output = self.enc_module(feat) + output = self.cls_seg(output) + if self.use_se_loss: + se_output = self.se_layer(encode_feat) + return output, se_output + else: + return output + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing, ignore se_loss.""" + if self.use_se_loss: + return self.forward(inputs)[0] + else: + return self.forward(inputs) + + @staticmethod + def _convert_to_onehot_labels(seg_label, num_classes): + """Convert segmentation label to onehot. + + Args: + seg_label (Tensor): Segmentation label of shape (N, H, W). + num_classes (int): Number of classes. + + Returns: + Tensor: Onehot labels of shape (N, num_classes). + """ + + batch_size = seg_label.size(0) + onehot_labels = seg_label.new_zeros((batch_size, num_classes)) + for i in range(batch_size): + hist = seg_label[i].float().histc( + bins=num_classes, min=0, max=num_classes - 1) + onehot_labels[i] = hist > 0 + return onehot_labels + + def losses(self, seg_logit, seg_label): + """Compute segmentation and semantic encoding loss.""" + seg_logit, se_seg_logit = seg_logit + loss = dict() + loss.update(super(EncHead, self).losses(seg_logit, seg_label)) + se_loss = self.loss_se_decode( + se_seg_logit, + self._convert_to_onehot_labels(seg_label, self.num_classes)) + loss['loss_se'] = se_loss + return loss diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/fcn_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e27be6923931ae9e63302ed770b95f0b7d47fc07 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/fcn_head.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class FCNHead(BaseDecodeHead): + """Fully Convolution Networks for Semantic Segmentation. + + This head is implemented of `FCNNet `_. + + Args: + num_convs (int): Number of convs in the head. Default: 2. + kernel_size (int): The kernel size for convs in the head. Default: 3. + concat_input (bool): Whether concat the input and output of convs + before classification layer. + dilation (int): The dilation rate for convs in the head. Default: 1. + """ + + def __init__(self, + num_convs=2, + kernel_size=3, + concat_input=True, + dilation=1, + **kwargs): + assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) + self.num_convs = num_convs + self.concat_input = concat_input + self.kernel_size = kernel_size + super(FCNHead, self).__init__(**kwargs) + if num_convs == 0: + assert self.in_channels == self.channels + + conv_padding = (kernel_size // 2) * dilation + convs = [] + for i in range(num_convs): + _in_channels = self.in_channels if i == 0 else self.channels + convs.append( + ConvModule( + _in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + if len(convs) == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + feats = self.convs(x) + if self.concat_input: + feats = self.conv_cat(torch.cat([x, feats], dim=1)) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/fpn_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/fpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e41f324cca471fe1d8fe7d95c8f5a8514d18f255 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/fpn_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import Upsample, resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class FPNHead(BaseDecodeHead): + """Panoptic Feature Pyramid Networks. + + This head is the implementation of `Semantic FPN + `_. + + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, feature_strides, **kwargs): + super(FPNHead, self).__init__( + input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.in_channels[i] if k == 0 else self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/gc_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/gc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..eed50742514d4d6cd3d861583f8ee85575e51373 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/gc_head.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ContextBlock + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class GCHead(FCNHead): + """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. + + This head is the implementation of `GCNet + `_. + + Args: + ratio (float): Multiplier of channels ratio. Default: 1/4. + pooling_type (str): The pooling type of context aggregation. + Options are 'att', 'avg'. Default: 'avg'. + fusion_types (tuple[str]): The fusion type for feature fusion. + Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) + """ + + def __init__(self, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + **kwargs): + super(GCHead, self).__init__(num_convs=2, **kwargs) + self.ratio = ratio + self.pooling_type = pooling_type + self.fusion_types = fusion_types + self.gc_block = ContextBlock( + in_channels=self.channels, + ratio=self.ratio, + pooling_type=self.pooling_type, + fusion_types=self.fusion_types) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.gc_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/ham_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/ham_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0546e394f032761132699db4b965320da5ac48d6 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/ham_head.py @@ -0,0 +1,252 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class Matrix_Decomposition_2D_Base(nn.Module): + """Base class of 2D Matrix Decomposition. + Args: + MD_S (int): The number of spatial coefficient in + Matrix Decomposition, it may be used for calculation + of the number of latent dimension D in Matrix + Decomposition. Defaults: 1. + MD_R (int): The number of latent dimension R in + Matrix Decomposition. Defaults: 64. + train_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in training. Defaults: 6. + eval_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in evaluation. Defaults: 7. + inv_t (int): Inverted multiple number to make coefficient + smaller in softmax. Defaults: 100. + rand_init (bool): Whether to initialize randomly. + Defaults: True. + """ + + def __init__(self, + MD_S=1, + MD_R=64, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True): + super().__init__() + + self.S = MD_S + self.R = MD_R + + self.train_steps = train_steps + self.eval_steps = eval_steps + + self.inv_t = inv_t + + self.rand_init = rand_init + + def _build_bases(self, B, S, D, R, cuda=False): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + """Forward Function.""" + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + cuda = 'cuda' in str(x.device) + if not self.rand_init and not hasattr(self, 'bases'): + bases = self._build_bases(1, self.S, D, self.R, cuda=cuda) + self.register_buffer('bases', bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R, cuda=cuda) + else: + bases = self.bases.repeat(B, 1, 1) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + + return x + + +class NMF2D(Matrix_Decomposition_2D_Base): + """Non-negative Matrix Factorization (NMF) module. + It is inherited from ``Matrix_Decomposition_2D_Base`` module. + """ + + def __init__(self, args=dict()): + super().__init__(**args) + + self.inv_t = 1 + + def _build_bases(self, B, S, D, R, cuda=False): + """Build bases in initialization.""" + if cuda: + bases = torch.rand((B * S, D, R)).cuda() + else: + bases = torch.rand((B * S, D, R)) + + bases = F.normalize(bases, dim=1) + + return bases + + def local_step(self, x, bases, coef): + """Local step in iteration to renew bases and coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + """Compute coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class Hamburger(nn.Module): + """Hamburger Module. It consists of one slice of "ham" (matrix + decomposition) and two slices of "bread" (linear transformation). + Args: + ham_channels (int): Input and output channels of feature. + ham_kwargs (dict): Config of matrix decomposition module. + norm_cfg (dict | None): Config of norm layers. + """ + + def __init__(self, + ham_channels=512, + ham_kwargs=dict(), + norm_cfg=None, + **kwargs): + super().__init__() + + self.ham_in = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) + + self.ham = NMF2D(ham_kwargs) + + self.ham_out = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + + def forward(self, x): + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=True) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=True) + + return ham + + +@HEADS.register_module() +class LightHamHead(BaseDecodeHead): + """SegNeXt decode head. + This decode head is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + Specifically, LightHamHead is inspired by HamNet from + `Is Attention Better Than Matrix Decomposition? + `. + Args: + ham_channels (int): input channels for Hamburger. + Defaults: 512. + ham_kwargs (int): kwagrs for Ham. Defaults: dict(). + """ + + def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs): + super(LightHamHead, self).__init__( + input_transform='multiple_select', **kwargs) + self.ham_channels = ham_channels + + self.squeeze = ConvModule( + sum(self.in_channels), + self.ham_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs) + + self.align = ConvModule( + self.ham_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + inputs = [ + resize( + level, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + # apply a conv block to squeeze feature map + x = self.squeeze(inputs) + # apply hamburger module + x = self.hamburger(x) + + # apply a conv block to align feature map + output = self.align(x) + output = self.cls_seg(output) + return output \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/isa_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/isa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf3455571e315c9f50279e57b04549a6f1f3af8 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/isa_head.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Self-Attention Module. + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict | None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): + super(SelfAttentionBlock, self).__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.output_project = self.build_project( + in_channels, + in_channels, + num_convs=1, + use_conv_module=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + """Forward function.""" + context = super(SelfAttentionBlock, self).forward(x, x) + return self.output_project(context) + + +@HEADS.register_module() +class ISAHead(BaseDecodeHead): + """Interlaced Sparse Self-Attention for Semantic Segmentation. + + This head is the implementation of `ISA + `_. + + Args: + isa_channels (int): The channels of ISA Module. + down_factor (tuple[int]): The local group size of ISA. + """ + + def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): + super(ISAHead, self).__init__(**kwargs) + self.down_factor = down_factor + + self.in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.global_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.local_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.out_conv = ConvModule( + self.channels * 2, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x_ = self._transform_inputs(inputs) + x = self.in_conv(x_) + residual = x + + n, c, h, w = x.size() + loc_h, loc_w = self.down_factor # size of local group in H- and W-axes + glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w) + pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w + if pad_h > 0 or pad_w > 0: # pad if the size is not divisible + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2) + x = F.pad(x, padding) + + # global relation + x = x.view(n, c, glb_h, loc_h, glb_w, loc_w) + # do permutation to gather global group + x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w) + x = x.reshape(-1, c, glb_h, glb_w) + # apply attention within each global group + x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w) + + # local relation + x = x.view(n, loc_h, loc_w, c, glb_h, glb_w) + # do permutation to gather local group + x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w) + x = x.reshape(-1, c, loc_h, loc_w) + # apply attention within each local group + x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w) + + # permute each pixel back to its original position + x = x.view(n, glb_h, glb_w, c, loc_h, loc_w) + x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w) + x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w) + if pad_h > 0 or pad_w > 0: # remove padding + x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w] + + x = self.out_conv(torch.cat([x, residual], dim=1)) + out = self.cls_seg(x) + + return out diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/knet_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/knet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..78a270277fddcfc86cef6c85c4290666b54cafeb --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/knet_head.py @@ -0,0 +1,457 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER, + MultiheadAttention, + build_transformer_layer) + +from mmseg.models.builder import HEADS, build_head +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.utils import get_root_logger + + +@TRANSFORMER_LAYER.register_module() +class KernelUpdator(nn.Module): + """Dynamic Kernel Updator in Kernel Update Head. + + Args: + in_channels (int): The number of channels of input feature map. + Default: 256. + feat_channels (int): The number of middle-stage channels in + the kernel updator. Default: 64. + out_channels (int): The number of output channels. + gate_sigmoid (bool): Whether use sigmoid function in gate + mechanism. Default: True. + gate_norm_act (bool): Whether add normalization and activation + layer in gate mechanism. Default: False. + activate_out: Whether add activation after gate mechanism. + Default: False. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='LN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='ReLU', inplace=True), + ): + super(KernelUpdator, self).__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + """Forward function of KernelUpdator. + + Args: + update_feature (torch.Tensor): Feature map assembled from + each group. It would be reshaped with last dimension + shape: `self.in_channels`. + input_feature (torch.Tensor): Intermediate feature + with shape: (N, num_classes, conv_kernel_size**2, channels). + Returns: + Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is + the number of classes, C1 and C2 are the feature map channels of + KernelUpdateHead and KernelUpdator, respectively. + """ + + update_feature = update_feature.reshape(-1, self.in_channels) + num_proposals = update_feature.size(0) + # dynamic_layer works for + # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper + parameters = self.dynamic_layer(update_feature) + param_in = parameters[:, :self.num_params_in].view( + -1, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels) + + # input_layer works for + # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper + input_feats = self.input_layer( + input_feature.reshape(num_proposals, -1, self.feat_channels)) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + # `gate_feats` is F^G in K-Net paper + gate_feats = input_in * param_in.unsqueeze(-2) + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # Gate mechanism. Eq.(5) in original paper. + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out.unsqueeze( + -2) + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +@HEADS.register_module() +class KernelUpdateHead(nn.Module): + """Kernel Update Head in K-Net. + + Args: + num_classes (int): Number of classes. Default: 150. + num_ffn_fcs (int): The number of fully-connected layers in + FFNs. Default: 2. + num_heads (int): The number of parallel attention heads. + Default: 8. + num_mask_fcs (int): The number of fully connected layers for + mask prediction. Default: 3. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 2048. + in_channels (int): The number of channels of input feature map. + Default: 256. + out_channels (int): The number of output channels. + Default: 256. + dropout (float): The Probability of an element to be + zeroed in MultiheadAttention and FFN. Default 0.0. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + ffn_act_cfg (dict): Config of activation layers in FFN. + Default: dict(type='ReLU'). + conv_kernel_size (int): The kernel size of convolution in + Kernel Update Head for dynamic kernel updation. + Default: 1. + feat_transform_cfg (dict | None): Config of feature transform. + Default: None. + kernel_init (bool): Whether initiate mask kernel in mask head. + Default: False. + with_ffn (bool): Whether add FFN in kernel update head. + Default: True. + feat_gather_stride (int): Stride of convolution in feature transform. + Default: 1. + mask_transform_stride (int): Stride of mask transform. + Default: 1. + kernel_updator_cfg (dict): Config of kernel updator. + Default: dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')). + """ + + def __init__(self, + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=1, + feat_transform_cfg=None, + kernel_init=False, + with_ffn=True, + feat_gather_stride=1, + mask_transform_stride=1, + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))): + super(KernelUpdateHead, self).__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.fp16_enabled = False + self.dropout = dropout + self.num_heads = num_heads + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + transform_channels = in_channels + self.feat_transform = ConvModule( + transform_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + dropout=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.kernel_init: + logger = get_root_logger() + logger.info( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, x, proposal_feat, mask_preds, mask_shape=None): + """Forward function of Dynamic Instance Interactive Head. + + Args: + x (Tensor): Feature map from FPN with shape + (batch_size, feature_dimensions, H , W). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size, num_proposals, feature_dimensions) + mask_preds (Tensor): mask prediction from the former stage in shape + (batch_size, num_proposals, H, W). + + Returns: + Tuple: The first tensor is predicted mask with shape + (N, num_classes, H, W), the second tensor is dynamic kernel + with shape (N, num_classes, channels, K, K). + """ + N, num_proposals = proposal_feat.shape[:2] + if self.feat_transform is not None: + x = self.feat_transform(x) + + C, H, W = x.shape[-3:] + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds, (H, W), align_corners=False, mode='bilinear') + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.softmax(dim=1) + + # Group Feature Assembling. Eq.(3) in original paper. + # einsum is faster than bmm by 30% + x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + mask_feat = obj_feat + + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # but in real training group conv is slower than concat batch + # so we keep using concat batch. + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i:i + 1], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + + return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_proposals, self.in_channels, self.conv_kernel_size, + self.conv_kernel_size) + + +@HEADS.register_module() +class IterativeDecodeHead(BaseDecodeHead): + """K-Net: Towards Unified Image Segmentation. + + This head is the implementation of + `K-Net: `_. + + Args: + num_stages (int): The number of stages (kernel update heads) + in IterativeDecodeHead. Default: 3. + kernel_generate_head:(dict): Config of kernel generate head which + generate mask predictions, dynamic kernels and class predictions + for next kernel update heads. + kernel_update_head (dict): Config of kernel update head which refine + dynamic kernels and class predictions iteratively. + + """ + + def __init__(self, num_stages, kernel_generate_head, kernel_update_head, + **kwargs): + # ``IterativeDecodeHead`` would skip initialization of + # ``BaseDecodeHead`` which would be called when building + # ``self.kernel_generate_head``. + super(BaseDecodeHead, self).__init__(**kwargs) + assert num_stages == len(kernel_update_head) + self.num_stages = num_stages + self.kernel_generate_head = build_head(kernel_generate_head) + self.kernel_update_head = nn.ModuleList() + self.align_corners = self.kernel_generate_head.align_corners + self.num_classes = self.kernel_generate_head.num_classes + self.input_transform = self.kernel_generate_head.input_transform + self.ignore_index = self.kernel_generate_head.ignore_index + self.out_channels = self.num_classes + + for head_cfg in kernel_update_head: + self.kernel_update_head.append(build_head(head_cfg)) + + def forward(self, inputs): + """Forward function.""" + feats = self.kernel_generate_head._forward_feature(inputs) + sem_seg = self.kernel_generate_head.cls_seg(feats) + seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() + seg_kernels = seg_kernels[None].expand( + feats.size(0), *seg_kernels.size()) + + stage_segs = [sem_seg] + for i in range(self.num_stages): + sem_seg, seg_kernels = self.kernel_update_head[i](feats, + seg_kernels, + sem_seg) + stage_segs.append(sem_seg) + if self.training: + return stage_segs + # only return the prediction of the last stage during testing + return stage_segs[-1] + + def losses(self, seg_logit, seg_label): + losses = dict() + for i, logit in enumerate(seg_logit): + loss = self.kernel_generate_head.losses(logit, seg_label) + for k, v in loss.items(): + losses[f'{k}.s{i}'] = v + + return losses diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/lraspp_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/lraspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c10ff0d822fa29d4dfa1c8a1e047e4540b1d5064 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/lraspp_head.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv import is_tuple_of +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class LRASPPHead(BaseDecodeHead): + """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. + + This head is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + branch_channels (tuple[int]): The number of output channels in every + each branch. Default: (32, 64). + """ + + def __init__(self, branch_channels=(32, 64), **kwargs): + super(LRASPPHead, self).__init__(**kwargs) + if self.input_transform != 'multiple_select': + raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' + f'must be \'multiple_select\'. But received ' + f'\'{self.input_transform}\'') + assert is_tuple_of(branch_channels, int) + assert len(branch_channels) == len(self.in_channels) - 1 + self.branch_channels = branch_channels + + self.convs = nn.Sequential() + self.conv_ups = nn.Sequential() + for i in range(len(branch_channels)): + self.convs.add_module( + f'conv{i}', + nn.Conv2d( + self.in_channels[i], branch_channels[i], 1, bias=False)) + self.conv_ups.add_module( + f'conv_up{i}', + ConvModule( + self.channels + branch_channels[i], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False)) + + self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) + + self.aspp_conv = ConvModule( + self.in_channels[-1], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False) + self.image_pool = nn.Sequential( + nn.AvgPool2d(kernel_size=49, stride=(16, 20)), + ConvModule( + self.in_channels[2], + self.channels, + 1, + act_cfg=dict(type='Sigmoid'), + bias=False)) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + x = inputs[-1] + + x = self.aspp_conv(x) * resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = self.conv_up_input(x) + + for i in range(len(self.branch_channels) - 1, -1, -1): + x = resize( + x, + size=inputs[i].size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = torch.cat([x, self.convs[i](inputs[i])], 1) + x = self.conv_ups[i](x) + + return self.cls_seg(x) diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/nl_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/nl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..637517e7a0e634f733049c6c4c2001cd844b22c1 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/nl_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class NLHead(FCNHead): + """Non-local Neural Networks. + + This head is the implementation of `NLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: True. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + **kwargs): + super(NLHead, self).__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.nl_block = NonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.nl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/ocr_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/ocr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..09eadfb1a6a61b752ab1ece5233617a0cb3a3a2b --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/ocr_head.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .cascade_decode_head import BaseCascadeDecodeHead + + +class SpatialGatherModule(nn.Module): + """Aggregate the context features according to the initial predicted + probability distribution. + + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, scale): + super(SpatialGatherModule, self).__init__() + self.scale = scale + + def forward(self, feats, probs): + """Forward function.""" + batch_size, num_classes, height, width = probs.size() + channels = feats.size(1) + probs = probs.view(batch_size, num_classes, -1) + feats = feats.view(batch_size, channels, -1) + # [batch_size, height*width, num_classes] + feats = feats.permute(0, 2, 1) + # [batch_size, channels, height*width] + probs = F.softmax(self.scale * probs, dim=2) + # [batch_size, channels, num_classes] + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(_SelfAttentionBlock): + """Make a OCR used SelfAttentionBlock.""" + + def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, + act_cfg): + if scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=scale) + else: + query_downsample = None + super(ObjectAttentionBlock, self).__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=query_downsample, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.bottleneck = ConvModule( + in_channels * 2, + in_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, query_feats, key_feats): + """Forward function.""" + context = super(ObjectAttentionBlock, + self).forward(query_feats, key_feats) + output = self.bottleneck(torch.cat([context, query_feats], dim=1)) + if self.query_downsample is not None: + output = resize(query_feats) + + return output + + +@HEADS.register_module() +class OCRHead(BaseCascadeDecodeHead): + """Object-Contextual Representations for Semantic Segmentation. + + This head is the implementation of `OCRNet + `_. + + Args: + ocr_channels (int): The intermediate channels of OCR block. + scale (int): The scale of probability map in SpatialGatherModule in + Default: 1. + """ + + def __init__(self, ocr_channels, scale=1, **kwargs): + super(OCRHead, self).__init__(**kwargs) + self.ocr_channels = ocr_channels + self.scale = scale + self.object_context_block = ObjectAttentionBlock( + self.channels, + self.ocr_channels, + self.scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.spatial_gather_module = SpatialGatherModule(self.scale) + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs, prev_output): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.bottleneck(x) + context = self.spatial_gather_module(feats, prev_output) + object_context = self.object_context_block(feats, context) + output = self.cls_seg(object_context) + + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/point_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/point_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5e605271c79f402f12093eae66d90425a9406825 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/point_head.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +try: + from mmcv.ops import point_sample +except ModuleNotFoundError: + point_sample = None + +from mmseg.models.builder import HEADS +from mmseg.ops import resize +from ..losses import accuracy +from .cascade_decode_head import BaseCascadeDecodeHead + + +def calculate_uncertainty(seg_logits): + """Estimate uncertainty based on seg logits. + + For each location of the prediction ``seg_logits`` we estimate + uncertainty as the difference between top first and top second + predicted logits. + + Args: + seg_logits (Tensor): Semantic segmentation logits, + shape (batch_size, num_classes, height, width). + + Returns: + scores (Tensor): T uncertainty scores with the most uncertain + locations having the highest uncertainty score, shape ( + batch_size, 1, height, width) + """ + top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +@HEADS.register_module() +class PointHead(BaseCascadeDecodeHead): + """A mask point head use in PointRend. + + This head is implemented of `PointRend: Image Segmentation as + Rendering `_. + ``PointHead`` use shared multi-layer perceptron (equivalent to + nn.Conv1d) to predict the logit of input points. The fine-grained feature + and coarse feature will be concatenate together for predication. + + Args: + num_fcs (int): Number of fc layers in the head. Default: 3. + in_channels (int): Number of input channels. Default: 256. + fc_channels (int): Number of fc channels. Default: 256. + num_classes (int): Number of classes for logits. Default: 80. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Default: False. + coarse_pred_each_layer (bool): Whether concatenate coarse feature with + the output of each fc layer. Default: True. + conv_cfg (dict|None): Dictionary to construct and config conv layer. + Default: dict(type='Conv1d')) + norm_cfg (dict|None): Dictionary to construct and config norm layer. + Default: None. + loss_point (dict): Dictionary to construct and config loss layer of + point head. Default: dict(type='CrossEntropyLoss', use_mask=True, + loss_weight=1.0). + """ + + def __init__(self, + num_fcs=3, + coarse_pred_each_layer=True, + conv_cfg=dict(type='Conv1d'), + norm_cfg=None, + act_cfg=dict(type='ReLU', inplace=False), + **kwargs): + super(PointHead, self).__init__( + input_transform='multiple_select', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='fc_seg')), + **kwargs) + if point_sample is None: + raise RuntimeError('Please install mmcv-full for ' + 'point_sample ops') + + self.num_fcs = num_fcs + self.coarse_pred_each_layer = coarse_pred_each_layer + + fc_in_channels = sum(self.in_channels) + self.num_classes + fc_channels = self.channels + self.fcs = nn.ModuleList() + for k in range(num_fcs): + fc = ConvModule( + fc_in_channels, + fc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.fcs.append(fc) + fc_in_channels = fc_channels + fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ + else 0 + self.fc_seg = nn.Conv1d( + fc_in_channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0) + if self.dropout_ratio > 0: + self.dropout = nn.Dropout(self.dropout_ratio) + delattr(self, 'conv_seg') + + def cls_seg(self, feat): + """Classify each pixel with fc.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.fc_seg(feat) + return output + + def forward(self, fine_grained_point_feats, coarse_point_feats): + x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) + for fc in self.fcs: + x = fc(x) + if self.coarse_pred_each_layer: + x = torch.cat((x, coarse_point_feats), dim=1) + return self.cls_seg(x) + + def _get_fine_grained_point_feats(self, x, points): + """Sample from fine grained features. + + Args: + x (list[Tensor]): Feature pyramid from by neck or backbone. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + fine_grained_feats (Tensor): Sampled fine grained feature, + shape (batch_size, sum(channels of x), num_points). + """ + + fine_grained_feats_list = [ + point_sample(_, points, align_corners=self.align_corners) + for _ in x + ] + if len(fine_grained_feats_list) > 1: + fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) + else: + fine_grained_feats = fine_grained_feats_list[0] + + return fine_grained_feats + + def _get_coarse_point_feats(self, prev_output, points): + """Sample from fine grained features. + + Args: + prev_output (list[Tensor]): Prediction of previous decode head. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, + num_classes, num_points). + """ + + coarse_feats = point_sample( + prev_output, points, align_corners=self.align_corners) + + return coarse_feats + + def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, + train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self._transform_inputs(inputs) + with torch.no_grad(): + points = self.get_points_train( + prev_output, calculate_uncertainty, cfg=train_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats(prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + point_label = point_sample( + gt_semantic_seg.float(), + points, + mode='nearest', + align_corners=self.align_corners) + point_label = point_label.squeeze(1).long() + + losses = self.losses(point_logits, point_label) + + return losses + + def forward_test(self, inputs, prev_output, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + + x = self._transform_inputs(inputs) + refined_seg_logits = prev_output.clone() + for _ in range(test_cfg.subdivision_steps): + refined_seg_logits = resize( + refined_seg_logits, + scale_factor=test_cfg.scale_factor, + mode='bilinear', + align_corners=self.align_corners) + batch_size, channels, height, width = refined_seg_logits.shape + point_indices, points = self.get_points_test( + refined_seg_logits, calculate_uncertainty, cfg=test_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats( + prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) + refined_seg_logits = refined_seg_logits.reshape( + batch_size, channels, height * width) + refined_seg_logits = refined_seg_logits.scatter_( + 2, point_indices, point_logits) + refined_seg_logits = refined_seg_logits.view( + batch_size, channels, height, width) + + return refined_seg_logits + + def losses(self, point_logits, point_label): + """Compute segmentation loss.""" + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_module in losses_decode: + loss['point' + loss_module.loss_name] = loss_module( + point_logits, point_label, ignore_index=self.ignore_index) + + loss['acc_point'] = accuracy( + point_logits, point_label, ignore_index=self.ignore_index) + return loss + + def get_points_train(self, seg_logits, uncertainty_func, cfg): + """Sample points for training. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'uncertainty_func' function that takes point's logit prediction as + input. + + Args: + seg_logits (Tensor): Semantic segmentation logits, shape ( + batch_size, num_classes, height, width). + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Training config of point head. + + Returns: + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains the coordinates of ``num_points`` sampled + points. + """ + num_points = cfg.num_points + oversample_ratio = cfg.oversample_ratio + importance_sample_ratio = cfg.importance_sample_ratio + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = seg_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=seg_logits.device) + point_logits = point_sample(seg_logits, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=seg_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_point_coords = torch.rand( + batch_size, num_random_points, 2, device=seg_logits.device) + point_coords = torch.cat((point_coords, rand_point_coords), dim=1) + return point_coords + + def get_points_test(self, seg_logits, uncertainty_func, cfg): + """Sample points for testing. + + Find ``num_points`` most uncertain points from ``uncertainty_map``. + + Args: + seg_logits (Tensor): A tensor of shape (batch_size, num_classes, + height, width) for class-specific or class-agnostic prediction. + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Testing config of point head. + + Returns: + point_indices (Tensor): A tensor of shape (batch_size, num_points) + that contains indices from [0, height x width) of the most + uncertain points. + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the ``height x width`` grid . + """ + + num_points = cfg.subdivision_num_points + uncertainty_map = uncertainty_func(seg_logits) + batch_size, _, height, width = uncertainty_map.shape + h_step = 1.0 / height + w_step = 1.0 / width + + uncertainty_map = uncertainty_map.view(batch_size, height * width) + num_points = min(height * width, num_points) + point_indices = uncertainty_map.topk(num_points, dim=1)[1] + point_coords = torch.zeros( + batch_size, + num_points, + 2, + dtype=torch.float, + device=seg_logits.device) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % + width).float() * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // + width).float() * h_step + return point_indices, point_coords diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/psa_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/psa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..df7593cbcbd3c0de30d6a1cc7f5f1e0b8aa9e965 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/psa_head.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + +try: + from mmcv.ops import PSAMask +except ModuleNotFoundError: + PSAMask = None + + +@HEADS.register_module() +class PSAHead(BaseDecodeHead): + """Point-wise Spatial Attention Network for Scene Parsing. + + This head is the implementation of `PSANet + `_. + + Args: + mask_size (tuple[int]): The PSA mask size. It usually equals input + size. + psa_type (str): The type of psa module. Options are 'collect', + 'distribute', 'bi-direction'. Default: 'bi-direction' + compact (bool): Whether use compact map for 'collect' mode. + Default: True. + shrink_factor (int): The downsample factors of psa mask. Default: 2. + normalization_factor (float): The normalize factor of attention. + psa_softmax (bool): Whether use softmax for attention. + """ + + def __init__(self, + mask_size, + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + **kwargs): + if PSAMask is None: + raise RuntimeError('Please install mmcv-full for PSAMask ops') + super(PSAHead, self).__init__(**kwargs) + assert psa_type in ['collect', 'distribute', 'bi-direction'] + self.psa_type = psa_type + self.compact = compact + self.shrink_factor = shrink_factor + self.mask_size = mask_size + mask_h, mask_w = mask_size + self.psa_softmax = psa_softmax + if normalization_factor is None: + normalization_factor = mask_h * mask_w + self.normalization_factor = normalization_factor + + self.reduce = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + if psa_type == 'bi-direction': + self.reduce_p = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention_p = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + self.psamask_collect = PSAMask('collect', mask_size) + self.psamask_distribute = PSAMask('distribute', mask_size) + else: + self.psamask = PSAMask(psa_type, mask_size) + self.proj = ConvModule( + self.channels * (2 if psa_type == 'bi-direction' else 1), + self.in_channels, + kernel_size=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + self.in_channels * 2, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + identity = x + align_corners = self.align_corners + if self.psa_type in ['collect', 'distribute']: + out = self.reduce(x) + n, c, h, w = out.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + out = resize( + out, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y = self.attention(out) + if self.compact: + if self.psa_type == 'collect': + y = y.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y = self.psamask(y) + if self.psa_softmax: + y = F.softmax(y, dim=1) + out = torch.bmm( + out.view(n, c, h * w), y.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + else: + x_col = self.reduce(x) + x_dis = self.reduce_p(x) + n, c, h, w = x_col.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + x_col = resize( + x_col, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + x_dis = resize( + x_dis, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y_col = self.attention(x_col) + y_dis = self.attention_p(x_dis) + if self.compact: + y_dis = y_dis.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y_col = self.psamask_collect(y_col) + y_dis = self.psamask_distribute(y_dis) + if self.psa_softmax: + y_col = F.softmax(y_col, dim=1) + y_dis = F.softmax(y_dis, dim=1) + x_col = torch.bmm( + x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + x_dis = torch.bmm( + x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + out = torch.cat([x_col, x_dis], 1) + out = self.proj(out) + out = resize( + out, + size=identity.shape[2:], + mode='bilinear', + align_corners=align_corners) + out = self.bottleneck(torch.cat((identity, out), dim=1)) + out = self.cls_seg(out) + return out diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/psp_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/psp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6990676ff24fbf7c6ff1f677efc665f778e9f13b --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/psp_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +class PPM(nn.ModuleList): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, + act_cfg, align_corners, **kwargs): + super(PPM, self).__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **kwargs))) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = resize( + ppm_out, + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +@HEADS.register_module() +class PSPHead(BaseDecodeHead): + """Pyramid Scene Parsing Network. + + This head is the implementation of + `PSPNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super(PSPHead, self).__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.psp_modules = PPM( + self.pool_scales, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + feats = self.bottleneck(psp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/segformer_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/segformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..572ceac21566cc0a12365c45a8cbd7b857bdf6c7 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/segformer_head.py @@ -0,0 +1,142 @@ +# Modified from +# https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/decode_heads/segformer_head.py +# +# This work is licensed under the NVIDIA Source Code License. +# +# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. +# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator +# Augmentation (ADA) +# +# 1. Definitions +# "Licensor" means any person or entity that distributes its Work. +# "Software" means the original work of authorship made available under +# this License. +# "Work" means the Software and any additions to or derivative works of +# the Software that are made available under this License. +# The terms "reproduce," "reproduction," "derivative works," and +# "distribution" have the meaning as provided under U.S. copyright law; +# provided, however, that for the purposes of this License, derivative +# works shall not include works that remain separable from, or merely +# link (or bind by name) to the interfaces of, the Work. +# Works, including the Software, are "made available" under this License +# by including in or with the Work either (a) a copyright notice +# referencing the applicability of this License to the Work, or (b) a +# copy of this License. +# 2. License Grants +# 2.1 Copyright Grant. Subject to the terms and conditions of this +# License, each Licensor grants to you a perpetual, worldwide, +# non-exclusive, royalty-free, copyright license to reproduce, +# prepare derivative works of, publicly display, publicly perform, +# sublicense and distribute its Work and any resulting derivative +# works in any form. +# 3. Limitations +# 3.1 Redistribution. You may reproduce or distribute the Work only +# if (a) you do so under this License, (b) you include a complete +# copy of this License with your distribution, and (c) you retain +# without modification any copyright, patent, trademark, or +# attribution notices that are present in the Work. +# 3.2 Derivative Works. You may specify that additional or different +# terms apply to the use, reproduction, and distribution of your +# derivative works of the Work ("Your Terms") only if (a) Your Terms +# provide that the use limitation in Section 3.3 applies to your +# derivative works, and (b) you identify the specific derivative +# works that are subject to Your Terms. Notwithstanding Your Terms, +# this License (including the redistribution requirements in Section +# 3.1) will continue to apply to the Work itself. +# 3.3 Use Limitation. The Work and any derivative works thereof only +# may be used or intended for use non-commercially. Notwithstanding +# the foregoing, NVIDIA and its affiliates may use the Work and any +# derivative works commercially. As used herein, "non-commercially" +# means for research or evaluation purposes only. +# 3.4 Patent Claims. If you bring or threaten to bring a patent claim +# against any Licensor (including any claim, cross-claim or +# counterclaim in a lawsuit) to enforce any patents that you allege +# are infringed by any Work, then your rights under this License from +# such Licensor (including the grant in Section 2.1) will terminate +# immediately. +# 3.5 Trademarks. This License does not grant any rights to use any +# Licensor’s or its affiliates’ names, logos, or trademarks, except +# as necessary to reproduce the notices described in this License. +# 3.6 Termination. If you violate any term of this License, then your +# rights under this License (including the grant in Section 2.1) will +# terminate immediately. +# 4. Disclaimer of Warranty. +# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +# THIS LICENSE. +# 5. Limitation of Liability. +# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGES. + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.models.builder import HEADS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.ops import resize + + +@HEADS.register_module() +class SegformerHead(BaseDecodeHead): + """The all mlp Head of segformer. + This head is the implementation of + `Segformer ` _. + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + out = self.cls_seg(out) + + return out \ No newline at end of file diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/segmenter_mask_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/segmenter_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9b3d47ec5de0543e01aeb3df8092e6dbd5a7ea --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/segmenter_mask_head.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) +from mmcv.runner import ModuleList + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class SegmenterMaskTransformerHead(BaseDecodeHead): + """Segmenter: Transformer for Semantic Segmentation. + + This head is the implementation of + `Segmenter: `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input image. + num_layers (int): The depth of transformer. + num_heads (int): The number of attention heads. + embed_dims (int): The number of embedding dimension. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_path_rate (float): stochastic depth rate. Default 0.1. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + init_std (float): The value of std in weight initialization. + Default: 0.02. + """ + + def __init__( + self, + in_channels, + num_layers, + num_heads, + embed_dims, + mlp_ratio=4, + drop_path_rate=0.1, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_std=0.02, + **kwargs, + ): + super(SegmenterMaskTransformerHead, self).__init__( + in_channels=in_channels, **kwargs) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + batch_first=True, + )) + + self.dec_proj = nn.Linear(in_channels, embed_dims) + + self.cls_emb = nn.Parameter( + torch.randn(1, self.num_classes, embed_dims)) + self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) + self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) + + self.decoder_norm = build_norm_layer( + norm_cfg, embed_dims, postfix=1)[1] + self.mask_norm = build_norm_layer( + norm_cfg, self.num_classes, postfix=2)[1] + + self.init_std = init_std + + delattr(self, 'conv_seg') + + def init_weights(self): + trunc_normal_(self.cls_emb, std=self.init_std) + trunc_normal_init(self.patch_proj, std=self.init_std) + trunc_normal_init(self.classes_proj, std=self.init_std) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=self.init_std, bias=0) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.0) + + def forward(self, inputs): + x = self._transform_inputs(inputs) + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) + + x = self.dec_proj(x) + cls_emb = self.cls_emb.expand(x.size(0), -1, -1) + x = torch.cat((x, cls_emb), 1) + for layer in self.layers: + x = layer(x) + x = self.decoder_norm(x) + + patches = self.patch_proj(x[:, :-self.num_classes]) + cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) + + patches = F.normalize(patches, dim=2, p=2) + cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) + + masks = patches @ cls_seg_feat.transpose(1, 2) + masks = self.mask_norm(masks) + masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) + + return masks diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/sep_aspp_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/sep_aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4e894e28e3b192ce59d83ef09ec3105ed1436ae8 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/sep_aspp_head.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .aspp_head import ASPPHead, ASPPModule + + +class DepthwiseSeparableASPPModule(ASPPModule): + """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable + conv.""" + + def __init__(self, **kwargs): + super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) + for i, dilation in enumerate(self.dilations): + if dilation > 1: + self[i] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + 3, + dilation=dilation, + padding=dilation, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + +@HEADS.register_module() +class DepthwiseSeparableASPPHead(ASPPHead): + """Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation. + + This head is the implementation of `DeepLabV3+ + `_. + + Args: + c1_in_channels (int): The input channels of c1 decoder. If is 0, + the no decoder will be used. + c1_channels (int): The intermediate channels of c1 decoder. + """ + + def __init__(self, c1_in_channels, c1_channels, **kwargs): + super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) + assert c1_in_channels >= 0 + self.aspp_modules = DepthwiseSeparableASPPModule( + dilations=self.dilations, + in_channels=self.in_channels, + channels=self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if c1_in_channels > 0: + self.c1_bottleneck = ConvModule( + c1_in_channels, + c1_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + else: + self.c1_bottleneck = None + self.sep_bottleneck = nn.Sequential( + DepthwiseSeparableConvModule( + self.channels + c1_channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + DepthwiseSeparableConvModule( + self.channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + if self.c1_bottleneck is not None: + c1_output = self.c1_bottleneck(inputs[0]) + output = resize( + input=output, + size=c1_output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = torch.cat([output, c1_output], dim=1) + output = self.sep_bottleneck(output) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/sep_fcn_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/sep_fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7f9658e08fbf3cd38e7c573acf3501e131d94198 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/sep_fcn_head.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import DepthwiseSeparableConvModule + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class DepthwiseSeparableFCNHead(FCNHead): + """Depthwise-Separable Fully Convolutional Network for Semantic + Segmentation. + + This head is implemented according to `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels(int): Number of output channels of FFM. + channels(int): Number of middle-stage channels in the decode head. + concat_input(bool): Whether to concatenate original decode input into + the result of several consecutive convolution layers. + Default: True. + num_classes(int): Used to determine the dimension of + final prediction tensor. + in_index(int): Correspond with 'out_indices' in FastSCNN backbone. + norm_cfg (dict | None): Config of norm layers. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_decode(dict): Config of loss type and some + relevant additional options. + dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: None. + """ + + def __init__(self, dw_act_cfg=None, **kwargs): + super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) + self.convs[0] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + for i in range(1, self.num_convs): + self.convs[i] = DepthwiseSeparableConvModule( + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + if self.concat_input: + self.conv_cat = DepthwiseSeparableConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/setr_mla_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/setr_mla_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb94ae33007aa29923680fd2e558c15320bec95 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/setr_mla_head.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import Upsample +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class SETRMLAHead(BaseDecodeHead): + """Multi level feature aggretation head of SETR. + + MLA head of `SETR `_. + + Args: + mlahead_channels (int): Channels of conv-conv-4x of multi-level feature + aggregation. Default: 128. + up_scale (int): The scale factor of interpolate. Default:4. + """ + + def __init__(self, mla_channels=128, up_scale=4, **kwargs): + super(SETRMLAHead, self).__init__( + input_transform='multiple_select', **kwargs) + self.mla_channels = mla_channels + + num_inputs = len(self.in_channels) + + # Refer to self.cls_seg settings of BaseDecodeHead + assert self.channels == num_inputs * mla_channels + + self.up_convs = nn.ModuleList() + for i in range(num_inputs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=self.in_channels[i], + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + in_channels=mla_channels, + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + + def forward(self, inputs): + inputs = self._transform_inputs(inputs) + outs = [] + for x, up_conv in zip(inputs, self.up_convs): + outs.append(up_conv(x)) + out = torch.cat(outs, dim=1) + out = self.cls_seg(out) + return out diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/setr_up_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/setr_up_head.py new file mode 100644 index 0000000000000000000000000000000000000000..87e7ea7faa5228e22847f5739055e17dc5415f32 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/setr_up_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.ops import Upsample +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class SETRUPHead(BaseDecodeHead): + """Naive upsampling head and Progressive upsampling head of SETR. + + Naive or PUP head of `SETR `_. + + Args: + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + num_convs (int): Number of decoder convolutions. Default: 1. + up_scale (int): The scale factor of interpolate. Default:4. + kernel_size (int): The kernel size of convolution when decoding + feature information from backbone. Default: 3. + init_cfg (dict | list[dict] | None): Initialization config dict. + Default: dict( + type='Constant', val=1.0, bias=0, layer='LayerNorm'). + """ + + def __init__(self, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + num_convs=1, + up_scale=4, + kernel_size=3, + init_cfg=[ + dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), + dict( + type='Normal', + std=0.01, + override=dict(name='conv_seg')) + ], + **kwargs): + + assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' + + super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) + + assert isinstance(self.in_channels, int) + + _, self.norm = build_norm_layer(norm_layer, self.in_channels) + + self.up_convs = nn.ModuleList() + in_channels = self.in_channels + out_channels = self.channels + for _ in range(num_convs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=int(kernel_size - 1) // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + in_channels = out_channels + + def forward(self, x): + x = self._transform_inputs(x) + + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + + for up_conv in self.up_convs: + x = up_conv(x) + out = self.cls_seg(x) + return out diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/stdc_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/stdc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bddf1eb47fd291448300ebac8af7fe4a98579c71 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/stdc_head.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + +from ..builder import HEADS +from .fcn_head import FCNHead + + +@HEADS.register_module() +class STDCHead(FCNHead): + """This head is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + boundary_threshold (float): The threshold of calculating boundary. + Default: 0.1. + """ + + def __init__(self, boundary_threshold=0.1, **kwargs): + super(STDCHead, self).__init__(**kwargs) + self.boundary_threshold = boundary_threshold + # Using register buffer to make laplacian kernel on the same + # device of `seg_label`. + self.register_buffer( + 'laplacian_kernel', + torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], + dtype=torch.float32, + requires_grad=False).reshape((1, 1, 3, 3))) + self.fusion_kernel = torch.nn.Parameter( + torch.tensor([[6. / 10], [3. / 10], [1. / 10]], + dtype=torch.float32).reshape(1, 3, 1, 1), + requires_grad=False) + + def losses(self, seg_logit, seg_label): + """Compute Detail Aggregation Loss.""" + # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv + # parameters. However, it is a constant in original repo and other + # codebase because it would not be added into computation graph + # after threshold operation. + seg_label = seg_label.to(self.laplacian_kernel) + boundary_targets = F.conv2d( + seg_label, self.laplacian_kernel, padding=1) + boundary_targets = boundary_targets.clamp(min=0) + boundary_targets[boundary_targets > self.boundary_threshold] = 1 + boundary_targets[boundary_targets <= self.boundary_threshold] = 0 + + boundary_targets_x2 = F.conv2d( + seg_label, self.laplacian_kernel, stride=2, padding=1) + boundary_targets_x2 = boundary_targets_x2.clamp(min=0) + + boundary_targets_x4 = F.conv2d( + seg_label, self.laplacian_kernel, stride=4, padding=1) + boundary_targets_x4 = boundary_targets_x4.clamp(min=0) + + boundary_targets_x4_up = F.interpolate( + boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x2_up = F.interpolate( + boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') + + boundary_targets_x2_up[ + boundary_targets_x2_up > self.boundary_threshold] = 1 + boundary_targets_x2_up[ + boundary_targets_x2_up <= self.boundary_threshold] = 0 + + boundary_targets_x4_up[ + boundary_targets_x4_up > self.boundary_threshold] = 1 + boundary_targets_x4_up[ + boundary_targets_x4_up <= self.boundary_threshold] = 0 + + boundary_targets_pyramids = torch.stack( + (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), + dim=1) + + boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) + boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids, + self.fusion_kernel) + + boundary_targets_pyramid[ + boundary_targets_pyramid > self.boundary_threshold] = 1 + boundary_targets_pyramid[ + boundary_targets_pyramid <= self.boundary_threshold] = 0 + + loss = super(STDCHead, self).losses(seg_logit, + boundary_targets_pyramid.long()) + return loss diff --git a/data_utils/easyportrait/mmseg/models/decode_heads/uper_head.py b/data_utils/easyportrait/mmseg/models/decode_heads/uper_head.py new file mode 100644 index 0000000000000000000000000000000000000000..06b152a3f02d431cf7f7398dd996a1453a46100e --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/decode_heads/uper_head.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from ..builder import HEADS +from .decode_head import BaseDecodeHead +from .psp_head import PPM + + +@HEADS.register_module() +class UPerHead(BaseDecodeHead): + """Unified Perceptual Parsing for Scene Understanding. + + This head is the implementation of `UPerNet + `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module applied on the last feature. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super(UPerHead, self).__init__( + input_transform='multiple_select', **kwargs) + # PSP Module + self.psp_modules = PPM( + pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + inputs = self._transform_inputs(inputs) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], + size=prev_shape, + mode='bilinear', + align_corners=self.align_corners) + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + feats = self.fpn_bottleneck(fpn_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/data_utils/easyportrait/mmseg/models/losses/__init__.py b/data_utils/easyportrait/mmseg/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e019747d06e6b6cf746360b0d04861b4c6b7e0 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .accuracy import Accuracy, accuracy +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss +from .focal_loss import FocalLoss +from .lovasz_loss import LovaszLoss +from .tversky_loss import TverskyLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', + 'FocalLoss', 'TverskyLoss' +] diff --git a/data_utils/easyportrait/mmseg/models/losses/accuracy.py b/data_utils/easyportrait/mmseg/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9e2d7701088adadd5b6bb71c718c986b87a066 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/accuracy.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + ignore_index (int | None): The label index to be ignored. Default: None + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + if ignore_index is not None: + correct = correct[:, target != ignore_index] + res = [] + eps = torch.finfo(torch.float32).eps + for k in topk: + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + if ignore_index is not None: + total_num = target[target != ignore_index].numel() + eps + else: + total_num = target.numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None, ignore_index=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + self.ignore_index = ignore_index + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh, + self.ignore_index) diff --git a/data_utils/easyportrait/mmseg/models/losses/cross_entropy_loss.py b/data_utils/easyportrait/mmseg/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f62788d475567660793d486325a31cfa27e91572 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/cross_entropy_loss.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + # As the ignore_index often set as 255, so the + # binary class label check should mask out + # ignore_index + assert label[label != ignore_index].max() <= 1, \ + 'For pred with shape [N, 1, H, W], its label must have at ' \ + 'most 2 classes' + pred = pred.squeeze(1) + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == 'mean' and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@LOSSES.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + instead of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_ce', + avg_non_ignore=False): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == 'mean': + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=-100, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/data_utils/easyportrait/mmseg/models/losses/dice_loss.py b/data_utils/easyportrait/mmseg/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a294bc2b1cca9b5b076b4dda55a5f8830cb5e89d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/dice_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ +segmentron/solver/loss.py (Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def dice_loss(pred, + target, + valid_mask, + smooth=1, + exponent=2, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + dice_loss = binary_dice_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + smooth=smooth, + exponent=exponent) + if class_weight is not None: + dice_loss *= class_weight[i] + total_loss += dice_loss + return total_loss / num_classes + + +@weighted_loss +def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth + den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth + + return 1 - num / den + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + """DiceLoss. + + This loss is proposed in `V-Net: Fully Convolutional Neural Networks for + Volumetric Medical Image Segmentation `_. + + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1 + exponent (float): An float number to calculate denominator + value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_dice'. + """ + + def __init__(self, + smooth=1, + exponent=2, + reduction='mean', + class_weight=None, + loss_weight=1.0, + ignore_index=255, + loss_name='loss_dice', + **kwargs): + super(DiceLoss, self).__init__() + self.smooth = smooth + self.exponent = exponent + self.reduction = reduction + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + self._loss_name = loss_name + + def forward(self, + pred, + target, + avg_factor=None, + reduction_override=None, + **kwargs): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * dice_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor, + smooth=self.smooth, + exponent=self.exponent, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/data_utils/easyportrait/mmseg/models/losses/focal_loss.py b/data_utils/easyportrait/mmseg/models/losses/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cd43ce5a325d13a68a0f7e267993ae270847e758 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/focal_loss.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/open-mmlab/mmdetection +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from ..builder import LOSSES +from .utils import weight_reduce_loss + + +# This method is used when cuda is not available +def py_sigmoid_focal_loss(pred, + target, + one_hot_target=None, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction with + shape (N, C) + one_hot_target (None): Placeholder. It should be None. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + if isinstance(alpha, list): + alpha = pred.new_tensor(alpha) + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * one_minus_pt.pow(gamma) + + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + final_weight = torch.ones(1, pred.size(1)).type_as(loss) + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + one_hot_target, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + r"""A wrapper of cuda version `Focal Loss + `_. + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. It's shape + should be (N, ) + one_hot_target (torch.Tensor): The learning label with shape (N, C) + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + final_weight = torch.ones(1, pred.size(1)).type_as(pred) + if isinstance(alpha, list): + # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if + # a list is given, we set the input alpha as 0.5. This means setting + # equal weight for foreground class and background class. By + # multiplying the loss by 2, the effect of setting alpha as 0.5 is + # undone. The alpha of type list is used to regulate the loss in the + # post-processing process. + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, 0.5, None, 'none') * 2 + alpha = pred.new_tensor(alpha) + final_weight = final_weight * ( + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) + else: + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module() +class FocalLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.5, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_focal'): + """`Focal Loss `_ + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal + Loss. Defaults to 0.5. When a list is provided, the length + of the list should be equal to the number of classes. + Please be careful that this parameter is not the + class-wise weight but the weight of a binary classification + problem. This binary classification problem regards the + pixels which belong to one class as the foreground + and the other pixels as the background, each element in + the list is the weight of the corresponding foreground class. + The value of alpha or each element of alpha should be a float + in the interval [0, 1]. If you want to specify the class-wise + weight, please use `class_weight` parameter. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_focal'. + """ + super(FocalLoss, self).__init__() + assert use_sigmoid is True, \ + 'AssertionError: Only sigmoid focal loss supported now.' + assert reduction in ('none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert isinstance(alpha, (float, list)), \ + 'AssertionError: alpha should be of type float' + assert isinstance(gamma, float), \ + 'AssertionError: gamma should be of type float' + assert isinstance(loss_weight, float), \ + 'AssertionError: loss_weight should be of type float' + assert isinstance(loss_name, str), \ + 'AssertionError: loss_name should be of type str' + assert isinstance(class_weight, list) or class_weight is None, \ + 'AssertionError: class_weight must be None or of type list' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.class_weight = class_weight + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction with shape + (N, C) where C = number of classes, or + (N, C, d_1, d_2, ..., d_K) with K≥1 in the + case of K-dimensional loss. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of + K-dimensional loss. If containing class probabilities, + same shape as the input. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + ignore_index (int, optional): The label index to be ignored. + Default: 255 + Returns: + torch.Tensor: The calculated loss + """ + assert isinstance(ignore_index, int), \ + 'ignore_index must be of type int' + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert pred.shape == target.shape or \ + (pred.size(0) == target.size(0) and + pred.shape[2:] == target.shape[1:]), \ + "The shape of pred doesn't match the shape of target" + + original_shape = pred.shape + + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + + if original_shape == target.shape: + # target with shape [B, C, d_1, d_2, ...] + # transform it's shape into [N, C] + # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k] + target = target.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + target = target.reshape(target.size(0), -1) + # [C, N] -> [N, C] + target = target.transpose(0, 1).contiguous() + else: + # target with shape [B, d_1, d_2, ...] + # transform it's shape into [N, ] + target = target.view(-1).contiguous() + valid_mask = (target != ignore_index).view(-1, 1) + # avoid raising error when using F.one_hot() + target = torch.where(target == ignore_index, target.new_tensor(0), + target) + + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + num_classes = pred.size(1) + if torch.cuda.is_available() and pred.is_cuda: + if target.dim() == 1: + one_hot_target = F.one_hot(target, num_classes=num_classes) + else: + one_hot_target = target + target = target.argmax(dim=1) + valid_mask = (target != ignore_index).view(-1, 1) + calculate_loss_func = sigmoid_focal_loss + else: + one_hot_target = None + if target.dim() == 1: + target = F.one_hot(target, num_classes=num_classes) + else: + valid_mask = (target.argmax(dim=1) != ignore_index).view( + -1, 1) + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + one_hot_target, + weight, + gamma=self.gamma, + alpha=self.alpha, + class_weight=self.class_weight, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor) + + if reduction == 'none': + # [N, C] -> [C, N] + loss_cls = loss_cls.transpose(0, 1) + # [C, N] -> [C, B, d1, d2, ...] + # original_shape: [B, C, d1, d2, ...] + loss_cls = loss_cls.reshape(original_shape[1], + original_shape[0], + *original_shape[2:]) + # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] + loss_cls = loss_cls.transpose(0, 1).contiguous() + else: + raise NotImplementedError + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/data_utils/easyportrait/mmseg/models/losses/lovasz_loss.py b/data_utils/easyportrait/mmseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb0fad3931ea7d6140beca8b32a09379fbd7670 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/lovasz_loss.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import mmcv +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@LOSSES.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_lovasz'. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_lovasz'): + super(LovaszLoss, self).__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or mmcv.is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/data_utils/easyportrait/mmseg/models/losses/tversky_loss.py b/data_utils/easyportrait/mmseg/models/losses/tversky_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad14f783c91ee0c88deb7e61afe0f542c3bb735 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/tversky_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from +https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333 +(Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + tversky_loss = binary_tversky_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + alpha=alpha, + beta=beta, + smooth=smooth) + if class_weight is not None: + tversky_loss *= class_weight[i] + total_loss += tversky_loss + return total_loss / num_classes + + +@weighted_loss +def binary_tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) + FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1) + FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1) + tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) + + return 1 - tversky + + +@LOSSES.register_module() +class TverskyLoss(nn.Module): + """TverskyLoss. This loss is proposed in `Tversky loss function for image + segmentation using 3D fully convolutional deep networks. + + `_. + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + alpha(float, in [0, 1]): + The coefficient of false positives. Default: 0.3. + beta (float, in [0, 1]): + The coefficient of false negatives. Default: 0.7. + Note: alpha + beta = 1. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_tversky'. + """ + + def __init__(self, + smooth=1, + class_weight=None, + loss_weight=1.0, + ignore_index=255, + alpha=0.3, + beta=0.7, + loss_name='loss_tversky'): + super(TverskyLoss, self).__init__() + self.smooth = smooth + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + self.alpha = alpha + self.beta = beta + self._loss_name = loss_name + + def forward(self, pred, target, **kwargs): + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * tversky_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + alpha=self.alpha, + beta=self.beta, + smooth=self.smooth, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/data_utils/easyportrait/mmseg/models/losses/utils.py b/data_utils/easyportrait/mmseg/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..621f57c746c4de0abc3ad7a3c2ad35ef65c2fe32 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/losses/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import mmcv +import numpy as np +import torch +import torch.nn.functional as F + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = mmcv.load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/data_utils/easyportrait/mmseg/models/necks/__init__.py b/data_utils/easyportrait/mmseg/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff03186a92b78f942e79cff9eec9f5e2784c359a --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .featurepyramid import Feature2Pyramid +from .fpn import FPN +from .ic_neck import ICNeck +from .jpu import JPU +from .mla_neck import MLANeck +from .multilevel_neck import MultiLevelNeck + +__all__ = [ + 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' +] diff --git a/data_utils/easyportrait/mmseg/models/necks/featurepyramid.py b/data_utils/easyportrait/mmseg/models/necks/featurepyramid.py new file mode 100644 index 0000000000000000000000000000000000000000..82a00ceb1c4fa792538143f7af86b957822cce4d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/featurepyramid.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_norm_layer + +from ..builder import NECKS + + +@NECKS.register_module() +class Feature2Pyramid(nn.Module): + """Feature2Pyramid. + + A neck structure connect ViT backbone and decoder_heads. + + Args: + embed_dims (int): Embedding dimension. + rescales (list[float]): Different sampling multiples were + used to obtain pyramid features. Default: [4, 2, 1, 0.5]. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + embed_dim, + rescales=[4, 2, 1, 0.5], + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super(Feature2Pyramid, self).__init__() + self.rescales = rescales + self.upsample_4x = None + for k in self.rescales: + if k == 4: + self.upsample_4x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + build_norm_layer(norm_cfg, embed_dim)[1], + nn.GELU(), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + elif k == 2: + self.upsample_2x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2)) + elif k == 1: + self.identity = nn.Identity() + elif k == 0.5: + self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) + elif k == 0.25: + self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) + else: + raise KeyError(f'invalid {k} for feature2pyramid') + + def forward(self, inputs): + assert len(inputs) == len(self.rescales) + outputs = [] + if self.upsample_4x is not None: + ops = [ + self.upsample_4x, self.upsample_2x, self.identity, + self.downsample_2x + ] + else: + ops = [ + self.upsample_2x, self.identity, self.downsample_2x, + self.downsample_4x + ] + for i in range(len(inputs)): + outputs.append(ops[i](inputs[i])) + return tuple(outputs) diff --git a/data_utils/easyportrait/mmseg/models/necks/fpn.py b/data_utils/easyportrait/mmseg/models/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..6997de9d4428a0625241b55c6d509afd8e2b515a --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/fpn.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule, auto_fp16 + +from mmseg.ops import resize +from ..builder import NECKS + + +@NECKS.register_module() +class FPN(BaseModule): + """Feature Pyramid Network. + + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(FPN, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + @auto_fp16() + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/necks/ic_neck.py b/data_utils/easyportrait/mmseg/models/necks/ic_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d81cef8e16bd82f11c5ed35c04564a702b667d --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/ic_neck.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import NECKS + + +class CascadeFeatureFusion(BaseModule): + """Cascade Feature Fusion Unit in ICNet. + + Args: + low_channels (int): The number of input channels for + low resolution feature map. + high_channels (int): The number of input channels for + high resolution feature map. + out_channels (int): The number of output channels. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Returns: + x (Tensor): The output tensor of shape (N, out_channels, H, W). + x_low (Tensor): The output tensor of shape (N, out_channels, H, W) + for Cascade Label Guidance in auxiliary heads. + """ + + def __init__(self, + low_channels, + high_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super(CascadeFeatureFusion, self).__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.conv_low = ConvModule( + low_channels, + out_channels, + 3, + padding=2, + dilation=2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_high = ConvModule( + high_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_low, x_high): + x_low = resize( + x_low, + size=x_high.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + # Note: Different from original paper, `x_low` is underwent + # `self.conv_low` rather than another 1x1 conv classifier + # before being used for auxiliary head. + x_low = self.conv_low(x_low) + x_high = self.conv_high(x_high) + x = x_low + x_high + x = F.relu(x, inplace=True) + return x, x_low + + +@NECKS.register_module() +class ICNeck(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This head is the implementation of `ICHead + `_. + + Args: + in_channels (int): The number of input image channels. Default: 3. + out_channels (int): The numbers of output feature channels. + Default: 128. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(64, 256, 256), + out_channels=128, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super(ICNeck, self).__init__(init_cfg=init_cfg) + assert len(in_channels) == 3, 'Length of input channels \ + must be 3!' + + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.cff_24 = CascadeFeatureFusion( + self.in_channels[2], + self.in_channels[1], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + self.cff_12 = CascadeFeatureFusion( + self.out_channels, + self.in_channels[0], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, inputs): + assert len(inputs) == 3, 'Length of input feature \ + maps must be 3!' + + x_sub1, x_sub2, x_sub4 = inputs + x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) + x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) + # Note: `x_cff_12` is used for decode_head, + # `x_24` and `x_12` are used for auxiliary head. + return x_24, x_12, x_cff_12 diff --git a/data_utils/easyportrait/mmseg/models/necks/jpu.py b/data_utils/easyportrait/mmseg/models/necks/jpu.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc6b9f428911a48a7fe3f1f2913812e45e4737e --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/jpu.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import NECKS + + +@NECKS.register_module() +class JPU(BaseModule): + """FastFCN: Rethinking Dilated Convolution in the Backbone + for Semantic Segmentation. + + This Joint Pyramid Upsampling (JPU) neck is the implementation of + `FastFCN `_. + + Args: + in_channels (Tuple[int], optional): The number of input channels + for each convolution operations before upsampling. + Default: (512, 1024, 2048). + mid_channels (int): The number of output channels of JPU. + Default: 512. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + dilations (tuple[int]): Dilation rate of each Depthwise + Separable ConvModule. Default: (1, 2, 4, 8). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(512, 1024, 2048), + mid_channels=512, + start_level=0, + end_level=-1, + dilations=(1, 2, 4, 8), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(JPU, self).__init__(init_cfg=init_cfg) + assert isinstance(in_channels, tuple) + assert isinstance(dilations, tuple) + self.in_channels = in_channels + self.mid_channels = mid_channels + self.start_level = start_level + self.num_ins = len(in_channels) + if end_level == -1: + self.backbone_end_level = self.num_ins + else: + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + + self.dilations = dilations + self.align_corners = align_corners + + self.conv_layers = nn.ModuleList() + self.dilation_layers = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + conv_layer = nn.Sequential( + ConvModule( + self.in_channels[i], + self.mid_channels, + kernel_size=3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.conv_layers.append(conv_layer) + for i in range(len(dilations)): + dilation_layer = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=(self.backbone_end_level - self.start_level) * + self.mid_channels, + out_channels=self.mid_channels, + kernel_size=3, + stride=1, + padding=dilations[i], + dilation=dilations[i], + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=act_cfg)) + self.dilation_layers.append(dilation_layer) + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels), 'Length of inputs must \ + be the same with self.in_channels!' + + feats = [ + self.conv_layers[i - self.start_level](inputs[i]) + for i in range(self.start_level, self.backbone_end_level) + ] + + h, w = feats[0].shape[2:] + for i in range(1, len(feats)): + feats[i] = resize( + feats[i], + size=(h, w), + mode='bilinear', + align_corners=self.align_corners) + + feat = torch.cat(feats, dim=1) + concat_feat = torch.cat([ + self.dilation_layers[i](feat) for i in range(len(self.dilations)) + ], + dim=1) + + outs = [] + + # Default: outs[2] is the output of JPU for decoder head, outs[1] is + # the feature map from backbone for auxiliary head. Additionally, + # outs[0] can also be used for auxiliary head. + for i in range(self.start_level, self.backbone_end_level - 1): + outs.append(inputs[i]) + outs.append(concat_feat) + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/necks/mla_neck.py b/data_utils/easyportrait/mmseg/models/necks/mla_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..1513e296daaedf83bea23ab2c168fb63482bed23 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/mla_neck.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from ..builder import NECKS + + +class MLAModule(nn.Module): + + def __init__(self, + in_channels=[1024, 1024, 1024, 1024], + out_channels=256, + norm_cfg=None, + act_cfg=None): + super(MLAModule, self).__init__() + self.channel_proj = nn.ModuleList() + for i in range(len(in_channels)): + self.channel_proj.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.feat_extract = nn.ModuleList() + for i in range(len(in_channels)): + self.feat_extract.append( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + + # feat_list -> [p2, p3, p4, p5] + feat_list = [] + for x, conv in zip(inputs, self.channel_proj): + feat_list.append(conv(x)) + + # feat_list -> [p5, p4, p3, p2] + # mid_list -> [m5, m4, m3, m2] + feat_list = feat_list[::-1] + mid_list = [] + for feat in feat_list: + if len(mid_list) == 0: + mid_list.append(feat) + else: + mid_list.append(mid_list[-1] + feat) + + # mid_list -> [m5, m4, m3, m2] + # out_list -> [o2, o3, o4, o5] + out_list = [] + for mid, conv in zip(mid_list, self.feat_extract): + out_list.append(conv(mid)) + + return tuple(out_list) + + +@NECKS.register_module() +class MLANeck(nn.Module): + """Multi-level Feature Aggregation. + + This neck is `The Multi-level Feature Aggregation construction of + SETR `_. + + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + norm_cfg=None, + act_cfg=None): + super(MLANeck, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + + # In order to build general vision transformer backbone, we have to + # move MLA to neck. + self.norm = nn.ModuleList([ + build_norm_layer(norm_layer, in_channels[i])[1] + for i in range(len(in_channels)) + ]) + + self.mla = MLAModule( + in_channels=in_channels, + out_channels=out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # Convert from nchw to nlc + outs = [] + for i in range(len(inputs)): + x = inputs[i] + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm[i](x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + outs.append(x) + + outs = self.mla(outs) + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/necks/multilevel_neck.py b/data_utils/easyportrait/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..5151f8762de46ae3e41da8f9683ee8df6f70711e --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, xavier_init + +from mmseg.ops import resize +from ..builder import NECKS + + +@NECKS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[float]): Scale factors for each input feature map. + Default: [0.5, 1, 2, 4] + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super(MultiLevelNeck, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = resize( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/data_utils/easyportrait/mmseg/models/segmentors/__init__.py b/data_utils/easyportrait/mmseg/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387c858bd7a4e1e222db0fe99d85f4728ff48f21 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseSegmentor +from .cascade_encoder_decoder import CascadeEncoderDecoder +from .encoder_decoder import EncoderDecoder + +__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] diff --git a/data_utils/easyportrait/mmseg/models/segmentors/base.py b/data_utils/easyportrait/mmseg/models/segmentors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..76dc8f075a848d3df9e0d8d4f123ca458fb93aa4 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/segmentors/base.py @@ -0,0 +1,291 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import mmcv +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule, auto_fp16 + + +class BaseSegmentor(BaseModule, metaclass=ABCMeta): + """Base class for segmentors.""" + + def __init__(self, init_cfg=None): + super(BaseSegmentor, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got ' + f'{type(var)}') + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(imgs)}) != ' + f'num of image meta ({len(img_metas)})') + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_['ori_shape'] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_['img_shape'] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_['pad_shape'] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data_batch['img_metas'])) + + return outputs + + def val_step(self, data_batch, optimizer=None, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + log_vars_ = dict() + for loss_name, loss_value in log_vars.items(): + k = loss_name + '_val' + log_vars_[k] = loss_value + + outputs = dict( + loss=loss, + log_vars=log_vars_, + num_samples=len(data_batch['img_metas'])) + + return outputs + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + # If the loss_vars has different length, raise assertion error + # to prevent GPUs from infinite waiting. + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys()) + '\n') + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def show_result(self, + img, + result, + palette=None, + win_name='', + show=False, + wait_time=0, + out_file=None, + opacity=0.5): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (Tensor): The semantic segmentation results to draw over + `img`. + palette (list[list[int]]] | np.ndarray | None): The palette of + segmentation map. If None is given, random palette will be + generated. Default: None + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + Returns: + img (Tensor): Only if not `show` or `out_file` + """ + img = mmcv.imread(img) + img = img.copy() + seg = result[0] + if palette is None: + if self.PALETTE is None: + # Get random state before set seed, + # and restore random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + palette = np.random.randint( + 0, 255, size=(len(self.CLASSES), 3)) + np.random.set_state(state) + else: + palette = self.PALETTE + palette = np.array(palette) + assert palette.shape[0] == len(self.CLASSES) + assert palette.shape[1] == 3 + assert len(palette.shape) == 2 + assert 0 < opacity <= 1.0 + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * (1 - opacity) + color_seg * opacity + img = img.astype(np.uint8) + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img diff --git a/data_utils/easyportrait/mmseg/models/segmentors/cascade_encoder_decoder.py b/data_utils/easyportrait/mmseg/models/segmentors/cascade_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a9127a6d1d7b3aaecfa60d165dc60faaf2fe07 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import nn + +from mmseg.core import add_prefix +from mmseg.ops import resize +from .. import builder +from ..builder import SEGMENTORS +from .encoder_decoder import EncoderDecoder + + +@SEGMENTORS.register_module() +class CascadeEncoderDecoder(EncoderDecoder): + """Cascade Encoder Decoder segmentors. + + CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of + CascadeEncoderDecoder are cascaded. The output of previous decoder_head + will be the input of next decoder_head. + """ + + def __init__(self, + num_stages, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + self.num_stages = num_stages + super(CascadeEncoderDecoder, self).__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained, + init_cfg=init_cfg) + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + assert isinstance(decode_head, list) + assert len(decode_head) == self.num_stages + self.decode_head = nn.ModuleList() + for i in range(self.num_stages): + self.decode_head.append(builder.build_head(decode_head[i])) + self.align_corners = self.decode_head[-1].align_corners + self.num_classes = self.decode_head[-1].num_classes + self.out_channels = self.decode_head[-1].out_channels + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) + for i in range(1, self.num_stages): + out = self.decode_head[i].forward_test(x, out, img_metas, + self.test_cfg) + out = resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + + loss_decode = self.decode_head[0].forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode_0')) + + for i in range(1, self.num_stages): + # forward test again, maybe unnecessary for most methods. + if i == 1: + prev_outputs = self.decode_head[0].forward_test( + x, img_metas, self.test_cfg) + else: + prev_outputs = self.decode_head[i - 1].forward_test( + x, prev_outputs, img_metas, self.test_cfg) + loss_decode = self.decode_head[i].forward_train( + x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_decode, f'decode_{i}')) + + return losses diff --git a/data_utils/easyportrait/mmseg/models/segmentors/encoder_decoder.py b/data_utils/easyportrait/mmseg/models/segmentors/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ce8dfefaf6cd48aca290b41d623695935aaf99 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/segmentors/encoder_decoder.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.core import add_prefix +from mmseg.ops import resize +from .. import builder +from ..builder import SEGMENTORS +from .base import BaseSegmentor + + +@SEGMENTORS.register_module() +class EncoderDecoder(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(EncoderDecoder, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, + gt_semantic_seg) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train( + x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + # TODO refactor + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + # remove padding area + resize_shape = img_meta[0]['img_shape'][:2] + preds = preds[:, :, :resize_shape[0], :resize_shape[1]] + preds = resize( + preds, + size=img_meta[0]['ori_shape'][:2], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + # remove padding area + resize_shape = img_meta[0]['img_shape'][:2] + seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]] + size = img_meta[0]['ori_shape'][:2] + seg_logit = resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + if self.out_channels == 1: + output = F.sigmoid(seg_logit) + else: + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]['flip'] + if flip: + flip_direction = img_meta[0]['flip_direction'] + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + output = output.flip(dims=(3, )) + elif flip_direction == 'vertical': + output = output.flip(dims=(2, )) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + if self.out_channels == 1: + seg_pred = (seg_logit > + self.decode_head.threshold).to(seg_logit).squeeze(1) + else: + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def simple_test_logits(self, img, img_metas, rescale=True): + """Test without augmentations. + + Return numpy seg_map logits. + """ + seg_logit = self.inference(img[0], img_metas[0], rescale) + seg_logit = seg_logit.cpu().numpy() + return seg_logit + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + if self.out_channels == 1: + seg_pred = (seg_logit > + self.decode_head.threshold).to(seg_logit).squeeze(1) + else: + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test_logits(self, img, img_metas, rescale=True): + """Test with augmentations. + + Return seg_map logits. Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + + imgs = img + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + + seg_logit /= len(imgs) + seg_logit = seg_logit.cpu().numpy() + return seg_logit diff --git a/data_utils/easyportrait/mmseg/models/utils/__init__.py b/data_utils/easyportrait/mmseg/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8329021bc45dffb076751351bb0b44f56464f0 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .embed import PatchEmbed +from .inverted_residual import InvertedResidual, InvertedResidualV3 +from .make_divisible import make_divisible +from .res_layer import ResLayer +from .se_layer import SELayer +from .self_attention_block import SelfAttentionBlock +from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, + nlc_to_nchw) +from .up_conv_block import UpConvBlock + +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', + 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' +] diff --git a/data_utils/easyportrait/mmseg/models/utils/embed.py b/data_utils/easyportrait/mmseg/models/utils/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..1515675e1eeec92ca1b6d417c5e202ce8aecc35f --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/embed.py @@ -0,0 +1,330 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner.base_module import BaseModule +from mmcv.utils import to_2tuple + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1. + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super(AdaptivePadding, self).__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int, optional): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=768, + conv_type='Conv2d', + kernel_size=16, + stride=None, + padding='corner', + dilation=1, + bias=True, + norm_cfg=None, + input_size=None, + init_cfg=None): + super(PatchEmbed, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/data_utils/easyportrait/mmseg/models/utils/inverted_residual.py b/data_utils/easyportrait/mmseg/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cda76822a45baa0cbd27c98f1e7196e1e26e9a --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/inverted_residual.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch import nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + **kwargs): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(InvertedResidualV3, self).__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/data_utils/easyportrait/mmseg/models/utils/make_divisible.py b/data_utils/easyportrait/mmseg/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..ed42c2eeea2a6aed03a0be5516b8d1ef1139e486 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/make_divisible.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/data_utils/easyportrait/mmseg/models/utils/res_layer.py b/data_utils/easyportrait/mmseg/models/utils/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..190a0c5d5a63846c9a42c6622635cf2456bd6635 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/res_layer.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner import Sequential +from torch import nn as nn + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + multi_grid (int | None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + multi_grid=None, + contract_dilation=False, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if multi_grid is None: + if dilation > 1 and contract_dilation: + first_dilation = dilation // 2 + else: + first_dilation = dilation + else: + first_dilation = multi_grid[0] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=first_dilation, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation if multi_grid is None else multi_grid[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super(ResLayer, self).__init__(*layers) diff --git a/data_utils/easyportrait/mmseg/models/utils/se_layer.py b/data_utils/easyportrait/mmseg/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..16f52aa5c03b982c0d5f9ab9f145f48515a7ffa2 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/se_layer.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch.nn as nn +from mmcv.cnn import ConvModule + +from .make_divisible import make_divisible + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configured + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configured by the first dict and the + second activation layer will be configured by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)). + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))): + super(SELayer, self).__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmcv.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=make_divisible(channels // ratio, 8), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=make_divisible(channels // ratio, 8), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/data_utils/easyportrait/mmseg/models/utils/self_attention_block.py b/data_utils/easyportrait/mmseg/models/utils/self_attention_block.py new file mode 100644 index 0000000000000000000000000000000000000000..c945fa7168208fff513c4d397ad2c9a7ac4383ad --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/self_attention_block.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule, constant_init +from torch import nn as nn +from torch.nn import functional as F + + +class SelfAttentionBlock(nn.Module): + """General self-attention block/non-local block. + + Please refer to https://arxiv.org/abs/1706.03762 for details about key, + query and value. + + Args: + key_in_channels (int): Input channels of key feature. + query_in_channels (int): Input channels of query feature. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_downsample (nn.Module): Query downsample module. + key_downsample (nn.Module): Key downsample module. + key_query_num_convs (int): Number of convs for key/query projection. + value_num_convs (int): Number of convs for value projection. + matmul_norm (bool): Whether normalize attention map with sqrt of + channels + with_out (bool): Whether use out projection. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, key_in_channels, query_in_channels, channels, + out_channels, share_key_query, query_downsample, + key_downsample, key_query_num_convs, value_out_num_convs, + key_query_norm, value_out_norm, matmul_norm, with_out, + conv_cfg, norm_cfg, act_cfg): + super(SelfAttentionBlock, self).__init__() + if share_key_query: + assert key_in_channels == query_in_channels + self.key_in_channels = key_in_channels + self.query_in_channels = query_in_channels + self.out_channels = out_channels + self.channels = channels + self.share_key_query = share_key_query + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.key_project = self.build_project( + key_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if share_key_query: + self.query_project = self.key_project + else: + self.query_project = self.build_project( + query_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.value_project = self.build_project( + key_in_channels, + channels if with_out else out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if with_out: + self.out_project = self.build_project( + channels, + out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.out_project = None + + self.query_downsample = query_downsample + self.key_downsample = key_downsample + self.matmul_norm = matmul_norm + + self.init_weights() + + def init_weights(self): + """Initialize weight of later layer.""" + if self.out_project is not None: + if not isinstance(self.out_project, ConvModule): + constant_init(self.out_project, 0) + + def build_project(self, in_channels, channels, num_convs, use_conv_module, + conv_cfg, norm_cfg, act_cfg): + """Build projection layer for key/query/value/out.""" + if use_conv_module: + convs = [ + ConvModule( + in_channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ] + for _ in range(num_convs - 1): + convs.append( + ConvModule( + channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + convs = [nn.Conv2d(in_channels, channels, 1)] + for _ in range(num_convs - 1): + convs.append(nn.Conv2d(channels, channels, 1)) + if len(convs) > 1: + convs = nn.Sequential(*convs) + else: + convs = convs[0] + return convs + + def forward(self, query_feats, key_feats): + """Forward function.""" + batch_size = query_feats.size(0) + query = self.query_project(query_feats) + if self.query_downsample is not None: + query = self.query_downsample(query) + query = query.reshape(*query.shape[:2], -1) + query = query.permute(0, 2, 1).contiguous() + + key = self.key_project(key_feats) + value = self.value_project(key_feats) + if self.key_downsample is not None: + key = self.key_downsample(key) + value = self.key_downsample(value) + key = key.reshape(*key.shape[:2], -1) + value = value.reshape(*value.shape[:2], -1) + value = value.permute(0, 2, 1).contiguous() + + sim_map = torch.matmul(query, key) + if self.matmul_norm: + sim_map = (self.channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.reshape(batch_size, -1, *query_feats.shape[2:]) + if self.out_project is not None: + context = self.out_project(context) + return context diff --git a/data_utils/easyportrait/mmseg/models/utils/shape_convert.py b/data_utils/easyportrait/mmseg/models/utils/shape_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..cce1e220b645d4b02df1ec2d9ed3137c8acba707 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/shape_convert.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): + """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the + reshaped tensor as the input of `module`, and the convert the output of + `module`, whose shape is. + + [N, L, C], to [N, C, H, W]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, L, C] as input. + x (Tensor): The input tensor of shape [N, C, H, W]. + contiguous: + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, C, H, W]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> norm = nn.LayerNorm(4) + >>> feature_map = torch.rand(4, 4, 5, 5) + >>> output = nchw2nlc2nchw(norm, feature_map) + """ + B, C, H, W = x.shape + if not contiguous: + x = x.flatten(2).transpose(1, 2) + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W) + else: + x = x.flatten(2).transpose(1, 2).contiguous() + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + return x + + +def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): + """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the + reshaped tensor as the input of `module`, and convert the output of + `module`, whose shape is. + + [N, C, H, W], to [N, L, C]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, C, H, W] as input. + x (Tensor): The input tensor of shape [N, L, C]. + hw_shape: (Sequence[int]): The height and width of the + feature map with shape [N, C, H, W]. + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, L, C]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> conv = nn.Conv2d(16, 16, 3, 1, 1) + >>> feature_map = torch.rand(4, 25, 16) + >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + if not contiguous: + x = x.transpose(1, 2).reshape(B, C, H, W) + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2) + else: + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2).contiguous() + return x diff --git a/data_utils/easyportrait/mmseg/models/utils/up_conv_block.py b/data_utils/easyportrait/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d8396d9c2cc77135946ae4b43620e7b583915421 --- /dev/null +++ b/data_utils/easyportrait/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super(UpConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/data_utils/easyportrait/mmseg/ops/__init__.py b/data_utils/easyportrait/mmseg/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc075cd4eb7af7a8d2ad146233d9d5973e7f036d --- /dev/null +++ b/data_utils/easyportrait/mmseg/ops/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .encoding import Encoding +from .wrappers import Upsample, resize + +__all__ = ['Upsample', 'resize', 'Encoding'] diff --git a/data_utils/easyportrait/mmseg/ops/encoding.py b/data_utils/easyportrait/mmseg/ops/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..f397cc54e84feb6a9979c767d8602601d5613542 --- /dev/null +++ b/data_utils/easyportrait/mmseg/ops/encoding.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super(Encoding, self).__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/data_utils/easyportrait/mmseg/ops/wrappers.py b/data_utils/easyportrait/mmseg/ops/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..bcababd51c1f3f61efb73d514cdb14ee66767cb4 --- /dev/null +++ b/data_utils/easyportrait/mmseg/ops/wrappers.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > input_w: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/data_utils/easyportrait/mmseg/utils/__init__.py b/data_utils/easyportrait/mmseg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ef4b355c437ffbc230a07b3bbce34d5ddd745a --- /dev/null +++ b/data_utils/easyportrait/mmseg/utils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .collect_env import collect_env +from .logger import get_root_logger +from .misc import find_latest_checkpoint +from .set_env import setup_multi_processes +from .util_distribution import build_ddp, build_dp, get_device + +__all__ = [ + 'get_root_logger', 'collect_env', 'find_latest_checkpoint', + 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device' +] diff --git a/data_utils/easyportrait/mmseg/utils/collect_env.py b/data_utils/easyportrait/mmseg/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..3379ecb06bb059f2694c7ce9e1f10fa6c4f938b9 --- /dev/null +++ b/data_utils/easyportrait/mmseg/utils/collect_env.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import collect_env as collect_base_env +from mmcv.utils import get_git_hash + +import mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print('{}: {}'.format(name, val)) diff --git a/data_utils/easyportrait/mmseg/utils/logger.py b/data_utils/easyportrait/mmseg/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb3c78d6da0d6c850a4603a255ecdea93b15d2b --- /dev/null +++ b/data_utils/easyportrait/mmseg/utils/logger.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging + +from mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + + return logger diff --git a/data_utils/easyportrait/mmseg/utils/misc.py b/data_utils/easyportrait/mmseg/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1b6b163c21ed8ef071a97f42b7fd8997e78215 --- /dev/null +++ b/data_utils/easyportrait/mmseg/utils/misc.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os.path as osp +import warnings + + +def find_latest_checkpoint(path, suffix='pth'): + """This function is for finding the latest checkpoint. + + It will be used when automatically resume, modified from + https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py + + Args: + path (str): The path to find checkpoints. + suffix (str): File extension for the checkpoint. Defaults to pth. + + Returns: + latest_path(str | None): File path of the latest checkpoint. + """ + if not osp.exists(path): + warnings.warn("The path of the checkpoints doesn't exist.") + return None + if osp.exists(osp.join(path, f'latest.{suffix}')): + return osp.join(path, f'latest.{suffix}') + + checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) + if len(checkpoints) == 0: + warnings.warn('The are no checkpoints in the path') + return None + latest = -1 + latest_path = '' + for checkpoint in checkpoints: + if len(checkpoint) < len(latest_path): + continue + # `count` is iteration number, as checkpoints are saved as + # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. + count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) + if count > latest: + latest = count + latest_path = checkpoint + return latest_path diff --git a/data_utils/easyportrait/mmseg/utils/set_env.py b/data_utils/easyportrait/mmseg/utils/set_env.py new file mode 100644 index 0000000000000000000000000000000000000000..bf184539918e05a26997b5239a361e5283e98443 --- /dev/null +++ b/data_utils/easyportrait/mmseg/utils/set_env.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import platform + +import cv2 +import torch.multiprocessing as mp + +from ..utils import get_root_logger + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + logger = get_root_logger() + + # set multi-process start method + if platform.system() != 'Windows': + mp_start_method = cfg.get('mp_start_method', None) + current_method = mp.get_start_method(allow_none=True) + if mp_start_method in ('fork', 'spawn', 'forkserver'): + logger.info( + f'Multi-processing start method `{mp_start_method}` is ' + f'different from the previous setting `{current_method}`.' + f'It will be force set to `{mp_start_method}`.') + mp.set_start_method(mp_start_method, force=True) + else: + logger.info( + f'Multi-processing start method is `{mp_start_method}`') + + # disable opencv multithreading to avoid system being overloaded + opencv_num_threads = cfg.get('opencv_num_threads', None) + if isinstance(opencv_num_threads, int): + logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') + cv2.setNumThreads(opencv_num_threads) + else: + logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') + + if cfg.data.workers_per_gpu > 1: + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + omp_num_threads = cfg.get('omp_num_threads', None) + if 'OMP_NUM_THREADS' not in os.environ: + if isinstance(omp_num_threads, int): + logger.info(f'OMP num threads is {omp_num_threads}') + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + else: + logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') + + # setup MKL threads + if 'MKL_NUM_THREADS' not in os.environ: + mkl_num_threads = cfg.get('mkl_num_threads', None) + if isinstance(mkl_num_threads, int): + logger.info(f'MKL num threads is {mkl_num_threads}') + os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + else: + logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') diff --git a/data_utils/easyportrait/mmseg/utils/util_distribution.py b/data_utils/easyportrait/mmseg/utils/util_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..16651c225ab177991aea02a6f34cb3a054be89de --- /dev/null +++ b/data_utils/easyportrait/mmseg/utils/util_distribution.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel + +from mmseg import digit_version + +dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} + +ddp_factory = {'cuda': MMDistributedDataParallel} + + +def build_dp(model, device='cuda', dim=0, *args, **kwargs): + """build DataParallel module by device type. + + if device is cuda, return a MMDataParallel module; if device is mlu, + return a MLUDataParallel module. + + Args: + model (:class:`nn.Module`): module to be parallelized. + device (str): device type, cuda, cpu or mlu. Defaults to cuda. + dim (int): Dimension used to scatter the data. Defaults to 0. + + Returns: + :class:`nn.Module`: parallelized module. + """ + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ + 'Please use MMCV >= 1.5.0 for MLU training!' + from mmcv.device.mlu import MLUDataParallel + dp_factory['mlu'] = MLUDataParallel + model = model.mlu() + + return dp_factory[device](model, dim=dim, *args, **kwargs) + + +def build_ddp(model, device='cuda', *args, **kwargs): + """Build DistributedDataParallel module by device type. + + If device is cuda, return a MMDistributedDataParallel module; + if device is mlu, return a MLUDistributedDataParallel module. + + Args: + model (:class:`nn.Module`): module to be parallelized. + device (str): device type, mlu or cuda. + + Returns: + :class:`nn.Module`: parallelized module. + + References: + .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. + DistributedDataParallel.html + """ + assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' + if device == 'cuda': + model = model.cuda() + elif device == 'mlu': + assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ + 'Please use MMCV >= 1.5.0 for MLU training!' + from mmcv.device.mlu import MLUDistributedDataParallel + ddp_factory['mlu'] = MLUDistributedDataParallel + model = model.mlu() + + return ddp_factory[device](model, *args, **kwargs) + + +def is_mlu_available(): + """Returns a bool indicating if MLU is currently available.""" + return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() + + +def get_device(): + """Returns an available device, cpu, cuda or mlu.""" + is_device_available = { + 'cuda': torch.cuda.is_available(), + 'mlu': is_mlu_available() + } + device_list = [k for k, v in is_device_available.items() if v] + return device_list[0] if len(device_list) == 1 else 'cpu' diff --git a/data_utils/easyportrait/mmseg/version.py b/data_utils/easyportrait/mmseg/version.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f06467ba574f8f3b821c189b9eb99f0508726f --- /dev/null +++ b/data_utils/easyportrait/mmseg/version.py @@ -0,0 +1,18 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '0.30.0' + + +def parse_version_info(version_str): + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/data_utils/face_parsing/logger.py b/data_utils/face_parsing/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f9ddcc2cae221b4dd881d02404e848b5396f7e --- /dev/null +++ b/data_utils/face_parsing/logger.py @@ -0,0 +1,23 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import os.path as osp +import time +import sys +import logging + +import torch.distributed as dist + + +def setup_logger(logpth): + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) + logfile = osp.join(logpth, logfile) + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' + log_level = logging.INFO + if dist.is_initialized() and not dist.get_rank()==0: + log_level = logging.ERROR + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) + logging.root.addHandler(logging.StreamHandler()) + + diff --git a/data_utils/face_parsing/model.py b/data_utils/face_parsing/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b815170ca74a60213bfb7f3707d5201f9c3327 --- /dev/null +++ b/data_utils/face_parsing/model.py @@ -0,0 +1,285 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + + # return feat_out, feat_out16, feat_out32 + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/data_utils/face_parsing/resnet.py b/data_utils/face_parsing/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2bf95130e9815ba378cb6f73207068b81a04b9 --- /dev/null +++ b/data_utils/face_parsing/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/data_utils/face_parsing/test.py b/data_utils/face_parsing/test.py new file mode 100644 index 0000000000000000000000000000000000000000..047aa8745224de1af3f9d9e7eb358fe5d4f377dc --- /dev/null +++ b/data_utils/face_parsing/test.py @@ -0,0 +1,105 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- +import numpy as np +from model import BiSeNet + +import torch + +import os +import os.path as osp + +from PIL import Image +import torchvision.transforms as transforms +import cv2 +from pathlib import Path +import configargparse +import tqdm + +# import ttach as tta + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', + img_size=(512, 512)): + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize( + vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros( + (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 + + num_of_class = np.max(vis_parsing_anno) + # print(num_of_class) + for pi in range(1, 14): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) + + for pi in [11]: + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([100, 100, 100]) + + for pi in range(14, 16): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) + for pi in range(16, 17): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) + for pi in (17, 18): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 0]) + for pi in range(18, num_of_class+1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + index = np.where(vis_parsing_anno == num_of_class-1) + vis_im = cv2.resize(vis_parsing_anno_color, img_size, + interpolation=cv2.INTER_NEAREST) + if save_im: + cv2.imwrite(save_path, vis_im) + + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + Path(respth).mkdir(parents=True, exist_ok=True) + + print(f'[INFO] loading model...') + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + net.load_state_dict(torch.load(cp)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + image_paths = sorted(os.listdir(dspth)) + + with torch.no_grad(): + for image_path in tqdm.tqdm(image_paths): + if image_path.endswith('.jpg') or image_path.endswith('.png'): + img = Image.open(osp.join(dspth, image_path)) + ori_size = img.size + image = img.resize((512, 512), Image.BILINEAR) + image = image.convert("RGB") + img = to_tensor(image) + + # test-time augmentation. + inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512] + outputs = net(inputs.cuda()) + parsing = outputs.mean(0).cpu().numpy().argmax(0) + + image_path = int(image_path[:-4]) + image_path = str(image_path) + '.png' + + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) + + +if __name__ == "__main__": + parser = configargparse.ArgumentParser() + parser.add_argument('--respath', type=str, default='./result/', help='result path for label') + parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') + parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') + args = parser.parse_args() + evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) diff --git a/data_utils/face_tracking/3DMM/.gitkeep b/data_utils/face_tracking/3DMM/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/face_tracking/__init__.py b/data_utils/face_tracking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/face_tracking/convert_BFM.py b/data_utils/face_tracking/convert_BFM.py new file mode 100644 index 0000000000000000000000000000000000000000..486fe640744ca8c8b8c37caf32ee94e47117c8f3 --- /dev/null +++ b/data_utils/face_tracking/convert_BFM.py @@ -0,0 +1,39 @@ +import numpy as np +from scipy.io import loadmat + +original_BFM = loadmat("3DMM/01_MorphableModel.mat") +sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"] + +shapePC = original_BFM["shapePC"] +shapeEV = original_BFM["shapeEV"] +shapeMU = original_BFM["shapeMU"] +texPC = original_BFM["texPC"] +texEV = original_BFM["texEV"] +texMU = original_BFM["texMU"] + +b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) +mu_shape = shapeMU.reshape(-1, 3) + +b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) +mu_tex = texMU.reshape(-1, 3) + +b_shape = b_shape[:, sub_inds, :].reshape(199, -1) +mu_shape = mu_shape[sub_inds, :].reshape(-1) +b_tex = b_tex[:, sub_inds, :].reshape(199, -1) +mu_tex = mu_tex[sub_inds, :].reshape(-1) + +exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item() +np.save( + "3DMM/3DMM_info.npy", + { + "mu_shape": mu_shape, + "b_shape": b_shape, + "sig_shape": shapeEV.reshape(-1), + "mu_exp": exp_info["mu_exp"], + "b_exp": exp_info["base_exp"], + "sig_exp": exp_info["sig_exp"], + "mu_tex": mu_tex, + "b_tex": b_tex, + "sig_tex": texEV.reshape(-1), + }, +) diff --git a/data_utils/face_tracking/data_loader.py b/data_utils/face_tracking/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..61e330f1af24ec3b26e85336c386767a16e079d5 --- /dev/null +++ b/data_utils/face_tracking/data_loader.py @@ -0,0 +1,16 @@ +import os +import torch +import numpy as np + + +def load_dir(path, start, end): + lmss = [] + imgs_paths = [] + for i in range(start, end): + if os.path.isfile(os.path.join(path, str(i) + ".lms")): + lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32) + lmss.append(lms) + imgs_paths.append(os.path.join(path, str(i) + ".jpg")) + lmss = np.stack(lmss) + lmss = torch.as_tensor(lmss).cuda() + return lmss, imgs_paths diff --git a/data_utils/face_tracking/face_tracker.py b/data_utils/face_tracking/face_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..e97885686b88f80e738ad97b92a4ff7f34ce24d5 --- /dev/null +++ b/data_utils/face_tracking/face_tracker.py @@ -0,0 +1,390 @@ +import os +import sys +import cv2 +import argparse +from pathlib import Path +import torch +import numpy as np +from data_loader import load_dir +from facemodel import Face_3DMM +from util import * +from render_3dmm import Render_3DMM + + +# torch.autograd.set_detect_anomaly(True) + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def set_requires_grad(tensor_list): + for tensor in tensor_list: + tensor.requires_grad = True + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--path", type=str, default="obama/ori_imgs", help="idname of target person" +) +parser.add_argument("--img_h", type=int, default=512, help="image height") +parser.add_argument("--img_w", type=int, default=512, help="image width") +parser.add_argument("--frame_num", type=int, default=11000, help="image number") +args = parser.parse_args() + +start_id = 0 +end_id = args.frame_num + +lms, img_paths = load_dir(args.path, start_id, end_id) +num_frames = lms.shape[0] +h, w = args.img_h, args.img_w +cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda() +id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650 +model_3dmm = Face_3DMM( + os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num +) + +# only use one image per 40 to do fit the focal length +sel_ids = np.arange(0, num_frames, 40) +sel_num = sel_ids.shape[0] +arg_focal = 1600 +arg_landis = 1e5 + +print(f'[INFO] fitting focal length...') + +# fit the focal length +for focal in range(600, 1500, 100): + id_para = lms.new_zeros((1, id_dim), requires_grad=True) + exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True) + euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True) + trans = lms.new_zeros((sel_num, 3), requires_grad=True) + trans.data[:, 2] -= 7 + focal_length = lms.new_zeros(1, requires_grad=False) + focal_length.data += focal + set_requires_grad([id_para, exp_para, euler_angle, trans]) + + optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) + optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1) + + for iter in range(2000): + id_para_batch = id_para.expand(sel_num, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) + loss = loss_lan + optimizer_frame.zero_grad() + loss.backward() + optimizer_frame.step() + # if iter % 100 == 0: + # print(focal, 'pose', iter, loss.item()) + + for iter in range(2500): + id_para_batch = id_para.expand(sel_num, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) + loss_regid = torch.mean(id_para * id_para) + loss_regexp = torch.mean(exp_para * exp_para) + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 + optimizer_idexp.zero_grad() + optimizer_frame.zero_grad() + loss.backward() + optimizer_idexp.step() + optimizer_frame.step() + # if iter % 100 == 0: + # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) + + if iter % 1500 == 0 and iter >= 1500: + for param_group in optimizer_idexp.param_groups: + param_group["lr"] *= 0.2 + for param_group in optimizer_frame.param_groups: + param_group["lr"] *= 0.2 + + print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item()) + + if loss_lan.item() < arg_landis: + arg_landis = loss_lan.item() + arg_focal = focal + +print("[INFO] find best focal:", arg_focal) + +print(f'[INFO] coarse fitting...') + +# for all frames, do a coarse fitting ??? +id_para = lms.new_zeros((1, id_dim), requires_grad=True) +exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) +tex_para = lms.new_zeros( + (1, tex_dim), requires_grad=True +) # not optimized in this block ??? +euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) +trans = lms.new_zeros((num_frames, 3), requires_grad=True) +light_para = lms.new_zeros((num_frames, 27), requires_grad=True) +trans.data[:, 2] -= 7 # ??? +focal_length = lms.new_zeros(1, requires_grad=True) +focal_length.data += arg_focal + +set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para]) + +optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) +optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1) + +for iter in range(1500): + id_para_batch = id_para.expand(num_frames, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) + loss = loss_lan + optimizer_frame.zero_grad() + loss.backward() + optimizer_frame.step() + if iter == 1000: + for param_group in optimizer_frame.param_groups: + param_group["lr"] = 0.1 + # if iter % 100 == 0: + # print('pose', iter, loss.item()) + +for param_group in optimizer_frame.param_groups: + param_group["lr"] = 0.1 + +for iter in range(2000): + id_para_batch = id_para.expand(num_frames, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) + loss_regid = torch.mean(id_para * id_para) + loss_regexp = torch.mean(exp_para * exp_para) + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 + optimizer_idexp.zero_grad() + optimizer_frame.zero_grad() + loss.backward() + optimizer_idexp.step() + optimizer_frame.step() + # if iter % 100 == 0: + # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) + if iter % 1000 == 0 and iter >= 1000: + for param_group in optimizer_idexp.param_groups: + param_group["lr"] *= 0.2 + for param_group in optimizer_frame.param_groups: + param_group["lr"] *= 0.2 + +print(loss_lan.item(), torch.mean(trans[:, 2]).item()) + +print(f'[INFO] fitting light...') + +batch_size = 32 + +device_default = torch.device("cuda:0") +device_render = torch.device("cuda:0") +renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) + +sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] +imgs = [] +for sel_id in sel_ids: + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) +imgs = np.stack(imgs) +sel_imgs = torch.as_tensor(imgs).cuda() +sel_lms = lms[sel_ids] +sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) +set_requires_grad([sel_light]) + +optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1) +optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01) + +for iter in range(71): + sel_exp_para, sel_euler, sel_trans = ( + exp_para[sel_ids], + euler_angle[sel_ids], + trans[sel_ids], + ) + sel_id_para = id_para.expand(batch_size, -1) + geometry = model_3dmm.get_3dlandmarks( + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) + + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) + loss_regid = torch.mean(id_para * id_para) + loss_regexp = torch.mean(sel_exp_para * sel_exp_para) + + sel_tex_para = tex_para.expand(batch_size, -1) + sel_texture = model_3dmm.forward_tex(sel_tex_para) + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + rott_geo = forward_rott(geometry, sel_euler, sel_trans) + render_imgs = renderer( + rott_geo.to(device_render), + sel_texture.to(device_render), + sel_light.to(device_render), + ) + render_imgs = render_imgs.to(device_default) + + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 + render_proj = sel_imgs.clone() + render_proj[mask] = render_imgs[mask][..., :3].byte() + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) + + if iter > 50: + loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8 + else: + loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0 + + optimizer_tl.zero_grad() + optimizer_id_frame.zero_grad() + loss.backward() + + optimizer_tl.step() + optimizer_id_frame.step() + + if iter % 50 == 0 and iter > 0: + for param_group in optimizer_id_frame.param_groups: + param_group["lr"] *= 0.2 + for param_group in optimizer_tl.param_groups: + param_group["lr"] *= 0.2 + # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item()) + + +light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1) +light_para.data = light_mean + +exp_para = exp_para.detach() +euler_angle = euler_angle.detach() +trans = trans.detach() +light_para = light_para.detach() + +print(f'[INFO] fine frame-wise fitting...') + +for i in range(int((num_frames - 1) / batch_size + 1)): + + if (i + 1) * batch_size > num_frames: + start_n = num_frames - batch_size + sel_ids = np.arange(num_frames - batch_size, num_frames) + else: + start_n = i * batch_size + sel_ids = np.arange(i * batch_size, i * batch_size + batch_size) + + imgs = [] + for sel_id in sel_ids: + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) + imgs = np.stack(imgs) + sel_imgs = torch.as_tensor(imgs).cuda() + sel_lms = lms[sel_ids] + + sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True) + sel_exp_para.data = exp_para[sel_ids].clone() + sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True) + sel_euler.data = euler_angle[sel_ids].clone() + sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) + sel_trans.data = trans[sel_ids].clone() + sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) + sel_light.data = light_para[sel_ids].clone() + + set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light]) + + optimizer_cur_batch = torch.optim.Adam( + [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005 + ) + + sel_id_para = id_para.expand(batch_size, -1).detach() + sel_tex_para = tex_para.expand(batch_size, -1).detach() + + pre_num = 5 + + if i > 0: + pre_ids = np.arange(start_n - pre_num, start_n) + + for iter in range(50): + + geometry = model_3dmm.get_3dlandmarks( + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) + loss_regexp = torch.mean(sel_exp_para * sel_exp_para) + + sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + sel_texture = model_3dmm.forward_tex(sel_tex_para) + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + rott_geo = forward_rott(geometry, sel_euler, sel_trans) + render_imgs = renderer( + rott_geo.to(device_render), + sel_texture.to(device_render), + sel_light.to(device_render), + ) + render_imgs = render_imgs.to(device_default) + + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 + + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) + + if i > 0: + geometry_lap = model_3dmm.forward_geo_sub( + id_para.expand(batch_size + pre_num, -1).detach(), + torch.cat((exp_para[pre_ids].detach(), sel_exp_para)), + model_3dmm.rigid_ids, + ) + rott_geo_lap = forward_rott( + geometry_lap, + torch.cat((euler_angle[pre_ids].detach(), sel_euler)), + torch.cat((trans[pre_ids].detach(), sel_trans)), + ) + loss_lap = cal_lap_loss( + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] + ) + else: + geometry_lap = model_3dmm.forward_geo_sub( + id_para.expand(batch_size, -1).detach(), + sel_exp_para, + model_3dmm.rigid_ids, + ) + rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans) + loss_lap = cal_lap_loss( + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] + ) + + + if iter > 30: + loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0 + else: + loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0 + + optimizer_cur_batch.zero_grad() + loss.backward() + optimizer_cur_batch.step() + + # if iter % 10 == 0: + # print( + # i, + # iter, + # loss_col.item(), + # loss_lan.item(), + # loss_lap.item(), + # loss_regexp.item(), + # ) + + print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done") + + render_proj = sel_imgs.clone() + render_proj[mask] = render_imgs[mask][..., :3].byte() + + exp_para[sel_ids] = sel_exp_para.clone() + euler_angle[sel_ids] = sel_euler.clone() + trans[sel_ids] = sel_trans.clone() + light_para[sel_ids] = sel_light.clone() + +torch.save( + { + "id": id_para.detach().cpu(), + "exp": exp_para.detach().cpu(), + "euler": euler_angle.detach().cpu(), + "trans": trans.detach().cpu(), + "focal": focal_length.detach().cpu(), + }, + os.path.join(os.path.dirname(args.path), "track_params.pt"), +) + +print("params saved") diff --git a/data_utils/face_tracking/facemodel.py b/data_utils/face_tracking/facemodel.py new file mode 100644 index 0000000000000000000000000000000000000000..e54e96b750f6133c0380f2a1dd6030a036c6448d --- /dev/null +++ b/data_utils/face_tracking/facemodel.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +import numpy as np +import os +from util import * + + +class Face_3DMM(nn.Module): + def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num): + super(Face_3DMM, self).__init__() + # id_dim = 100 + # exp_dim = 79 + # tex_dim = 100 + self.point_num = point_num + DMM_info = np.load( + os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True + ).item() + base_id = DMM_info["b_shape"][:id_dim, :] + mu_id = DMM_info["mu_shape"] + base_exp = DMM_info["b_exp"][:exp_dim, :] + mu_exp = DMM_info["mu_exp"] + mu = mu_id + mu_exp + mu = mu.reshape(-1, 3) + for i in range(3): + mu[:, i] -= np.mean(mu[:, i]) + mu = mu.reshape(-1) + self.base_id = torch.as_tensor(base_id).cuda() / 100000.0 + self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0 + self.mu = torch.as_tensor(mu).cuda() / 100000.0 + base_tex = DMM_info["b_tex"][:tex_dim, :] + mu_tex = DMM_info["mu_tex"] + self.base_tex = torch.as_tensor(base_tex).cuda() + self.mu_tex = torch.as_tensor(mu_tex).cuda() + sig_id = DMM_info["sig_shape"][:id_dim] + sig_tex = DMM_info["sig_tex"][:tex_dim] + sig_exp = DMM_info["sig_exp"][:exp_dim] + self.sig_id = torch.as_tensor(sig_id).cuda() + self.sig_tex = torch.as_tensor(sig_tex).cuda() + self.sig_exp = torch.as_tensor(sig_exp).cuda() + + keys_info = np.load( + os.path.join(modelpath, "keys_info.npy"), allow_pickle=True + ).item() + self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda() + self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda() + self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda() + self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda() + + def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy): + id_para = id_para * self.sig_id + exp_para = exp_para * self.sig_exp + batch_size = id_para.shape[0] + num_per_contour = self.left_contours.shape[1] + left_contours_flat = self.left_contours.reshape(-1) + right_contours_flat = self.right_contours.reshape(-1) + sel_index = torch.cat( + ( + 3 * left_contours_flat.unsqueeze(1), + 3 * left_contours_flat.unsqueeze(1) + 1, + 3 * left_contours_flat.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + left_geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + left_geometry = left_geometry.view(batch_size, -1, 3) + proj_x = forward_transform( + left_geometry, euler_angle, trans, focal_length, cxy + )[:, :, 0] + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) + arg_min = proj_x.argmin(dim=2) + left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3) + left_3dlands = left_geometry[ + torch.arange(batch_size * 8), arg_min.view(-1), : + ].view(batch_size, 8, 3) + + sel_index = torch.cat( + ( + 3 * right_contours_flat.unsqueeze(1), + 3 * right_contours_flat.unsqueeze(1) + 1, + 3 * right_contours_flat.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + right_geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + right_geometry = right_geometry.view(batch_size, -1, 3) + proj_x = forward_transform( + right_geometry, euler_angle, trans, focal_length, cxy + )[:, :, 0] + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) + arg_max = proj_x.argmax(dim=2) + right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3) + right_3dlands = right_geometry[ + torch.arange(batch_size * 8), arg_max.view(-1), : + ].view(batch_size, 8, 3) + + sel_index = torch.cat( + ( + 3 * self.keyinds.unsqueeze(1), + 3 * self.keyinds.unsqueeze(1) + 1, + 3 * self.keyinds.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + lands_3d = geometry.view(-1, self.keyinds.shape[0], 3) + lands_3d[:, :8, :] = left_3dlands + lands_3d[:, 9:17, :] = right_3dlands + return lands_3d + + def forward_geo_sub(self, id_para, exp_para, sub_index): + id_para = id_para * self.sig_id + exp_para = exp_para * self.sig_exp + sel_index = torch.cat( + ( + 3 * sub_index.unsqueeze(1), + 3 * sub_index.unsqueeze(1) + 1, + 3 * sub_index.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + return geometry.reshape(-1, sub_index.shape[0], 3) + + def forward_geo(self, id_para, exp_para): + id_para = id_para * self.sig_id + exp_para = exp_para * self.sig_exp + geometry = ( + torch.mm(id_para, self.base_id) + + torch.mm(exp_para, self.base_exp) + + self.mu + ) + return geometry.reshape(-1, self.point_num, 3) + + def forward_tex(self, tex_para): + tex_para = tex_para * self.sig_tex + texture = torch.mm(tex_para, self.base_tex) + self.mu_tex + return texture.reshape(-1, self.point_num, 3) diff --git a/data_utils/face_tracking/geo_transform.py b/data_utils/face_tracking/geo_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..4c092241f9ddcf5d85db2a1d94bd905b4527e5eb --- /dev/null +++ b/data_utils/face_tracking/geo_transform.py @@ -0,0 +1,69 @@ +"""This module contains functions for geometry transform and camera projection""" +import torch +import torch.nn as nn +import numpy as np + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + zero = torch.zeros( + (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device + ) + rot_x = torch.cat( + ( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), + 2, + ) + rot_y = torch.cat( + ( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), + 2, + ) + rot_z = torch.cat( + ( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1), + ), + 2, + ) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def rot_trans_geo(geometry, rot, trans): + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) + return rott_geo.permute(0, 2, 1) + + +def euler_trans_geo(geometry, euler, trans): + rot = euler2rot(euler) + return rot_trans_geo(geometry, rot, trans) + + +def proj_geo(rott_geo, camera_para): + fx = camera_para[:, 0] + fy = camera_para[:, 0] + cx = camera_para[:, 1] + cy = camera_para[:, 2] + + X = rott_geo[:, :, 0] + Y = rott_geo[:, :, 1] + Z = rott_geo[:, :, 2] + + fxX = fx[:, None] * X + fyY = fy[:, None] * Y + + proj_x = -fxX / Z + cx[:, None] + proj_y = fyY / Z + cy[:, None] + + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) diff --git a/data_utils/face_tracking/render_3dmm.py b/data_utils/face_tracking/render_3dmm.py new file mode 100644 index 0000000000000000000000000000000000000000..6a29e19312ad2834e785a8ced644bff47c04bc47 --- /dev/null +++ b/data_utils/face_tracking/render_3dmm.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import numpy as np +import os +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + PerspectiveCameras, + FoVPerspectiveCameras, + PointLights, + DirectionalLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, + TexturesVertex, + blending, +) + +from pytorch3d.ops import interpolate_face_attributes + +from pytorch3d.renderer.blending import ( + BlendParams, + hard_rgb_blend, + sigmoid_alpha_blend, + softmax_rgb_blend, +) + + +class SoftSimpleShader(nn.Module): + """ + Per pixel lighting - the lighting model is applied using the interpolated + coordinates and normals for each pixel. The blending function returns the + soft aggregated color using all the faces per pixel. + + To use the default values, simply initialize the shader with the desired + device e.g. + + """ + + def __init__( + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + ): + super().__init__() + self.lights = lights if lights is not None else PointLights(device=device) + self.materials = ( + materials if materials is not None else Materials(device=device) + ) + self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams() + + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + return self + + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + + texels = meshes.sample_textures(fragments) + blend_params = kwargs.get("blend_params", self.blend_params) + + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) + znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + images = softmax_rgb_blend( + texels, fragments, blend_params, znear=znear, zfar=zfar + ) + return images + + +class Render_3DMM(nn.Module): + def __init__( + self, + focal=1015, + img_h=500, + img_w=500, + batch_size=1, + device=torch.device("cuda:0"), + ): + super(Render_3DMM, self).__init__() + + self.focal = focal + self.img_h = img_h + self.img_w = img_w + self.device = device + self.renderer = self.get_render(batch_size) + + dir_path = os.path.dirname(os.path.realpath(__file__)) + topo_info = np.load( + os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True + ).item() + self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) + self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) + + def compute_normal(self, geometry): + vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) + vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) + vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) + tri_normal = nn.functional.normalize(nnorm, dim=2) + v_norm = tri_normal[:, self.vert_tris, :].sum(2) + vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) + return vert_normal + + def get_render(self, batch_size=1): + half_s = self.img_w * 0.5 + R, T = look_at_view_transform(10, 0, 0) + R = R.repeat(batch_size, 1, 1) + T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) + + cameras = FoVPerspectiveCameras( + device=self.device, + R=R, + T=T, + znear=0.01, + zfar=20, + fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, + ) + lights = PointLights( + device=self.device, + location=[[0.0, 0.0, 1e5]], + ambient_color=[[1, 1, 1]], + specular_color=[[0.0, 0.0, 0.0]], + diffuse_color=[[0.0, 0.0, 0.0]], + ) + sigma = 1e-4 + raster_settings = RasterizationSettings( + image_size=(self.img_h, self.img_w), + blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, + faces_per_pixel=2, + perspective_correct=False, + ) + blend_params = blending.BlendParams(background_color=[0, 0, 0]) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), + shader=SoftSimpleShader( + lights=lights, blend_params=blend_params, cameras=cameras + ), + ) + return renderer.to(self.device) + + @staticmethod + def Illumination_layer(face_texture, norm, gamma): + + n_b, num_vertex, _ = face_texture.size() + n_v_full = n_b * num_vertex + gamma = gamma.view(-1, 3, 9).clone() + gamma[:, :, 0] += 0.8 + + gamma = gamma.permute(0, 2, 1) + + a0 = np.pi + a1 = 2 * np.pi / np.sqrt(3.0) + a2 = 2 * np.pi / np.sqrt(8.0) + c0 = 1 / np.sqrt(4 * np.pi) + c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) + c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) + d0 = 0.5 / np.sqrt(3.0) + + Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 + norm = norm.view(-1, 3) + nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] + arrH = [] + + arrH.append(Y0) + arrH.append(-a1 * c1 * ny) + arrH.append(a1 * c1 * nz) + arrH.append(-a1 * c1 * nx) + arrH.append(a2 * c2 * nx * ny) + arrH.append(-a2 * c2 * ny * nz) + arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) + arrH.append(-a2 * c2 * nx * nz) + arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) + + H = torch.stack(arrH, 1) + Y = H.view(n_b, num_vertex, 9) + lighting = Y.bmm(gamma) + + face_color = face_texture * lighting + return face_color + + def forward(self, rott_geometry, texture, diffuse_sh): + face_normal = self.compute_normal(rott_geometry) + face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) + face_color = TexturesVertex(face_color) + mesh = Meshes( + rott_geometry, + self.tris.float().repeat(rott_geometry.shape[0], 1, 1), + face_color, + ) + rendered_img = self.renderer(mesh) + rendered_img = torch.clamp(rendered_img, 0, 255) + + return rendered_img diff --git a/data_utils/face_tracking/render_land.py b/data_utils/face_tracking/render_land.py new file mode 100644 index 0000000000000000000000000000000000000000..e0050f709f210fa13bd7b47cdf8ff46f3c83ce5c --- /dev/null +++ b/data_utils/face_tracking/render_land.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +import render_util +import geo_transform +import numpy as np + + +def compute_tri_normal(geometry, tris): + geometry = geometry.permute(0, 2, 1) + tri_1 = tris[:, 0] + tri_2 = tris[:, 1] + tri_3 = tris[:, 2] + + vert_1 = torch.index_select(geometry, 2, tri_1) + vert_2 = torch.index_select(geometry, 2, tri_2) + vert_3 = torch.index_select(geometry, 2, tri_3) + + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1) + normal = nn.functional.normalize(nnorm).permute(0, 2, 1) + return normal + + +class Compute_normal_base(torch.autograd.Function): + @staticmethod + def forward(ctx, normal): + (normal_b,) = render_util.normal_base_forward(normal) + ctx.save_for_backward(normal) + return normal_b + + @staticmethod + def backward(ctx, grad_normal_b): + (normal,) = ctx.saved_tensors + (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal) + return grad_normal + + +class Normal_Base(torch.nn.Module): + def __init__(self): + super(Normal_Base, self).__init__() + + def forward(self, normal): + return Compute_normal_base.apply(normal) + + +def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): + point_num = geometry.shape[1] + rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) + proj_geo = geo_transform.proj_geo(rott_geo, cam) + rot_tri_normal = compute_tri_normal(rott_geo, tris) + rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) + is_visible = -torch.bmm( + rot_vert_normal.reshape(-1, 1, 3), + nn.functional.normalize(rott_geo.reshape(-1, 3, 1)), + ).reshape(-1, point_num) + is_visible[is_visible < 0.01] = -1 + pixel_valid = torch.zeros( + (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]), + dtype=torch.float32, + device=ori_img.device, + ) + return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid + + +class Render_Face(torch.autograd.Function): + @staticmethod + def forward( + ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid + ): + batch_size, h, w, _ = ori_img.shape + ori_img = ori_img.view(batch_size, -1, 3) + ori_size = torch.cat( + ( + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * h, + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * w, + ), + dim=1, + ).view(-1) + tri_index, tri_coord, render, real = render_util.render_face_forward( + proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid + ) + ctx.save_for_backward( + ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord + ) + return render, real + + @staticmethod + def backward(ctx, grad_render, grad_real): + ( + ori_img, + ori_size, + proj_geo, + texture, + nbl, + tri_inds, + tri_index, + tri_coord, + ) = ctx.saved_tensors + grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( + grad_render, + grad_real, + ori_img, + ori_size, + proj_geo, + texture, + nbl, + tri_inds, + tri_index, + tri_coord, + ) + return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None + + +class Render_RGB(nn.Module): + def __init__(self): + super(Render_RGB, self).__init__() + + def forward( + self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid + ): + return Render_Face.apply( + proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid + ) + + +def cal_land(proj_geo, is_visible, lands_info, land_num): + (land_index,) = render_util.update_contour(lands_info, is_visible, land_num) + proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[ + :, :2 + ].reshape(-1, land_num, 2) + return proj_land + + +class Render_Land(nn.Module): + def __init__(self): + super(Render_Land, self).__init__() + lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32) + self.lands_info = torch.as_tensor(lands_info).cuda() + tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64) + self.tris = torch.as_tensor(tris).cuda() - 1 + vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64) + self.vert_tris = torch.as_tensor(vert_tris).cuda() + self.normal_baser = Normal_Base().cuda() + self.renderer = Render_RGB().cuda() + + def render_mesh(self, geometry, euler, trans, cam, ori_img, light): + batch_size, h, w, _ = ori_img.shape + ori_img = ori_img.view(batch_size, -1, 3) + ori_size = torch.cat( + ( + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * h, + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * w, + ), + dim=1, + ).view(-1) + rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img + ) + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) + nbl = torch.bmm( + tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3) + ) + texture = torch.ones_like(geometry) * 200 + (render,) = render_util.render_mesh( + proj_geo, ori_img, ori_size, texture, nbl, self.tris + ) + return render.view(batch_size, h, w, 3).byte() + + def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): + rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render( + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img + ) + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) + nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) + render, real = self.renderer( + proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid + ) + proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1]) + col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape( + ori_img.shape[0], -1 + ) + col_dis = torch.mean(col_minus * pixel_valid) / ( + torch.mean(pixel_valid) + 0.00001 + ) + land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape( + ori_img.shape[0], -1 + ) + lan_dis = torch.mean(land_dists) + return col_dis, lan_dis diff --git a/data_utils/face_tracking/util.py b/data_utils/face_tracking/util.py new file mode 100644 index 0000000000000000000000000000000000000000..d2554eb323903ca41f64d7c938adc5eca9172c3a --- /dev/null +++ b/data_utils/face_tracking/util.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def compute_tri_normal(geometry, tris): + tri_1 = tris[:, 0] + tri_2 = tris[:, 1] + tri_3 = tris[:, 2] + vert_1 = torch.index_select(geometry, 1, tri_1) + vert_2 = torch.index_select(geometry, 1, tri_2) + vert_3 = torch.index_select(geometry, 1, tri_3) + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) + normal = nn.functional.normalize(nnorm) + return normal + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones(batch_size, 1, 1).to(euler_angle.device) + zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) + rot_x = torch.cat( + ( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), + 2, + ) + rot_y = torch.cat( + ( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), + 2, + ) + rot_z = torch.cat( + ( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1), + ), + 2, + ) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def rot_trans_pts(geometry, rot, trans): + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] + return rott_geo.permute(0, 2, 1) + + +def cal_lap_loss(tensor_list, weight_list): + lap_kernel = ( + torch.Tensor((-0.5, 1.0, -0.5)) + .unsqueeze(0) + .unsqueeze(0) + .float() + .to(tensor_list[0].device) + ) + loss_lap = 0 + for i in range(len(tensor_list)): + in_tensor = tensor_list[i] + in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) + out_tensor = F.conv1d(in_tensor, lap_kernel) + loss_lap += torch.mean(out_tensor ** 2) * weight_list[i] + return loss_lap + + +def proj_pts(rott_geo, focal_length, cxy): + cx, cy = cxy[0], cxy[1] + X = rott_geo[:, :, 0] + Y = rott_geo[:, :, 1] + Z = rott_geo[:, :, 2] + fxX = focal_length * X + fyY = focal_length * Y + proj_x = -fxX / Z + cx + proj_y = fyY / Z + cy + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) + + +def forward_rott(geometry, euler_angle, trans): + rot = euler2rot(euler_angle) + rott_geo = rot_trans_pts(geometry, rot, trans) + return rott_geo + + +def forward_transform(geometry, euler_angle, trans, focal_length, cxy): + rot = euler2rot(euler_angle) + rott_geo = rot_trans_pts(geometry, rot, trans) + proj_geo = proj_pts(rott_geo, focal_length, cxy) + return proj_geo + + +def cal_lan_loss(proj_lan, gt_lan): + return torch.mean((proj_lan - gt_lan) ** 2) + + +def cal_col_loss(pred_img, gt_img, img_mask): + pred_img = pred_img.float() + # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 + loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255 + loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) + loss = torch.mean(loss) + return loss diff --git a/data_utils/hubert.py b/data_utils/hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd989235f3a6eefdba03f3faae8ad6f77a5ac0a --- /dev/null +++ b/data_utils/hubert.py @@ -0,0 +1,94 @@ +from transformers import Wav2Vec2Processor, HubertModel +import soundfile as sf +import numpy as np +import torch + +print("Loading the Wav2Vec2 Processor...") +wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") +print("Loading the HuBERT Model...") +hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") + + +def get_hubert_from_16k_wav(wav_16k_name): + speech_16k, _ = sf.read(wav_16k_name) + hubert = get_hubert_from_16k_speech(speech_16k) + return hubert + +@torch.no_grad() +def get_hubert_from_16k_speech(speech, device="cuda:0"): + global hubert_model + hubert_model = hubert_model.to(device) + if speech.ndim ==2: + speech = speech[:, 0] # [T, 2] ==> [T,] + input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] + input_values_all = input_values_all.to(device) + # For long audio sequence, due to the memory limitation, we cannot process them in one run + # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 + # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. + # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 + # We have the equation to calculate out time step: T = floor((t-k)/s) + # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip + # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N + kernel = 400 + stride = 320 + clip_length = stride * 1000 + num_iter = input_values_all.shape[1] // clip_length + expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride + res_lst = [] + for i in range(num_iter): + if i == 0: + start_idx = 0 + end_idx = clip_length - stride + kernel + else: + start_idx = clip_length * i + end_idx = start_idx + (clip_length - stride + kernel) + input_values = input_values_all[:, start_idx: end_idx] + hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] + res_lst.append(hidden_states[0]) + if num_iter > 0: + input_values = input_values_all[:, clip_length * num_iter:] + else: + input_values = input_values_all + # if input_values.shape[1] != 0: + if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it + hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] + res_lst.append(hidden_states[0]) + else: + print("skip the latest ", input_values.shape[1]) + ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] + # assert ret.shape[0] == expected_T + assert abs(ret.shape[0] - expected_T) <= 1 + if ret.shape[0] < expected_T: + ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) + else: + ret = ret[:expected_T] + return ret + +def make_even_first_dim(tensor): + size = list(tensor.size()) + if size[0] % 2 == 1: + size[0] -= 1 + return tensor[:size[0]] + return tensor + +import soundfile as sf +import numpy as np +import torch +from argparse import ArgumentParser +import librosa + +parser = ArgumentParser() +parser.add_argument('--wav', type=str, help='') +args = parser.parse_args() + +wav_name = args.wav + +speech_16k, sr = librosa.load(wav_name, sr=16000) +# speech_16k = librosa.resample(speech, orig_sr=sr, target_sr=16000) +# print("SR: {} to {}".format(sr, 16000)) +# print(speech.shape, speech_16k.shape) + +hubert_hidden = get_hubert_from_16k_speech(speech_16k) +hubert_hidden = make_even_first_dim(hubert_hidden).reshape(-1, 2, 1024) +np.save(wav_name.replace('.wav', '_hu.npy'), hubert_hidden.detach().numpy()) +print(hubert_hidden.detach().numpy().shape) \ No newline at end of file diff --git a/data_utils/process.py b/data_utils/process.py new file mode 100644 index 0000000000000000000000000000000000000000..93f566159fcb4b7d2439a0d8e33b96ebc7ad34e0 --- /dev/null +++ b/data_utils/process.py @@ -0,0 +1,405 @@ +import os +import glob +import tqdm +import json +import argparse +import cv2 +import numpy as np + +def extract_audio(path, out_path, sample_rate=16000): + + print(f'[INFO] ===== extract audio from {path} to {out_path} =====') + cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}' + os.system(cmd) + print(f'[INFO] ===== extracted audio =====') + + +def extract_audio_features(path, mode='wav2vec'): + + print(f'[INFO] ===== extract audio labels for {path} =====') + if mode == 'wav2vec': + cmd = f'python nerf/asr.py --wav {path} --save_feats' + else: # deepspeech + cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}' + os.system(cmd) + print(f'[INFO] ===== extracted audio labels =====') + + + +def extract_images(path, out_path, fps=25): + + print(f'[INFO] ===== extract images from {path} to {out_path} =====') + cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}' + os.system(cmd) + print(f'[INFO] ===== extracted images =====') + + +def extract_semantics(ori_imgs_dir, parsing_dir): + + print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====') + cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}' + os.system(cmd) + print(f'[INFO] ===== extracted semantics =====') + + +def extract_landmarks(ori_imgs_dir): + + print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====') + + import face_alignment + try: + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) + except: + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + for image_path in tqdm.tqdm(image_paths): + input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) + preds = fa.get_landmarks(input) + if len(preds) > 0: + lands = preds[0].reshape(-1, 2)[:,:2] + np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f') + del fa + print(f'[INFO] ===== extracted face landmarks =====') + + +def extract_background(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====') + + from sklearn.neighbors import NearestNeighbors + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + # only use 1/20 image_paths + image_paths = image_paths[::20] + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + # nearest neighbors + all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() + distss = [] + for image_path in tqdm.tqdm(image_paths): + parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255) + fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + dists, _ = nbrs.kneighbors(all_xys) + distss.append(dists) + + distss = np.stack(distss) + max_dist = np.max(distss, 0) + max_id = np.argmax(distss, 0) + + bc_pixs = max_dist > 5 + bc_pixs_id = np.nonzero(bc_pixs) + bc_ids = max_id[bc_pixs] + + imgs = [] + num_pixs = distss.shape[1] + for image_path in image_paths: + img = cv2.imread(image_path) + imgs.append(img) + imgs = np.stack(imgs).reshape(-1, num_pixs, 3) + + bc_img = np.zeros((h*w, 3), dtype=np.uint8) + bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] + bc_img = bc_img.reshape(h, w, 3) + + max_dist = max_dist.reshape(h, w) + bc_pixs = max_dist > 5 + bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() + fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + distances, indices = nbrs.kneighbors(bg_xys) + bg_fg_xys = fg_xys[indices[:, 0]] + bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] + + cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img) + + print(f'[INFO] ===== extracted background image =====') + + +def extract_torso_and_gt(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract torso and gt images for {base_dir} =====') + + from scipy.ndimage import binary_erosion, binary_dilation + + # load bg + bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + for image_path in tqdm.tqdm(image_paths): + # read ori image + ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + + # read semantics + seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0) + neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0) + torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255) + bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255) + + # get gt image + gt_image = ori_image.copy() + gt_image[bg_part] = bg_image[bg_part] + cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image) + + # get torso image + torso_image = gt_image.copy() # rgb + torso_image[head_part] = bg_image[head_part] + torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha + + # torso part "vertical" in-painting... + L = 8 + 1 + torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2] + # lexsort: sort 2D coords first by y then by x, + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes + inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1])) + torso_coords = torso_coords[inds] + # choose the top pixel for each column + u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True) + top_torso_coords = torso_coords[uid] # [m, 2] + # only keep top-is-head pixels + top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) + mask = head_part[tuple(top_torso_coords_up.T)] + if mask.any(): + top_torso_coords = top_torso_coords[mask] + # get the color + top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3] + # construct inpaint coords (vertically up, or minus in x) + inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2] + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] + inpaint_torso_coords += inpaint_offsets + inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2] + inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3] + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] + inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] + # set color + torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors + + inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool) + inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True + else: + inpaint_torso_mask = None + + + # neck part "vertical" in-painting... + push_down = 4 + L = 48 + push_down + 1 + + neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3) + + neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2] + # lexsort: sort 2D coords first by y then by x, + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes + inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1])) + neck_coords = neck_coords[inds] + # choose the top pixel for each column + u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True) + top_neck_coords = neck_coords[uid] # [m, 2] + # only keep top-is-head pixels + top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0]) + mask = head_part[tuple(top_neck_coords_up.T)] + + top_neck_coords = top_neck_coords[mask] + # push these top down for 4 pixels to make the neck inpainting more natural... + offset_down = np.minimum(ucnt[mask] - 1, push_down) + top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1) + # get the color + top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3] + # construct inpaint coords (vertically up, or minus in x) + inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2] + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] + inpaint_neck_coords += inpaint_offsets + inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2] + inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3] + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] + inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] + # set color + torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors + + # apply blurring to the inpaint area to avoid vertical-line artifects... + inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool) + inpaint_mask[tuple(inpaint_neck_coords.T)] = True + + blur_img = torso_image.copy() + blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT) + + torso_image[inpaint_mask] = blur_img[inpaint_mask] + + # set mask + mask = (neck_part | torso_part | inpaint_mask) + if inpaint_torso_mask is not None: + mask = mask | inpaint_torso_mask + torso_image[~mask] = 0 + torso_alpha[~mask] = 0 + + cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1)) + + print(f'[INFO] ===== extracted torso and gt images =====') + + +def face_tracking(ori_imgs_dir): + + print(f'[INFO] ===== perform face tracking =====') + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}' + + os.system(cmd) + + print(f'[INFO] ===== finished face tracking =====') + + +def save_transforms(base_dir, ori_imgs_dir): + print(f'[INFO] ===== save transforms =====') + + import torch + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + params_dict = torch.load(os.path.join(base_dir, 'track_params.pt')) + focal_len = params_dict['focal'] + euler_angle = params_dict['euler'] + trans = params_dict['trans'] / 10.0 + valid_num = euler_angle.shape[0] + + def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + + # train_val_split = int(valid_num*0.5) + # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set. + train_val_split = int(valid_num * 10 / 11) + + train_ids = torch.arange(0, train_val_split) + val_ids = torch.arange(train_val_split, valid_num) + + rot = euler2rot(euler_angle) + rot_inv = rot.permute(0, 2, 1) + trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2)) + + pose = torch.eye(4, dtype=torch.float32) + save_ids = ['train', 'val'] + train_val_ids = [train_ids, val_ids] + mean_z = -float(torch.mean(trans[:, 2]).item()) + + for split in range(2): + transform_dict = dict() + transform_dict['focal_len'] = float(focal_len[0]) + transform_dict['cx'] = float(w/2.0) + transform_dict['cy'] = float(h/2.0) + transform_dict['frames'] = [] + ids = train_val_ids[split] + save_id = save_ids[split] + + for i in ids: + i = i.item() + frame_dict = dict() + frame_dict['img_id'] = i + frame_dict['aud_id'] = i + + pose[:3, :3] = rot_inv[i] + pose[:3, 3] = trans_inv[i, :, 0] + + frame_dict['transform_matrix'] = pose.numpy().tolist() + + transform_dict['frames'].append(frame_dict) + + with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp: + json.dump(transform_dict, fp, indent=2, separators=(',', ': ')) + + print(f'[INFO] ===== finished saving transforms =====') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="path to video file") + parser.add_argument('--task', type=int, default=-1, help="-1 means all") + parser.add_argument('--asr', type=str, default='deepspeech', help="wav2vec or deepspeech") + + opt = parser.parse_args() + + base_dir = os.path.dirname(opt.path) + + wav_path = os.path.join(base_dir, 'aud.wav') + ori_imgs_dir = os.path.join(base_dir, 'ori_imgs') + parsing_dir = os.path.join(base_dir, 'parsing') + gt_imgs_dir = os.path.join(base_dir, 'gt_imgs') + torso_imgs_dir = os.path.join(base_dir, 'torso_imgs') + + os.makedirs(ori_imgs_dir, exist_ok=True) + os.makedirs(parsing_dir, exist_ok=True) + os.makedirs(gt_imgs_dir, exist_ok=True) + os.makedirs(torso_imgs_dir, exist_ok=True) + + + # extract audio + if opt.task == -1 or opt.task == 1: + extract_audio(opt.path, wav_path) + + # extract audio features + if opt.task == -1 or opt.task == 2: + extract_audio_features(wav_path, mode=opt.asr) + + # extract images + if opt.task == -1 or opt.task == 3: + extract_images(opt.path, ori_imgs_dir) + + # face parsing + if opt.task == -1 or opt.task == 4: + extract_semantics(ori_imgs_dir, parsing_dir) + + # extract bg + if opt.task == -1 or opt.task == 5: + extract_background(base_dir, ori_imgs_dir) + + # extract torso images and gt_images + if opt.task == -1 or opt.task == 6: + extract_torso_and_gt(base_dir, ori_imgs_dir) + + # extract face landmarks + if opt.task == -1 or opt.task == 7: + extract_landmarks(ori_imgs_dir) + + # face tracking + if opt.task == -1 or opt.task == 8: + face_tracking(ori_imgs_dir) + + # save transforms.json + if opt.task == -1 or opt.task == 9: + save_transforms(base_dir, ori_imgs_dir) + diff --git a/data_utils/wav2mel.py b/data_utils/wav2mel.py new file mode 100644 index 0000000000000000000000000000000000000000..f3609e363600b6b9448b653d613cb382c54f4513 --- /dev/null +++ b/data_utils/wav2mel.py @@ -0,0 +1,167 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from wav2mel_hparams import hparams as hp +from librosa.core.audio import resample +import soundfile as sf + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _stft(y): + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + + +def wav2mel(wav, sr): + wav16k = resample(wav, orig_sr=sr, target_sr=16000) + # print('wav16k', wav16k.shape, wav16k.dtype) + mel = melspectrogram(wav16k) + # print('mel', mel.shape, mel.dtype) + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError( + 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + # mel.dtype = np.float32 + mel_chunks = [] + mel_idx_multiplier = 80. / 25 + mel_step_size = 8 + i = start_idx = 0 + while start_idx < len(mel[0]): + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size // 2 > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + elif start_idx - mel_step_size // 2 < 0: + mel_chunks.append(mel[:, :mel_step_size]) + else: + mel_chunks.append(mel[:, start_idx - mel_step_size // 2 : start_idx + mel_step_size // 2]) + i += 1 + return mel_chunks + + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--wav', type=str, default='') + parser.add_argument('--save_feats', action='store_true') + + opt = parser.parse_args() + + wav, sr = librosa.core.load(opt.wav) + mel_chunks = np.array(wav2mel(wav.T, sr)) + print(mel_chunks.shape, mel_chunks.transpose(0,2,1).shape) + + if opt.save_feats: + save_path = opt.wav.replace('.wav', '_mel.npy') + np.save(save_path, mel_chunks.transpose(0,2,1)) + print(f"[INFO] saved logits to {save_path}") \ No newline at end of file diff --git a/data_utils/wav2mel_hparams.py b/data_utils/wav2mel_hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..184b4a50237aa52ae45a0fb314ce5c9bce4be06e --- /dev/null +++ b/data_utils/wav2mel_hparams.py @@ -0,0 +1,80 @@ +class HParams: + def __init__(self, **kwargs): + self.data = {} + + for key, value in kwargs.items(): + self.data[key] = value + + def __getattr__(self, key): + if key not in self.data: + raise AttributeError("'HParams' object has no attribute %s" % key) + return self.data[key] + + def set_hparam(self, key, value): + self.data[key] = value + +# Default hyperparameters +hparams = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=65, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=6000, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=16, + initial_learning_rate=1e-4, + nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=16, + checkpoint_interval=3000, + eval_interval=3000, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=10000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) diff --git a/data_utils/wav2vec.py b/data_utils/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e0e8d118309de6b382eae4e23cade972eea5a5 --- /dev/null +++ b/data_utils/wav2vec.py @@ -0,0 +1,420 @@ +import time +import numpy as np +import torch +import torch.nn.functional as F +from transformers import AutoModelForCTC, AutoProcessor + +import pyaudio +import soundfile as sf +import resampy + +from queue import Queue +from threading import Thread, Event + + +def _read_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] read frame thread ends') + break + frame = stream.read(chunk, exception_on_overflow=False) + frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] + queue.put(frame) + +def _play_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] play frame thread ends') + break + frame = queue.get() + frame = (frame * 32767).astype(np.int16).tobytes() + stream.write(frame, chunk) + +class ASR: + def __init__(self, opt): + + self.opt = opt + + self.play = opt.asr_play + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.fps = opt.fps # 20 ms per frame + self.sample_rate = 16000 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.mode = 'live' if opt.asr_wav == '' else 'file' + + if 'esperanto' in self.opt.asr_model: + self.audio_dim = 44 + elif 'deepspeech' in self.opt.asr_model: + self.audio_dim = 29 + else: + self.audio_dim = 32 + + # prepare context cache + # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms + self.context_size = opt.m + self.stride_left_size = opt.l + self.stride_right_size = opt.r + self.text = '[START]\n' + self.terminated = False + self.frames = [] + + # pad left frames + if self.stride_left_size > 0: + self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) + + + self.exit_event = Event() + self.audio_instance = pyaudio.PyAudio() + + # create input stream + if self.mode == 'file': + self.file_stream = self.create_file_stream() + else: + # start a background process to read frames + self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) + self.queue = Queue() + self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) + + # play out the audio too...? + if self.play: + self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) + self.output_queue = Queue() + self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) + + # current location of audio + self.idx = 0 + + # create wav2vec model + print(f'[INFO] loading ASR model {self.opt.asr_model}...') + self.processor = AutoProcessor.from_pretrained(opt.asr_model) + self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + + # prepare to save logits + if self.opt.asr_save_feats: + self.all_feats = [] + + # the extracted features + # use a loop queue to efficiently record endless features: [f--t---][-------][-------] + self.feat_buffer_size = 4 + self.feat_buffer_idx = 0 + self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) + + # TODO: hard coded 16 and 8 window size... + self.front = self.feat_buffer_size * self.context_size - 8 # fake padding + self.tail = 8 + # attention window... + self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... + + # warm up steps needed: mid + right + window_size + attention_size + self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 + + self.listening = False + self.playing = False + + def listen(self): + # start + if self.mode == 'live' and not self.listening: + print(f'[INFO] starting read frame thread...') + self.process_read_frame.start() + self.listening = True + + if self.play and not self.playing: + print(f'[INFO] starting play frame thread...') + self.process_play_frame.start() + self.playing = True + + def stop(self): + + self.exit_event.set() + + if self.play: + self.output_stream.stop_stream() + self.output_stream.close() + if self.playing: + self.process_play_frame.join() + self.playing = False + + if self.mode == 'live': + self.input_stream.stop_stream() + self.input_stream.close() + if self.listening: + self.process_read_frame.join() + self.listening = False + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + + self.stop() + + if self.mode == 'live': + # live mode: also print the result text. + self.text += '\n[END]' + print(self.text) + + def get_next_feat(self): + # return a [1/8, 16] window, for the next input to nerf side. + + while len(self.att_feats) < 8: + # [------f+++t-----] + if self.front < self.tail: + feat = self.feat_queue[self.front:self.tail] + # [++t-----------f+] + else: + feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) + + self.front = (self.front + 2) % self.feat_queue.shape[0] + self.tail = (self.tail + 2) % self.feat_queue.shape[0] + + # print(self.front, self.tail, feat.shape) + + self.att_feats.append(feat.permute(1, 0)) + + att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] + + # discard old + self.att_feats = self.att_feats[1:] + + return att_feat + + def run_step(self): + + if self.terminated: + return + + # get a frame of audio + frame = self.get_audio_frame() + + # the last frame + if frame is None: + # terminate, but always run the network for the left frames + self.terminated = True + else: + self.frames.append(frame) + # put to output + if self.play: + self.output_queue.put(frame) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] + + # discard the old part to save memory + if not self.terminated: + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + + logits, labels, text = self.frame_to_text(inputs) + feats = logits # better lips-sync than labels + + # save feats + if self.opt.asr_save_feats: + self.all_feats.append(feats) + + # record the feats efficiently.. (no concat, constant memory) + if not self.terminated: + start = self.feat_buffer_idx * self.context_size + end = start + feats.shape[0] + self.feat_queue[start:end] = feats + self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size + + # very naive, just concat the text output. + if text != '': + self.text = self.text + ' ' + text + + # will only run once at ternimation + if self.terminated: + self.text += '\n[END]' + print(self.text) + if self.opt.asr_save_feats: + print(f'[INFO] save all feats for training purpose... ') + feats = torch.cat(self.all_feats, dim=0) # [N, C] + # print('[INFO] before unfold', feats.shape) + window_size = 16 + padding = window_size // 2 + feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] + feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] + unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] + unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] + # print('[INFO] after unfold', unfold_feats.shape) + # save to a npy file + if 'esperanto' in self.opt.asr_model: + output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') + else: + output_path = self.opt.asr_wav.replace('.wav', '.npy') + np.save(output_path, unfold_feats.cpu().numpy()) + print(f"[INFO] saved logits to {output_path}") + + def create_file_stream(self): + + stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') + + return stream + + + def create_pyaudio_stream(self): + + import pyaudio + + print(f'[INFO] creating live audio stream ...') + + audio = pyaudio.PyAudio() + + # get devices + info = audio.get_host_api_info_by_index(0) + n_devices = info.get('deviceCount') + + for i in range(0, n_devices): + if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + name = audio.get_device_info_by_host_api_device_index(0, i).get('name') + print(f'[INFO] choose audio device {name}, id {i}') + break + + # get stream + stream = audio.open(input_device_index=i, + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + input=True, + frames_per_buffer=self.chunk) + + return audio, stream + + + def get_audio_frame(self): + + if self.mode == 'file': + + if self.idx < self.file_stream.shape[0]: + frame = self.file_stream[self.idx: self.idx + self.chunk] + self.idx = self.idx + self.chunk + return frame + else: + return None + + else: + + frame = self.queue.get() + # print(f'[INFO] get frame {frame.shape}') + + self.idx = self.idx + self.chunk + + return frame + + + def frame_to_text(self, frame): + # frame: [N * 320], N = (context_size + 2 * stride_size) + + inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) + + with torch.no_grad(): + result = self.model(inputs.input_values.to(self.device)) + logits = result.logits # [1, N - 1, 32] + + # cut off stride + left = max(0, self.stride_left_size) + right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. + + # do not cut right if terminated. + if self.terminated: + right = logits.shape[1] + + logits = logits[:, left:right] + + # print(frame.shape, inputs.input_values.shape, logits.shape) + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = self.processor.batch_decode(predicted_ids)[0].lower() + + + # for esperanto + # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) + + # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) + # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) + # print(predicted_ids[0]) + # print(transcription) + + return logits[0], predicted_ids[0], transcription # [N,] + + + def run(self): + + self.listen() + + while not self.terminated: + self.run_step() + + def clear_queue(self): + # clear the queue, to reduce potential latency... + print(f'[INFO] clear queue') + if self.mode == 'live': + self.queue.queue.clear() + if self.play: + self.output_queue.queue.clear() + + def warm_up(self): + + self.listen() + + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + for _ in range(self.warm_up_steps): + self.run_step() + if torch.cuda.is_available(): + torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + + self.clear_queue() + + + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--wav', type=str, default='') + parser.add_argument('--play', action='store_true', help="play out the audio") + + parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') + # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') + + parser.add_argument('--save_feats', action='store_true') + # audio FPS + parser.add_argument('--fps', type=int, default=50) + # sliding window left-middle-right length. + parser.add_argument('-l', type=int, default=10) + parser.add_argument('-m', type=int, default=50) + parser.add_argument('-r', type=int, default=10) + + opt = parser.parse_args() + + # fix + opt.asr_wav = opt.wav + opt.asr_play = opt.play + opt.asr_model = opt.model + opt.asr_save_feats = opt.save_feats + + if 'deepspeech' in opt.asr_model: + raise ValueError("DeepSpeech features should not use this code to extract...") + + with ASR(opt) as asr: + asr.run() \ No newline at end of file diff --git a/encoding.py b/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..17b1d7355d0be31b2bfc72b9f1de4adea0e095cb --- /dev/null +++ b/encoding.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FreqEncoder(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, **kwargs): + + out = [] + if self.include_input: + out.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + out = torch.cat(out, dim=-1) + + + return out + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency': + #encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) + from freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'sphere_harmonics': + from shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) + + elif encoding == 'tiledgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) + + elif encoding == 'ash': + from ashencoder import AshEncoder + encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..8a213d09b5343815a4c60ef7550d130321018a2b --- /dev/null +++ b/environment.yml @@ -0,0 +1,45 @@ +name: talking_gaussian +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cudatoolkit=11.3 + - plyfile=0.8.1 + - python=3.7.13 + - pip=22.3.1 + - pytorch=1.12.1 + - torchaudio=0.12.1 + - torchvision=0.13.1 + - tqdm + - ffmpeg + - openh264 + - pip: + - ./submodules/diff-gaussian-rasterization + - ./submodules/simple-knn + - ./gridencoder + - numpy + - pillow + - scipy + - tensorboard + - opencv-python + - tensorboardX + + - pandas + - tqdm + - matplotlib + - PyMCubes==0.1.4 + - rich + - packaging + - scikit-learn + + - face_alignment + - python_speech_features + - numba + - resampy + - pyaudio + - soundfile + - configargparse + + - lpips + - imageio-ffmpeg diff --git a/gaussian-splatting/.gitignore b/gaussian-splatting/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..114376106621fc03ff8f923c03536e9dafd0d3f5 --- /dev/null +++ b/gaussian-splatting/.gitignore @@ -0,0 +1,8 @@ +*.pyc +.vscode +output +build +diff_rasterization/diff_rast.egg-info +diff_rasterization/dist +tensorboard_3d +screenshots \ No newline at end of file diff --git a/gaussian-splatting/.gitmodules b/gaussian-splatting/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..d20bef20b74670643e1f7849848528cc41000160 --- /dev/null +++ b/gaussian-splatting/.gitmodules @@ -0,0 +1,9 @@ +[submodule "submodules/simple-knn"] + path = submodules/simple-knn + url = https://gitlab.inria.fr/bkerbl/simple-knn.git +[submodule "submodules/diff-gaussian-rasterization"] + path = submodules/diff-gaussian-rasterization + url = https://github.com/graphdeco-inria/diff-gaussian-rasterization +[submodule "SIBR_viewers"] + path = SIBR_viewers + url = https://gitlab.inria.fr/sibr/sibr_core.git diff --git a/gaussian-splatting/LICENSE.md b/gaussian-splatting/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..18445c6d34aedbf1ab9d282223f8f10ce38cd79a --- /dev/null +++ b/gaussian-splatting/LICENSE.md @@ -0,0 +1,91 @@ +Gaussian-Splatting License +=========================== + +**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. +The *Software* is in the process of being registered with the Agence pour la Protection des +Programmes (APP). + +The *Software* is still being developed by the *Licensor*. + +*Licensor*'s goal is to allow the research community to use, test and evaluate +the *Software*. + +## 1. Definitions + +*Licensee* means any person or entity that uses the *Software* and distributes +its *Work*. + +*Licensor* means the owners of the *Software*, i.e Inria and MPII + +*Software* means the original work of authorship made available under this +License ie gaussian-splatting. + +*Work* means the *Software* and any additions to or derivative works of the +*Software* that are made available under this License. + + +## 2. Purpose +This license is intended to define the rights granted to the *Licensee* by +Licensors under the *Software*. + +## 3. Rights granted + +For the above reasons Licensors have decided to distribute the *Software*. +Licensors grant non-exclusive rights to use the *Software* for research purposes +to research users (both academic and industrial), free of charge, without right +to sublicense.. The *Software* may be used "non-commercially", i.e., for research +and/or evaluation purposes only. + +Subject to the terms and conditions of this License, you are granted a +non-exclusive, royalty-free, license to reproduce, prepare derivative works of, +publicly display, publicly perform and distribute its *Work* and any resulting +derivative works in any form. + +## 4. Limitations + +**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do +so under this License, (b) you include a complete copy of this License with +your distribution, and (c) you retain without modification any copyright, +patent, trademark, or attribution notices that are present in the *Work*. + +**4.2 Derivative Works.** You may specify that additional or different terms apply +to the use, reproduction, and distribution of your derivative works of the *Work* +("Your Terms") only if (a) Your Terms provide that the use limitation in +Section 2 applies to your derivative works, and (b) you identify the specific +derivative works that are subject to Your Terms. Notwithstanding Your Terms, +this License (including the redistribution requirements in Section 3.1) will +continue to apply to the *Work* itself. + +**4.3** Any other use without of prior consent of Licensors is prohibited. Research +users explicitly acknowledge having received from Licensors all information +allowing to appreciate the adequacy between of the *Software* and their needs and +to undertake all necessary precautions for its execution and use. + +**4.4** The *Software* is provided both as a compiled library file and as source +code. In case of using the *Software* for a publication or other results obtained +through the use of the *Software*, users are strongly encouraged to cite the +corresponding publications as explained in the documentation of the *Software*. + +## 5. Disclaimer + +THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES +WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY +UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL +CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES +OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL +USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR +ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE +AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR +IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. + +## 6. Files subject to permissive licenses +The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. + +Title: pytorch-ssim\ +Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ +Copyright Evan Su, 2017\ +License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) \ No newline at end of file diff --git a/gaussian-splatting/README.md b/gaussian-splatting/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4cbd3326db4ec7fca26ce3bce104ce7b93d1cca8 --- /dev/null +++ b/gaussian-splatting/README.md @@ -0,0 +1,522 @@ +# 3D Gaussian Splatting for Real-Time Radiance Field Rendering +Bernhard Kerbl*, Georgios Kopanas*, Thomas Leimkühler, George Drettakis (* indicates equal contribution)
+| [Webpage](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) | [Full Paper](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf) | [Video](https://youtu.be/T_kXY43VZnk) | [Other GRAPHDECO Publications](http://www-sop.inria.fr/reves/publis/gdindex.php) | [FUNGRAPH project page](https://fungraph.inria.fr) |
+| [T&T+DB COLMAP (650MB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip) | [Pre-trained Models (14 GB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/pretrained/models.zip) | [Viewers for Windows (60MB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/binaries/viewers.zip) | [Evaluation Images (7 GB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/evaluation/images.zip) |
+![Teaser image](assets/teaser.png) + +This repository contains the official authors implementation associated with the paper "3D Gaussian Splatting for Real-Time Radiance Field Rendering", which can be found [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/). We further provide the reference images used to create the error metrics reported in the paper, as well as recently created, pre-trained models. + + + + + + +Abstract: *Radiance Field methods have recently revolutionized novel-view synthesis of scenes captured with multiple photos or videos. However, achieving high visual quality still requires neural networks that are costly to train and render, while recent faster methods inevitably trade off speed for quality. For unbounded and complete scenes (rather than isolated objects) and 1080p resolution rendering, no current method can achieve real-time display rates. We introduce three key elements that allow us to achieve state-of-the-art visual quality while maintaining competitive training times and importantly allow high-quality real-time (≥ 30 fps) novel-view synthesis at 1080p resolution. First, starting from sparse points produced during camera calibration, we represent the scene with 3D Gaussians that preserve desirable properties of continuous volumetric radiance fields for scene optimization while avoiding unnecessary computation in empty space; Second, we perform interleaved optimization/density control of the 3D Gaussians, notably optimizing anisotropic covariance to achieve an accurate representation of the scene; Third, we develop a fast visibility-aware rendering algorithm that supports anisotropic splatting and both accelerates training and allows realtime rendering. We demonstrate state-of-the-art visual quality and real-time rendering on several established datasets.* + +
+
+

BibTeX

+
@Article{kerbl3Dgaussians,
+      author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
+      title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
+      journal      = {ACM Transactions on Graphics},
+      number       = {4},
+      volume       = {42},
+      month        = {July},
+      year         = {2023},
+      url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
+}
+
+
+ + +## Funding and Acknowledgments + +This research was funded by the ERC Advanced grant FUNGRAPH No 788065. The authors are grateful to Adobe for generous donations, the OPAL infrastructure from Université Côte d’Azur and for the HPC resources from GENCI–IDRIS (Grant 2022-AD011013409). The authors thank the anonymous reviewers for their valuable feedback, P. Hedman and A. Tewari for proofreading earlier drafts also T. Müller, A. Yu and S. Fridovich-Keil for helping with the comparisons. + +## Step-by-step Tutorial + +Jonathan Stephens made a fantastic step-by-step tutorial for setting up Gaussian Splatting on your machine, along with instructions for creating usable datasets from videos. If the instructions below are too dry for you, go ahead and check it out [here](https://www.youtube.com/watch?v=UXtuigy_wYc). + +## Colab + +User [camenduru](https://github.com/camenduru) was kind enough to provide a Colab template that uses this repo's source (status: August 2023!) for quick and easy access to the method. Please check it out [here](https://github.com/camenduru/gaussian-splatting-colab). + +## Cloning the Repository + +The repository contains submodules, thus please check it out with +```shell +# SSH +git clone git@github.com:graphdeco-inria/gaussian-splatting.git --recursive +``` +or +```shell +# HTTPS +git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive +``` + +## Overview + +The codebase has 4 main components: +- A PyTorch-based optimizer to produce a 3D Gaussian model from SfM inputs +- A network viewer that allows to connect to and visualize the optimization process +- An OpenGL-based real-time viewer to render trained models in real-time. +- A script to help you turn your own images into optimization-ready SfM data sets + +The components have different requirements w.r.t. both hardware and software. They have been tested on Windows 10 and Ubuntu Linux 22.04. Instructions for setting up and running each of them are found in the sections below. + +## New features [Please check regularly!] + +We will be adding several new features soon. In the meantime Orange has kindly added [OpenXR support](#openXR-support) for VR viewing. Please come back soon, we will be adding other features, building among others on recent 3DGS followup papers. + +## Optimizer + +The optimizer uses PyTorch and CUDA extensions in a Python environment to produce trained models. + +### Hardware Requirements + +- CUDA-ready GPU with Compute Capability 7.0+ +- 24 GB VRAM (to train to paper evaluation quality) +- Please see FAQ for smaller VRAM configurations + +### Software Requirements +- Conda (recommended for easy setup) +- C++ Compiler for PyTorch extensions (we used Visual Studio 2019 for Windows) +- CUDA SDK 11 for PyTorch extensions, install *after* Visual Studio (we used 11.8, **known issues with 11.6**) +- C++ Compiler and CUDA SDK must be compatible + +### Setup + +#### Local Setup + +Our default, provided install method is based on Conda package and environment management: +```shell +SET DISTUTILS_USE_SDK=1 # Windows only +conda env create --file environment.yml +conda activate gaussian_splatting +``` +Please note that this process assumes that you have CUDA SDK **11** installed, not **12**. For modifications, see below. + +Tip: Downloading packages and creating a new environment with Conda can require a significant amount of disk space. By default, Conda will use the main system hard drive. You can avoid this by specifying a different package download location and an environment on a different drive: + +```shell +conda config --add pkgs_dirs / +conda env create --file environment.yml --prefix //gaussian_splatting +conda activate //gaussian_splatting +``` + +#### Modifications + +If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make modifications, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 12). Make sure to create an environment where PyTorch and its CUDA runtime version match and the installed CUDA SDK has no major version difference with PyTorch's CUDA version. + +#### Known Issues + +Some users experience problems building the submodules on Windows (```cl.exe: File not found``` or similar). Please consider the workaround for this problem from the FAQ. + +### Running + +To run the optimizer, simply use + +```shell +python train.py -s +``` + +
+Command Line Arguments for train.py + + #### --source_path / -s + Path to the source directory containing a COLMAP or Synthetic NeRF data set. + #### --model_path / -m + Path where the trained model should be stored (```output/``` by default). + #### --images / -i + Alternative subdirectory for COLMAP images (```images``` by default). + #### --eval + Add this flag to use a MipNeRF360-style training/test split for evaluation. + #### --resolution / -r + Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** + #### --data_device + Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training. Thanks to [HrsPythonix](https://github.com/HrsPythonix). + #### --white_background / -w + Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. + #### --sh_degree + Order of spherical harmonics to be used (no larger than 3). ```3``` by default. + #### --convert_SHs_python + Flag to make pipeline compute forward and backward of SHs with PyTorch instead of ours. + #### --convert_cov3D_python + Flag to make pipeline compute forward and backward of the 3D covariance with PyTorch instead of ours. + #### --debug + Enables debug mode if you experience erros. If the rasterizer fails, a ```dump``` file is created that you may forward to us in an issue so we can take a look. + #### --debug_from + Debugging is **slow**. You may specify an iteration (starting from 0) after which the above debugging becomes active. + #### --iterations + Number of total iterations to train for, ```30_000``` by default. + #### --ip + IP to start GUI server on, ```127.0.0.1``` by default. + #### --port + Port to use for GUI server, ```6009``` by default. + #### --test_iterations + Space-separated iterations at which the training script computes L1 and PSNR over test set, ```7000 30000``` by default. + #### --save_iterations + Space-separated iterations at which the training script saves the Gaussian model, ```7000 30000 ``` by default. + #### --checkpoint_iterations + Space-separated iterations at which to store a checkpoint for continuing later, saved in the model directory. + #### --start_checkpoint + Path to a saved checkpoint to continue training from. + #### --quiet + Flag to omit any text written to standard out pipe. + #### --feature_lr + Spherical harmonics features learning rate, ```0.0025``` by default. + #### --opacity_lr + Opacity learning rate, ```0.05``` by default. + #### --scaling_lr + Scaling learning rate, ```0.005``` by default. + #### --rotation_lr + Rotation learning rate, ```0.001``` by default. + #### --position_lr_max_steps + Number of steps (from 0) where position learning rate goes from ```initial``` to ```final```. ```30_000``` by default. + #### --position_lr_init + Initial 3D position learning rate, ```0.00016``` by default. + #### --position_lr_final + Final 3D position learning rate, ```0.0000016``` by default. + #### --position_lr_delay_mult + Position learning rate multiplier (cf. Plenoxels), ```0.01``` by default. + #### --densify_from_iter + Iteration where densification starts, ```500``` by default. + #### --densify_until_iter + Iteration where densification stops, ```15_000``` by default. + #### --densify_grad_threshold + Limit that decides if points should be densified based on 2D position gradient, ```0.0002``` by default. + #### --densification_interval + How frequently to densify, ```100``` (every 100 iterations) by default. + #### --opacity_reset_interval + How frequently to reset opacity, ```3_000``` by default. + #### --lambda_dssim + Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. + #### --percent_dense + Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. + +
+
+ +Note that similar to MipNeRF360, we target images at resolutions in the 1-1.6K pixel range. For convenience, arbitrary-size inputs can be passed and will be automatically resized if their width exceeds 1600 pixels. We recommend to keep this behavior, but you may force training to use your higher-resolution images by setting ```-r 1```. + +The MipNeRF360 scenes are hosted by the paper authors [here](https://jonbarron.info/mipnerf360/). You can find our SfM data sets for Tanks&Temples and Deep Blending [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip). If you do not provide an output model directory (```-m```), trained models are written to folders with randomized unique names inside the ```output``` directory. At this point, the trained models may be viewed with the real-time viewer (see further below). + +### Evaluation +By default, the trained models use all available images in the dataset. To train them while withholding a test set for evaluation, use the ```--eval``` flag. This way, you can render training/test sets and produce error metrics as follows: +```shell +python train.py -s --eval # Train with train/test split +python render.py -m # Generate renderings +python metrics.py -m # Compute error metrics on renderings +``` + +If you want to evaluate our [pre-trained models](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/pretrained/models.zip), you will have to download the corresponding source data sets and indicate their location to ```render.py``` with an additional ```--source_path/-s``` flag. Note: The pre-trained models were created with the release codebase. This code base has been cleaned up and includes bugfixes, hence the metrics you get from evaluating them will differ from those in the paper. +```shell +python render.py -m -s +python metrics.py -m +``` + +
+Command Line Arguments for render.py + + #### --model_path / -m + Path to the trained model directory you want to create renderings for. + #### --skip_train + Flag to skip rendering the training set. + #### --skip_test + Flag to skip rendering the test set. + #### --quiet + Flag to omit any text written to standard out pipe. + + **The below parameters will be read automatically from the model path, based on what was used for training. However, you may override them by providing them explicitly on the command line.** + + #### --source_path / -s + Path to the source directory containing a COLMAP or Synthetic NeRF data set. + #### --images / -i + Alternative subdirectory for COLMAP images (```images``` by default). + #### --eval + Add this flag to use a MipNeRF360-style training/test split for evaluation. + #### --resolution / -r + Changes the resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. ```1``` by default. + #### --white_background / -w + Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. + #### --convert_SHs_python + Flag to make pipeline render with computed SHs from PyTorch instead of ours. + #### --convert_cov3D_python + Flag to make pipeline render with computed 3D covariance from PyTorch instead of ours. + +
+ +
+Command Line Arguments for metrics.py + + #### --model_paths / -m + Space-separated list of model paths for which metrics should be computed. +
+
+ +We further provide the ```full_eval.py``` script. This script specifies the routine used in our evaluation and demonstrates the use of some additional parameters, e.g., ```--images (-i)``` to define alternative image directories within COLMAP data sets. If you have downloaded and extracted all the training data, you can run it like this: +```shell +python full_eval.py -m360 -tat -db +``` +In the current version, this process takes about 7h on our reference machine containing an A6000. If you want to do the full evaluation on our pre-trained models, you can specify their download location and skip training. +```shell +python full_eval.py -o --skip_training -m360 -tat -db +``` + +If you want to compute the metrics on our paper's [evaluation images](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/evaluation/images.zip), you can also skip rendering. In this case it is not necessary to provide the source datasets. You can compute metrics for multiple image sets at a time. +```shell +python full_eval.py -m /garden ... --skip_training --skip_rendering +``` + +
+Command Line Arguments for full_eval.py + + #### --skip_training + Flag to skip training stage. + #### --skip_rendering + Flag to skip rendering stage. + #### --skip_metrics + Flag to skip metrics calculation stage. + #### --output_path + Directory to put renderings and results in, ```./eval``` by default, set to pre-trained model location if evaluating them. + #### --mipnerf360 / -m360 + Path to MipNeRF360 source datasets, required if training or rendering. + #### --tanksandtemples / -tat + Path to Tanks&Temples source datasets, required if training or rendering. + #### --deepblending / -db + Path to Deep Blending source datasets, required if training or rendering. +
+
+ +## Interactive Viewers +We provide two interactive viewers for our method: remote and real-time. Our viewing solutions are based on the [SIBR](https://sibr.gitlabpages.inria.fr/) framework, developed by the GRAPHDECO group for several novel-view synthesis projects. + +### Hardware Requirements +- OpenGL 4.5-ready GPU and drivers (or latest MESA software) +- 4 GB VRAM recommended +- CUDA-ready GPU with Compute Capability 7.0+ (only for Real-Time Viewer) + +### Software Requirements +- Visual Studio or g++, **not Clang** (we used Visual Studio 2019 for Windows) +- CUDA SDK 11, install *after* Visual Studio (we used 11.8) +- CMake (recent version, we used 3.24) +- 7zip (only on Windows) + +### Pre-built Windows Binaries +We provide pre-built binaries for Windows [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/binaries/viewers.zip). We recommend using them on Windows for an efficient setup, since the building of SIBR involves several external dependencies that must be downloaded and compiled on-the-fly. + +### Installation from Source +If you cloned with submodules (e.g., using ```--recursive```), the source code for the viewers is found in ```SIBR_viewers```. The network viewer runs within the SIBR framework for Image-based Rendering applications. + +#### Windows +CMake should take care of your dependencies. +```shell +cd SIBR_viewers +cmake -Bbuild . +cmake --build build --target install --config RelWithDebInfo +``` +You may specify a different configuration, e.g. ```Debug``` if you need more control during development. + +#### Ubuntu 22.04 +You will need to install a few dependencies before running the project setup. +```shell +# Dependencies +sudo apt install -y libglew-dev libassimp-dev libboost-all-dev libgtk-3-dev libopencv-dev libglfw3-dev libavdevice-dev libavcodec-dev libeigen3-dev libxxf86vm-dev libembree-dev +# Project setup +cd SIBR_viewers +cmake -Bbuild . -DCMAKE_BUILD_TYPE=Release # add -G Ninja to build faster +cmake --build build -j24 --target install +``` + +#### Ubuntu 20.04 +Backwards compatibility with Focal Fossa is not fully tested, but building SIBR with CMake should still work after invoking +```shell +git checkout fossa_compatibility +``` + +### Navigation in SIBR Viewers +The SIBR interface provides several methods of navigating the scene. By default, you will be started with an FPS navigator, which you can control with ```W, A, S, D, Q, E``` for camera translation and ```I, K, J, L, U, O``` for rotation. Alternatively, you may want to use a Trackball-style navigator (select from the floating menu). You can also snap to a camera from the data set with the ```Snap to``` button or find the closest camera with ```Snap to closest```. The floating menues also allow you to change the navigation speed. You can use the ```Scaling Modifier``` to control the size of the displayed Gaussians, or show the initial point cloud. + +### Running the Network Viewer + + + +https://github.com/graphdeco-inria/gaussian-splatting/assets/40643808/90a2e4d3-cf2e-4633-b35f-bfe284e28ff7 + + + +After extracting or installing the viewers, you may run the compiled ```SIBR_remoteGaussian_app[_config]``` app in ```/bin```, e.g.: +```shell +.//bin/SIBR_remoteGaussian_app +``` +The network viewer allows you to connect to a running training process on the same or a different machine. If you are training on the same machine and OS, no command line parameters should be required: the optimizer communicates the location of the training data to the network viewer. By default, optimizer and network viewer will try to establish a connection on **localhost** on port **6009**. You can change this behavior by providing matching ```--ip``` and ```--port``` parameters to both the optimizer and the network viewer. If for some reason the path used by the optimizer to find the training data is not reachable by the network viewer (e.g., due to them running on different (virtual) machines), you may specify an override location to the viewer by using ```-s ```. + +
+Primary Command Line Arguments for Network Viewer + + #### --path / -s + Argument to override model's path to source dataset. + #### --ip + IP to use for connection to a running training script. + #### --port + Port to use for connection to a running training script. + #### --rendering-size + Takes two space separated numbers to define the resolution at which network rendering occurs, ```1200``` width by default. + Note that to enforce an aspect that differs from the input images, you need ```--force-aspect-ratio``` too. + #### --load_images + Flag to load source dataset images to be displayed in the top view for each camera. +
+
+ +### Running the Real-Time Viewer + + + + +https://github.com/graphdeco-inria/gaussian-splatting/assets/40643808/0940547f-1d82-4c2f-a616-44eabbf0f816 + + + + +After extracting or installing the viewers, you may run the compiled ```SIBR_gaussianViewer_app[_config]``` app in ```/bin```, e.g.: +```shell +.//bin/SIBR_gaussianViewer_app -m +``` + +It should suffice to provide the ```-m``` parameter pointing to a trained model directory. Alternatively, you can specify an override location for training input data using ```-s```. To use a specific resolution other than the auto-chosen one, specify ```--rendering-size ```. Combine it with ```--force-aspect-ratio``` if you want the exact resolution and don't mind image distortion. + +**To unlock the full frame rate, please disable V-Sync on your machine and also in the application (Menu → Display). In a multi-GPU system (e.g., laptop) your OpenGL/Display GPU should be the same as your CUDA GPU (e.g., by setting the application's GPU preference on Windows, see below) for maximum performance.** + +![Teaser image](assets/select.png) + +In addition to the initial point cloud and the splats, you also have the option to visualize the Gaussians by rendering them as ellipsoids from the floating menu. +SIBR has many other functionalities, please see the [documentation](https://sibr.gitlabpages.inria.fr/) for more details on the viewer, navigation options etc. There is also a Top View (available from the menu) that shows the placement of the input cameras and the original SfM point cloud; please note that Top View slows rendering when enabled. The real-time viewer also uses slightly more aggressive, fast culling, which can be toggled in the floating menu. If you ever encounter an issue that can be solved by turning fast culling off, please let us know. + +
+Primary Command Line Arguments for Real-Time Viewer + + #### --model-path / -m + Path to trained model. + #### --iteration + Specifies which of state to load if multiple are available. Defaults to latest available iteration. + #### --path / -s + Argument to override model's path to source dataset. + #### --rendering-size + Takes two space separated numbers to define the resolution at which real-time rendering occurs, ```1200``` width by default. Note that to enforce an aspect that differs from the input images, you need ```--force-aspect-ratio``` too. + #### --load_images + Flag to load source dataset images to be displayed in the top view for each camera. + #### --device + Index of CUDA device to use for rasterization if multiple are available, ```0``` by default. + #### --no_interop + Disables CUDA/GL interop forcibly. Use on systems that may not behave according to spec (e.g., WSL2 with MESA GL 4.5 software rendering). +
+
+ +## Processing your own Scenes + +Our COLMAP loaders expect the following dataset structure in the source path location: + +``` + +|---images +| |--- +| |--- +| |---... +|---sparse + |---0 + |---cameras.bin + |---images.bin + |---points3D.bin +``` + +For rasterization, the camera models must be either a SIMPLE_PINHOLE or PINHOLE camera. We provide a converter script ```convert.py```, to extract undistorted images and SfM information from input images. Optionally, you can use ImageMagick to resize the undistorted images. This rescaling is similar to MipNeRF360, i.e., it creates images with 1/2, 1/4 and 1/8 the original resolution in corresponding folders. To use them, please first install a recent version of COLMAP (ideally CUDA-powered) and ImageMagick. Put the images you want to use in a directory ```/input```. +``` + +|---input + |--- + |--- + |---... +``` + If you have COLMAP and ImageMagick on your system path, you can simply run +```shell +python convert.py -s [--resize] #If not resizing, ImageMagick is not needed +``` +Alternatively, you can use the optional parameters ```--colmap_executable``` and ```--magick_executable``` to point to the respective paths. Please note that on Windows, the executable should point to the COLMAP ```.bat``` file that takes care of setting the execution environment. Once done, `````` will contain the expected COLMAP data set structure with undistorted, resized input images, in addition to your original images and some temporary (distorted) data in the directory ```distorted```. + +If you have your own COLMAP dataset without undistortion (e.g., using ```OPENCV``` camera), you can try to just run the last part of the script: Put the images in ```input``` and the COLMAP info in a subdirectory ```distorted```: +``` + +|---input +| |--- +| |--- +| |---... +|---distorted + |---database.db + |---sparse + |---0 + |---... +``` +Then run +```shell +python convert.py -s --skip_matching [--resize] #If not resizing, ImageMagick is not needed +``` + +
+Command Line Arguments for convert.py + + #### --no_gpu + Flag to avoid using GPU in COLMAP. + #### --skip_matching + Flag to indicate that COLMAP info is available for images. + #### --source_path / -s + Location of the inputs. + #### --camera + Which camera model to use for the early matching steps, ```OPENCV``` by default. + #### --resize + Flag for creating resized versions of input images. + #### --colmap_executable + Path to the COLMAP executable (```.bat``` on Windows). + #### --magick_executable + Path to the ImageMagick executable. +
+
+ +### OpenXR support + +OpenXR is supported in the branch gaussian_code_release_openxr +Within that branch, you can find documentation for VR support [here](https://gitlab.inria.fr/sibr/sibr_core/-/tree/gaussian_code_release_openxr?ref_type=heads). + +## FAQ +- *Where do I get data sets, e.g., those referenced in ```full_eval.py```?* The MipNeRF360 data set is provided by the authors of the original paper on the project site. Note that two of the data sets cannot be openly shared and require you to consult the authors directly. For Tanks&Temples and Deep Blending, please use the download links provided at the top of the page. Alternatively, you may access the cloned data (status: August 2023!) from [HuggingFace](https://huggingface.co/camenduru/gaussian-splatting) + + +- *How can I use this for a much larger dataset, like a city district?* The current method was not designed for these, but given enough memory, it should work out. However, the approach can struggle in multi-scale detail scenes (extreme close-ups, mixed with far-away shots). This is usually the case in, e.g., driving data sets (cars close up, buildings far away). For such scenes, you can lower the ```--position_lr_init```, ```--position_lr_final``` and ```--scaling_lr``` (x0.3, x0.1, ...). The more extensive the scene, the lower these values should be. Below, we use default learning rates (left) and ```--position_lr_init 0.000016 --scaling_lr 0.001"``` (right). + +| ![Default learning rate result](assets/worse.png "title-1") | ![Reduced learning rate result](assets/better.png "title-2") | +| --- | --- | + +- *I'm on Windows and I can't manage to build the submodules, what do I do?* Consider following the steps in the excellent video tutorial [here](https://www.youtube.com/watch?v=UXtuigy_wYc), hopefully they should help. The order in which the steps are done is important! Alternatively, consider using the linked Colab template. + +- *It still doesn't work. It says something about ```cl.exe```. What do I do?* User Henry Pearce found a workaround. You can you try adding the visual studio path to your environment variables (your version number might differ); +```C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30133\bin\Hostx64\x64``` +Then make sure you start a new conda prompt and cd to your repo location and try this; +``` +conda activate gaussian_splatting +cd /gaussian-splatting +pip install submodules\diff-gaussian-rasterization +pip install submodules\simple-knn +``` + +- *I'm on macOS/Puppy Linux/Greenhat and I can't manage to build, what do I do?* Sorry, we can't provide support for platforms outside of the ones we list in this README. Consider using the linked Colab template. + +- *I don't have 24 GB of VRAM for training, what do I do?* The VRAM consumption is determined by the number of points that are being optimized, which increases over time. If you only want to train to 7k iterations, you will need significantly less. To do the full training routine and avoid running out of memory, you can increase the ```--densify_grad_threshold```, ```--densification_interval``` or reduce the value of ```--densify_until_iter```. Note however that this will affect the quality of the result. Also try setting ```--test_iterations``` to ```-1``` to avoid memory spikes during testing. If ```--densify_grad_threshold``` is very high, no densification should occur and training should complete if the scene itself loads successfully. + +- *24 GB of VRAM for reference quality training is still a lot! Can't we do it with less?* Yes, most likely. By our calculations it should be possible with **way** less memory (~8GB). If we can find the time we will try to achieve this. If some PyTorch veteran out there wants to tackle this, we look forward to your pull request! + + +- *How can I use the differentiable Gaussian rasterizer for my own project?* Easy, it is included in this repo as a submodule ```diff-gaussian-rasterization```. Feel free to check out and install the package. It's not really documented, but using it from the Python side is very straightforward (cf. ```gaussian_renderer/__init__.py```). + +- *Wait, but `````` isn't optimized and could be much better?* There are several parts we didn't even have time to think about improving (yet). The performance you get with this prototype is probably a rather slow baseline for what is physically possible. + +- *Something is broken, how did this happen?* We tried hard to provide a solid and comprehensible basis to make use of the paper's method. We have refactored the code quite a bit, but we have limited capacity to test all possible usage scenarios. Thus, if part of the website, the code or the performance is lacking, please create an issue. If we find the time, we will do our best to address it. diff --git a/gaussian-splatting/arguments/__init__.py b/gaussian-splatting/arguments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e13a551ebe455cf43f1d8785c75ac659e7d4b2d --- /dev/null +++ b/gaussian-splatting/arguments/__init__.py @@ -0,0 +1,112 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from argparse import ArgumentParser, Namespace +import sys +import os + +class GroupParams: + pass + +class ParamGroup: + def __init__(self, parser: ArgumentParser, name : str, fill_none = False): + group = parser.add_argument_group(name) + for key, value in vars(self).items(): + shorthand = False + if key.startswith("_"): + shorthand = True + key = key[1:] + t = type(value) + value = value if not fill_none else None + if shorthand: + if t == bool: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") + else: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) + else: + if t == bool: + group.add_argument("--" + key, default=value, action="store_true") + else: + group.add_argument("--" + key, default=value, type=t) + + def extract(self, args): + group = GroupParams() + for arg in vars(args).items(): + if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): + setattr(group, arg[0], arg[1]) + return group + +class ModelParams(ParamGroup): + def __init__(self, parser, sentinel=False): + self.sh_degree = 3 + self._source_path = "" + self._model_path = "" + self._images = "images" + self._resolution = -1 + self._white_background = False + self.data_device = "cuda" + self.eval = False + super().__init__(parser, "Loading Parameters", sentinel) + + def extract(self, args): + g = super().extract(args) + g.source_path = os.path.abspath(g.source_path) + return g + +class PipelineParams(ParamGroup): + def __init__(self, parser): + self.convert_SHs_python = False + self.compute_cov3D_python = False + self.debug = False + super().__init__(parser, "Pipeline Parameters") + +class OptimizationParams(ParamGroup): + def __init__(self, parser): + self.iterations = 30_000 + self.position_lr_init = 0.00016 + self.position_lr_final = 0.0000016 + self.position_lr_delay_mult = 0.01 + self.position_lr_max_steps = 30_000 + self.feature_lr = 0.0025 + self.opacity_lr = 0.05 + self.scaling_lr = 0.005 + self.rotation_lr = 0.001 + self.percent_dense = 0.01 + self.lambda_dssim = 0.2 + self.densification_interval = 100 + self.opacity_reset_interval = 3000 + self.densify_from_iter = 500 + self.densify_until_iter = 15_000 + self.densify_grad_threshold = 0.0002 + self.random_background = False + super().__init__(parser, "Optimization Parameters") + +def get_combined_args(parser : ArgumentParser): + cmdlne_string = sys.argv[1:] + cfgfile_string = "Namespace()" + args_cmdline = parser.parse_args(cmdlne_string) + + try: + cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") + print("Looking for config file in", cfgfilepath) + with open(cfgfilepath) as cfg_file: + print("Config file found: {}".format(cfgfilepath)) + cfgfile_string = cfg_file.read() + except TypeError: + print("Config file not found at") + pass + args_cfgfile = eval(cfgfile_string) + + merged_dict = vars(args_cfgfile).copy() + for k,v in vars(args_cmdline).items(): + if v != None: + merged_dict[k] = v + return Namespace(**merged_dict) diff --git a/gaussian-splatting/assets/better.png b/gaussian-splatting/assets/better.png new file mode 100644 index 0000000000000000000000000000000000000000..019c143829caa4e182c055307dd43c013329f88c Binary files /dev/null and b/gaussian-splatting/assets/better.png differ diff --git a/gaussian-splatting/assets/logo_graphdeco.png b/gaussian-splatting/assets/logo_graphdeco.png new file mode 100644 index 0000000000000000000000000000000000000000..4818ac47bdb1568c333857578d6a1fbc8038ee02 Binary files /dev/null and b/gaussian-splatting/assets/logo_graphdeco.png differ diff --git a/gaussian-splatting/assets/logo_inria.png b/gaussian-splatting/assets/logo_inria.png new file mode 100644 index 0000000000000000000000000000000000000000..f395b7a72aae9724195336429e8daf233c93576e Binary files /dev/null and b/gaussian-splatting/assets/logo_inria.png differ diff --git a/gaussian-splatting/assets/logo_mpi.png b/gaussian-splatting/assets/logo_mpi.png new file mode 100644 index 0000000000000000000000000000000000000000..2282e7f18597f5f57be85cd8d34332c34cc83a65 Binary files /dev/null and b/gaussian-splatting/assets/logo_mpi.png differ diff --git a/gaussian-splatting/assets/logo_mpi.svg b/gaussian-splatting/assets/logo_mpi.svg new file mode 100644 index 0000000000000000000000000000000000000000..6cb3a0069891a3cb4e9c8f8ece200891356840b9 --- /dev/null +++ b/gaussian-splatting/assets/logo_mpi.svg @@ -0,0 +1,488 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/gaussian-splatting/assets/logo_uca.png b/gaussian-splatting/assets/logo_uca.png new file mode 100644 index 0000000000000000000000000000000000000000..e7f1a6f0aa3bd1ce150b333ec5797a63b71db6e5 Binary files /dev/null and b/gaussian-splatting/assets/logo_uca.png differ diff --git a/gaussian-splatting/assets/select.png b/gaussian-splatting/assets/select.png new file mode 100644 index 0000000000000000000000000000000000000000..58a0ad9fbef31d5b53d4dd001ad5a1bf708c1704 Binary files /dev/null and b/gaussian-splatting/assets/select.png differ diff --git a/gaussian-splatting/assets/teaser.png b/gaussian-splatting/assets/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..98b8166be43d3e3b8df90c1d5c11822a85556dc9 Binary files /dev/null and b/gaussian-splatting/assets/teaser.png differ diff --git a/gaussian-splatting/assets/worse.png b/gaussian-splatting/assets/worse.png new file mode 100644 index 0000000000000000000000000000000000000000..b384897d82fcdc28076275ab63985388c3bc41e5 Binary files /dev/null and b/gaussian-splatting/assets/worse.png differ diff --git a/gaussian-splatting/convert.py b/gaussian-splatting/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..78948848f4849a88d686542790cd04f34f34beb0 --- /dev/null +++ b/gaussian-splatting/convert.py @@ -0,0 +1,124 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import logging +from argparse import ArgumentParser +import shutil + +# This Python script is based on the shell converter script provided in the MipNerF 360 repository. +parser = ArgumentParser("Colmap converter") +parser.add_argument("--no_gpu", action='store_true') +parser.add_argument("--skip_matching", action='store_true') +parser.add_argument("--source_path", "-s", required=True, type=str) +parser.add_argument("--camera", default="OPENCV", type=str) +parser.add_argument("--colmap_executable", default="", type=str) +parser.add_argument("--resize", action="store_true") +parser.add_argument("--magick_executable", default="", type=str) +args = parser.parse_args() +colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" +magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" +use_gpu = 1 if not args.no_gpu else 0 + +if not args.skip_matching: + os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) + + ## Feature extraction + feat_extracton_cmd = colmap_command + " feature_extractor "\ + "--database_path " + args.source_path + "/distorted/database.db \ + --image_path " + args.source_path + "/input \ + --ImageReader.single_camera 1 \ + --ImageReader.camera_model " + args.camera + " \ + --SiftExtraction.use_gpu " + str(use_gpu) + exit_code = os.system(feat_extracton_cmd) + if exit_code != 0: + logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") + exit(exit_code) + + ## Feature matching + feat_matching_cmd = colmap_command + " exhaustive_matcher \ + --database_path " + args.source_path + "/distorted/database.db \ + --SiftMatching.use_gpu " + str(use_gpu) + exit_code = os.system(feat_matching_cmd) + if exit_code != 0: + logging.error(f"Feature matching failed with code {exit_code}. Exiting.") + exit(exit_code) + + ### Bundle adjustment + # The default Mapper tolerance is unnecessarily large, + # decreasing it speeds up bundle adjustment steps. + mapper_cmd = (colmap_command + " mapper \ + --database_path " + args.source_path + "/distorted/database.db \ + --image_path " + args.source_path + "/input \ + --output_path " + args.source_path + "/distorted/sparse \ + --Mapper.ba_global_function_tolerance=0.000001") + exit_code = os.system(mapper_cmd) + if exit_code != 0: + logging.error(f"Mapper failed with code {exit_code}. Exiting.") + exit(exit_code) + +### Image undistortion +## We need to undistort our images into ideal pinhole intrinsics. +img_undist_cmd = (colmap_command + " image_undistorter \ + --image_path " + args.source_path + "/input \ + --input_path " + args.source_path + "/distorted/sparse/0 \ + --output_path " + args.source_path + "\ + --output_type COLMAP") +exit_code = os.system(img_undist_cmd) +if exit_code != 0: + logging.error(f"Mapper failed with code {exit_code}. Exiting.") + exit(exit_code) + +files = os.listdir(args.source_path + "/sparse") +os.makedirs(args.source_path + "/sparse/0", exist_ok=True) +# Copy each file from the source directory to the destination directory +for file in files: + if file == '0': + continue + source_file = os.path.join(args.source_path, "sparse", file) + destination_file = os.path.join(args.source_path, "sparse", "0", file) + shutil.move(source_file, destination_file) + +if(args.resize): + print("Copying and resizing...") + + # Resize images. + os.makedirs(args.source_path + "/images_2", exist_ok=True) + os.makedirs(args.source_path + "/images_4", exist_ok=True) + os.makedirs(args.source_path + "/images_8", exist_ok=True) + # Get the list of files in the source directory + files = os.listdir(args.source_path + "/images") + # Copy each file from the source directory to the destination directory + for file in files: + source_file = os.path.join(args.source_path, "images", file) + + destination_file = os.path.join(args.source_path, "images_2", file) + shutil.copy2(source_file, destination_file) + exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) + if exit_code != 0: + logging.error(f"50% resize failed with code {exit_code}. Exiting.") + exit(exit_code) + + destination_file = os.path.join(args.source_path, "images_4", file) + shutil.copy2(source_file, destination_file) + exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) + if exit_code != 0: + logging.error(f"25% resize failed with code {exit_code}. Exiting.") + exit(exit_code) + + destination_file = os.path.join(args.source_path, "images_8", file) + shutil.copy2(source_file, destination_file) + exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) + if exit_code != 0: + logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") + exit(exit_code) + +print("Done.") diff --git a/gaussian-splatting/environment.yml b/gaussian-splatting/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..d479ec71524a44aee91860f9ef54a593620aef13 --- /dev/null +++ b/gaussian-splatting/environment.yml @@ -0,0 +1,17 @@ +name: gaussian_splatting +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cudatoolkit=11.6 + - plyfile + - python=3.7.13 + - pip=22.3.1 + - pytorch=1.12.1 + - torchaudio=0.12.1 + - torchvision=0.13.1 + - tqdm + - pip: + - submodules/diff-gaussian-rasterization + - submodules/simple-knn diff --git a/gaussian-splatting/full_eval.py b/gaussian-splatting/full_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbb12247724b25563e215b4409ded9af1cbdd04 --- /dev/null +++ b/gaussian-splatting/full_eval.py @@ -0,0 +1,75 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +from argparse import ArgumentParser + +mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] +mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] +tanks_and_temples_scenes = ["truck", "train"] +deep_blending_scenes = ["drjohnson", "playroom"] + +parser = ArgumentParser(description="Full evaluation script parameters") +parser.add_argument("--skip_training", action="store_true") +parser.add_argument("--skip_rendering", action="store_true") +parser.add_argument("--skip_metrics", action="store_true") +parser.add_argument("--output_path", default="./eval") +args, _ = parser.parse_known_args() + +all_scenes = [] +all_scenes.extend(mipnerf360_outdoor_scenes) +all_scenes.extend(mipnerf360_indoor_scenes) +all_scenes.extend(tanks_and_temples_scenes) +all_scenes.extend(deep_blending_scenes) + +if not args.skip_training or not args.skip_rendering: + parser.add_argument('--mipnerf360', "-m360", required=True, type=str) + parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) + parser.add_argument("--deepblending", "-db", required=True, type=str) + args = parser.parse_args() + +if not args.skip_training: + common_args = " --quiet --eval --test_iterations -1 " + for scene in mipnerf360_outdoor_scenes: + source = args.mipnerf360 + "/" + scene + os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) + for scene in mipnerf360_indoor_scenes: + source = args.mipnerf360 + "/" + scene + os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) + for scene in tanks_and_temples_scenes: + source = args.tanksandtemples + "/" + scene + os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) + for scene in deep_blending_scenes: + source = args.deepblending + "/" + scene + os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) + +if not args.skip_rendering: + all_sources = [] + for scene in mipnerf360_outdoor_scenes: + all_sources.append(args.mipnerf360 + "/" + scene) + for scene in mipnerf360_indoor_scenes: + all_sources.append(args.mipnerf360 + "/" + scene) + for scene in tanks_and_temples_scenes: + all_sources.append(args.tanksandtemples + "/" + scene) + for scene in deep_blending_scenes: + all_sources.append(args.deepblending + "/" + scene) + + common_args = " --quiet --eval --skip_train" + for scene, source in zip(all_scenes, all_sources): + os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) + os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) + +if not args.skip_metrics: + scenes_string = "" + for scene in all_scenes: + scenes_string += "\"" + args.output_path + "/" + scene + "\" " + + os.system("python metrics.py -m " + scenes_string) \ No newline at end of file diff --git a/gaussian-splatting/gaussian_renderer/__init__.py b/gaussian-splatting/gaussian_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f74e336af41e042dfb9f1c308e40caf17d0b3211 --- /dev/null +++ b/gaussian-splatting/gaussian_renderer/__init__.py @@ -0,0 +1,100 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer +from scene.gaussian_model import GaussianModel +from utils.sh_utils import eval_sh + +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii} diff --git a/gaussian-splatting/gaussian_renderer/network_gui.py b/gaussian-splatting/gaussian_renderer/network_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..df2f9dae782b24527ae5b09f91ca4009361de53f --- /dev/null +++ b/gaussian-splatting/gaussian_renderer/network_gui.py @@ -0,0 +1,86 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import traceback +import socket +import json +from scene.cameras import MiniCam + +host = "127.0.0.1" +port = 6009 + +conn = None +addr = None + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +def init(wish_host, wish_port): + global host, port, listener + host = wish_host + port = wish_port + listener.bind((host, port)) + listener.listen() + listener.settimeout(0) + +def try_connect(): + global conn, addr, listener + try: + conn, addr = listener.accept() + print(f"\nConnected by {addr}") + conn.settimeout(None) + except Exception as inst: + pass + +def read(): + global conn + messageLength = conn.recv(4) + messageLength = int.from_bytes(messageLength, 'little') + message = conn.recv(messageLength) + return json.loads(message.decode("utf-8")) + +def send(message_bytes, verify): + global conn + if message_bytes != None: + conn.sendall(message_bytes) + conn.sendall(len(verify).to_bytes(4, 'little')) + conn.sendall(bytes(verify, 'ascii')) + +def receive(): + message = read() + + width = message["resolution_x"] + height = message["resolution_y"] + + if width != 0 and height != 0: + try: + do_training = bool(message["train"]) + fovy = message["fov_y"] + fovx = message["fov_x"] + znear = message["z_near"] + zfar = message["z_far"] + do_shs_python = bool(message["shs_python"]) + do_rot_scale_python = bool(message["rot_scale_python"]) + keep_alive = bool(message["keep_alive"]) + scaling_modifier = message["scaling_modifier"] + world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() + world_view_transform[:,1] = -world_view_transform[:,1] + world_view_transform[:,2] = -world_view_transform[:,2] + full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() + full_proj_transform[:,1] = -full_proj_transform[:,1] + custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) + except Exception as e: + print("") + traceback.print_exc() + raise e + return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier + else: + return None, None, None, None, None, None \ No newline at end of file diff --git a/gaussian-splatting/lpipsPyTorch/__init__.py b/gaussian-splatting/lpipsPyTorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6297daa457d1d041c9491dfdf6a75994ffe06e --- /dev/null +++ b/gaussian-splatting/lpipsPyTorch/__init__.py @@ -0,0 +1,21 @@ +import torch + +from .modules.lpips import LPIPS + + +def lpips(x: torch.Tensor, + y: torch.Tensor, + net_type: str = 'alex', + version: str = '0.1'): + r"""Function that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + x, y (torch.Tensor): the input tensors to compare. + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + device = x.device + criterion = LPIPS(net_type, version).to(device) + return criterion(x, y) diff --git a/gaussian-splatting/lpipsPyTorch/modules/lpips.py b/gaussian-splatting/lpipsPyTorch/modules/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd001d1e1036c7f8f5db62e81446e2ff2db80ab --- /dev/null +++ b/gaussian-splatting/lpipsPyTorch/modules/lpips.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + +from .networks import get_network, LinLayers +from .utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type) + + # linear layers + self.lin = LinLayers(self.net.n_channels_list) + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0), 0, True) diff --git a/gaussian-splatting/lpipsPyTorch/modules/networks.py b/gaussian-splatting/lpipsPyTorch/modules/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..d36c6a56163004d49c321da5e26404af9baa4c2a --- /dev/null +++ b/gaussian-splatting/lpipsPyTorch/modules/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from .utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) diff --git a/gaussian-splatting/lpipsPyTorch/modules/utils.py b/gaussian-splatting/lpipsPyTorch/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5 --- /dev/null +++ b/gaussian-splatting/lpipsPyTorch/modules/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/gaussian-splatting/metrics.py b/gaussian-splatting/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f7393a4c9b6978bb34c7121a66628335690b3279 --- /dev/null +++ b/gaussian-splatting/metrics.py @@ -0,0 +1,103 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from pathlib import Path +import os +from PIL import Image +import torch +import torchvision.transforms.functional as tf +from utils.loss_utils import ssim +from lpipsPyTorch import lpips +import json +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser + +def readImages(renders_dir, gt_dir): + renders = [] + gts = [] + image_names = [] + for fname in os.listdir(renders_dir): + render = Image.open(renders_dir / fname) + gt = Image.open(gt_dir / fname) + renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) + gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) + image_names.append(fname) + return renders, gts, image_names + +def evaluate(model_paths): + + full_dict = {} + per_view_dict = {} + full_dict_polytopeonly = {} + per_view_dict_polytopeonly = {} + print("") + + for scene_dir in model_paths: + try: + print("Scene:", scene_dir) + full_dict[scene_dir] = {} + per_view_dict[scene_dir] = {} + full_dict_polytopeonly[scene_dir] = {} + per_view_dict_polytopeonly[scene_dir] = {} + + test_dir = Path(scene_dir) / "test" + + for method in os.listdir(test_dir): + print("Method:", method) + + full_dict[scene_dir][method] = {} + per_view_dict[scene_dir][method] = {} + full_dict_polytopeonly[scene_dir][method] = {} + per_view_dict_polytopeonly[scene_dir][method] = {} + + method_dir = test_dir / method + gt_dir = method_dir/ "gt" + renders_dir = method_dir / "renders" + renders, gts, image_names = readImages(renders_dir, gt_dir) + + ssims = [] + psnrs = [] + lpipss = [] + + for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): + ssims.append(ssim(renders[idx], gts[idx])) + psnrs.append(psnr(renders[idx], gts[idx])) + lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) + + print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) + print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) + print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) + print("") + + full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), + "PSNR": torch.tensor(psnrs).mean().item(), + "LPIPS": torch.tensor(lpipss).mean().item()}) + per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, + "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, + "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) + + with open(scene_dir + "/results.json", 'w') as fp: + json.dump(full_dict[scene_dir], fp, indent=True) + with open(scene_dir + "/per_view.json", 'w') as fp: + json.dump(per_view_dict[scene_dir], fp, indent=True) + except: + print("Unable to compute metrics for model", scene_dir) + +if __name__ == "__main__": + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) + args = parser.parse_args() + evaluate(args.model_paths) diff --git a/gaussian-splatting/render.py b/gaussian-splatting/render.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6b82de8d967b233d7f7a958dfec4f4fe67243d --- /dev/null +++ b/gaussian-splatting/render.py @@ -0,0 +1,66 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from scene import Scene +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args +from gaussian_renderer import GaussianModel + +def render_set(model_path, name, iteration, views, gaussians, pipeline, background): + render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") + gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + + makedirs(render_path, exist_ok=True) + makedirs(gts_path, exist_ok=True) + + for idx, view in enumerate(tqdm(views, desc="Rendering progress")): + rendering = render(view, gaussians, pipeline, background)["render"] + gt = view.original_image[0:3, :, :] + torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) + torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) + +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): + with torch.no_grad(): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) + + bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + if not skip_train: + render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) + + if not skip_test: + render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Testing script parameters") + model = ModelParams(parser, sentinel=True) + pipeline = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = get_combined_args(parser) + print("Rendering " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file diff --git a/gaussian-splatting/scene/__init__.py b/gaussian-splatting/scene/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b31398b8527b7f37428fcaff4dac5b156d936db --- /dev/null +++ b/gaussian-splatting/scene/__init__.py @@ -0,0 +1,93 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import json +from utils.system_utils import searchForMaxIteration +from scene.dataset_readers import sceneLoadTypeCallbacks +from scene.gaussian_model import GaussianModel +from arguments import ModelParams +from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON + +class Scene: + + gaussians : GaussianModel + + def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): + """b + :param path: Path to colmap scene main folder. + """ + self.model_path = args.model_path + self.loaded_iter = None + self.gaussians = gaussians + + if load_iteration: + if load_iteration == -1: + self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) + else: + self.loaded_iter = load_iteration + print("Loading trained model at iteration {}".format(self.loaded_iter)) + + self.train_cameras = {} + self.test_cameras = {} + + if os.path.exists(os.path.join(args.source_path, "sparse")): + scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) + elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): + print("Found transforms_train.json file, assuming Blender data set!") + scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) + else: + assert False, "Could not recognize scene type!" + + if not self.loaded_iter: + with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: + dest_file.write(src_file.read()) + json_cams = [] + camlist = [] + if scene_info.test_cameras: + camlist.extend(scene_info.test_cameras) + if scene_info.train_cameras: + camlist.extend(scene_info.train_cameras) + for id, cam in enumerate(camlist): + json_cams.append(camera_to_JSON(id, cam)) + with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: + json.dump(json_cams, file) + + if shuffle: + random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling + random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling + + self.cameras_extent = scene_info.nerf_normalization["radius"] + + for resolution_scale in resolution_scales: + print("Loading Training Cameras") + self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) + print("Loading Test Cameras") + self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) + + if self.loaded_iter: + self.gaussians.load_ply(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "point_cloud.ply")) + else: + self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) + + def save(self, iteration): + point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + + def getTrainCameras(self, scale=1.0): + return self.train_cameras[scale] + + def getTestCameras(self, scale=1.0): + return self.test_cameras[scale] \ No newline at end of file diff --git a/gaussian-splatting/scene/cameras.py b/gaussian-splatting/scene/cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..abf6e5242bc46ef1915ce24619a8319d0b7591c7 --- /dev/null +++ b/gaussian-splatting/scene/cameras.py @@ -0,0 +1,71 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from torch import nn +import numpy as np +from utils.graphics_utils import getWorld2View2, getProjectionMatrix + +class Camera(nn.Module): + def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, + image_name, uid, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + ): + super(Camera, self).__init__() + + self.uid = uid + self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.image_name = image_name + + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + if gt_alpha_mask is not None: + self.original_image *= gt_alpha_mask.to(self.data_device) + else: + self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + +class MiniCam: + def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + self.world_view_transform = world_view_transform + self.full_proj_transform = full_proj_transform + view_inv = torch.inverse(self.world_view_transform) + self.camera_center = view_inv[3][:3] + diff --git a/gaussian-splatting/scene/colmap_loader.py b/gaussian-splatting/scene/colmap_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6fba6a9c961f52c88780ecb44d7821b4cb73ee --- /dev/null +++ b/gaussian-splatting/scene/colmap_loader.py @@ -0,0 +1,294 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import numpy as np +import collections +import struct + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + xyzs = None + rgbs = None + errors = None + num_points = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + num_points += 1 + + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + count = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = np.array(float(elems[7])) + xyzs[count] = xyz + rgbs[count] = rgb + errors[count] = error + count += 1 + + return xyzs, rgbs, errors + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + + + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + + for p_id in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + xyzs[p_id] = xyz + rgbs[p_id] = rgb + errors[p_id] = error + return xyzs, rgbs, errors + +def read_intrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + +def read_extrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_intrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_extrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_colmap_bin_array(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py + + :param path: path to the colmap binary file. + :return: nd array with the floating point values in the value + """ + with open(path, "rb") as fid: + width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, + usecols=(0, 1, 2), dtype=int) + fid.seek(0) + num_delimiter = 0 + byte = fid.read(1) + while True: + if byte == b"&": + num_delimiter += 1 + if num_delimiter >= 3: + break + byte = fid.read(1) + array = np.fromfile(fid, np.float32) + array = array.reshape((width, height, channels), order="F") + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/gaussian-splatting/scene/dataset_readers.py b/gaussian-splatting/scene/dataset_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6f904a92048e9146e5999fb4c0f0ff7152aa03 --- /dev/null +++ b/gaussian-splatting/scene/dataset_readers.py @@ -0,0 +1,260 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import sys +from PIL import Image +from typing import NamedTuple +from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ + read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +import numpy as np +import json +from pathlib import Path +from plyfile import PlyData, PlyElement +from utils.sh_utils import SH2RGB +from scene.gaussian_model import BasicPointCloud + +class CameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + image: np.array + image_path: str + image_name: str + width: int + height: int + +class SceneInfo(NamedTuple): + point_cloud: BasicPointCloud + train_cameras: list + test_cameras: list + nerf_normalization: dict + ply_path: str + +def getNerfppNorm(cam_info): + def get_center_and_diag(cam_centers): + cam_centers = np.hstack(cam_centers) + avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) + center = avg_cam_center + dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) + diagonal = np.max(dist) + return center.flatten(), diagonal + + cam_centers = [] + + for cam in cam_info: + W2C = getWorld2View2(cam.R, cam.T) + C2W = np.linalg.inv(W2C) + cam_centers.append(C2W[:3, 3:4]) + + center, diagonal = get_center_and_diag(cam_centers) + radius = diagonal * 1.1 + + translate = -center + + return {"translate": translate, "radius": radius} + +def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): + cam_infos = [] + for idx, key in enumerate(cam_extrinsics): + sys.stdout.write('\r') + # the exact output you're looking for: + sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) + sys.stdout.flush() + + extr = cam_extrinsics[key] + intr = cam_intrinsics[extr.camera_id] + height = intr.height + width = intr.width + + uid = intr.id + R = np.transpose(qvec2rotmat(extr.qvec)) + T = np.array(extr.tvec) + + if intr.model=="SIMPLE_PINHOLE": + focal_length_x = intr.params[0] + FovY = focal2fov(focal_length_x, height) + FovX = focal2fov(focal_length_x, width) + elif intr.model=="PINHOLE": + focal_length_x = intr.params[0] + focal_length_y = intr.params[1] + FovY = focal2fov(focal_length_y, height) + FovX = focal2fov(focal_length_x, width) + else: + assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" + + image_path = os.path.join(images_folder, os.path.basename(extr.name)) + image_name = os.path.basename(image_path).split(".")[0] + image = Image.open(image_path) + + cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=width, height=height) + cam_infos.append(cam_info) + sys.stdout.write('\n') + return cam_infos + +def fetchPly(path): + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + return BasicPointCloud(points=positions, colors=colors, normals=normals) + +def storePly(path, xyz, rgb): + # Define the dtype for the structured array + dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), + ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), + ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] + + normals = np.zeros_like(xyz) + + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb), axis=1) + elements[:] = list(map(tuple, attributes)) + + # Create the PlyData object and write to file + vertex_element = PlyElement.describe(elements, 'vertex') + ply_data = PlyData([vertex_element]) + ply_data.write(path) + +def readColmapSceneInfo(path, images, eval, llffhold=8): + try: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") + cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) + except: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") + cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) + + reading_dir = "images" if images == None else images + cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) + cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) + + if eval: + train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] + test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] + else: + train_cam_infos = cam_infos + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "sparse/0/points3D.ply") + bin_path = os.path.join(path, "sparse/0/points3D.bin") + txt_path = os.path.join(path, "sparse/0/points3D.txt") + if not os.path.exists(ply_path): + print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") + try: + xyz, rgb, _ = read_points3D_binary(bin_path) + except: + xyz, rgb, _ = read_points3D_text(txt_path) + storePly(ply_path, xyz, rgb) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path) + return scene_info + +def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): + cam_infos = [] + + with open(os.path.join(path, transformsfile)) as json_file: + contents = json.load(json_file) + fovx = contents["camera_angle_x"] + + frames = contents["frames"] + for idx, frame in enumerate(frames): + cam_name = os.path.join(path, frame["file_path"] + extension) + + # NeRF 'transform_matrix' is a camera-to-world transform + c2w = np.array(frame["transform_matrix"]) + # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) + c2w[:3, 1:3] *= -1 + + # get the world-to-camera transform and set R, T + w2c = np.linalg.inv(c2w) + R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code + T = w2c[:3, 3] + + image_path = os.path.join(path, cam_name) + image_name = Path(cam_name).stem + image = Image.open(image_path) + + im_data = np.array(image.convert("RGBA")) + + bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) + + norm_data = im_data / 255.0 + arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) + image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + + fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) + FovY = fovy + FovX = fovx + + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) + + return cam_infos + +def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): + print("Reading Training Transforms") + train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) + print("Reading Test Transforms") + test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) + + if not eval: + train_cam_infos.extend(test_cam_infos) + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "points3d.ply") + if not os.path.exists(ply_path): + # Since this data set has no colmap data, we start with random points + num_pts = 100_000 + print(f"Generating random point cloud ({num_pts})...") + + # We create random points inside the bounds of the synthetic Blender scenes + xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) + + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path) + return scene_info + +sceneLoadTypeCallbacks = { + "Colmap": readColmapSceneInfo, + "Blender" : readNerfSyntheticInfo +} \ No newline at end of file diff --git a/gaussian-splatting/scene/gaussian_model.py b/gaussian-splatting/scene/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..632a1e8e160f023dcc9a25badae21649f2b8d0c2 --- /dev/null +++ b/gaussian-splatting/scene/gaussian_model.py @@ -0,0 +1,407 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from utils.sh_utils import RGB2SH +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation + +class GaussianModel: + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + + def __init__(self, sh_degree : int): + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + self._xyz = torch.empty(0) + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + self.denom = torch.empty(0) + self.optimizer = None + self.percent_dense = 0 + self.spatial_lr_scale = 0 + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + self.xyz_gradient_accum, + self.denom, + self.optimizer.state_dict(), + self.spatial_lr_scale, + ) + + def restore(self, model_args, training_args): + (self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + xyz_gradient_accum, + denom, + opt_dict, + self.spatial_lr_scale) = model_args + self.training_setup(training_args) + self.xyz_gradient_accum = xyz_gradient_accum + self.denom = denom + self.optimizer.load_state_dict(opt_dict) + + @property + def get_scaling(self): + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) + scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, + lr_final=training_args.position_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + + def update_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def reset_opacity(self): + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + + self.active_sh_degree = self.max_sh_degree + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] == name: + stored_state = self.optimizer.state.get(group['params'][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): + d = {"xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling" : new_scaling, + "rotation" : new_rotation} + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[:grads.shape[0]] = grads.squeeze() + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + + stds = self.get_scaling[selected_pts_mask].repeat(N,1) + means =torch.zeros((stds.size(0), 3),device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N,1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) + new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) + + prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) + + new_xyz = self._xyz[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) + + def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) + self.denom[update_filter] += 1 \ No newline at end of file diff --git a/gaussian-splatting/train.py b/gaussian-splatting/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5d819b3481ee4bad45d18b55ae27c68690e32fab --- /dev/null +++ b/gaussian-splatting/train.py @@ -0,0 +1,222 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from utils.loss_utils import l1_loss, ssim +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() + if custom_cam != None: + net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) + + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config['cameras']): + image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) + gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) + if tb_writer and (idx < 5): + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) + if iteration == testing_iterations[0]: + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + if tb_writer: + tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--debug_from', type=int, default=-1) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + + # All done + print("\nTraining complete.") diff --git a/gaussian-splatting/utils/camera_utils.py b/gaussian-splatting/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a54d0ada0361997109c462cde1e088ea5da9ff2 --- /dev/null +++ b/gaussian-splatting/utils/camera_utils.py @@ -0,0 +1,82 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +WARNED = False + +def loadCam(args, id, cam_info, resolution_scale): + orig_w, orig_h = cam_info.image.size + + if args.resolution in [1, 2, 4, 8]: + resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) + else: # should be a type that converts to float + if args.resolution == -1: + if orig_w > 1600: + global WARNED + if not WARNED: + print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " + "If this is not desired, please explicitly specify '--resolution/-r' as 1") + WARNED = True + global_down = orig_w / 1600 + else: + global_down = 1 + else: + global_down = orig_w / args.resolution + + scale = float(global_down) * float(resolution_scale) + resolution = (int(orig_w / scale), int(orig_h / scale)) + + resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + gt_image = resized_image_rgb[:3, ...] + loaded_mask = None + + if resized_image_rgb.shape[1] == 4: + loaded_mask = resized_image_rgb[3:4, ...] + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=gt_image, gt_alpha_mask=loaded_mask, + image_name=cam_info.image_name, uid=id, data_device=args.data_device) + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + +def camera_to_JSON(id, camera : Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + 'id' : id, + 'img_name' : camera.image_name, + 'width' : camera.width, + 'height' : camera.height, + 'position': pos.tolist(), + 'rotation': serializable_array_2d, + 'fy' : fov2focal(camera.FovY, camera.height), + 'fx' : fov2focal(camera.FovX, camera.width) + } + return camera_entry diff --git a/gaussian-splatting/utils/general_utils.py b/gaussian-splatting/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d --- /dev/null +++ b/gaussian-splatting/utils/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/gaussian-splatting/utils/graphics_utils.py b/gaussian-splatting/utils/graphics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4627d837c74fcdffc898fa0c3071cb7b316802b --- /dev/null +++ b/gaussian-splatting/utils/graphics_utils.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +import numpy as np +from typing import NamedTuple + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) \ No newline at end of file diff --git a/gaussian-splatting/utils/image_utils.py b/gaussian-splatting/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaa1b6d250e549181ab165070f82ccd31b3eb9 --- /dev/null +++ b/gaussian-splatting/utils/image_utils.py @@ -0,0 +1,19 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/gaussian-splatting/utils/loss_utils.py b/gaussian-splatting/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9defc23a913e5d861aa5adc63270050884923094 --- /dev/null +++ b/gaussian-splatting/utils/loss_utils.py @@ -0,0 +1,64 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + diff --git a/gaussian-splatting/utils/sh_utils.py b/gaussian-splatting/utils/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/gaussian-splatting/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/gaussian-splatting/utils/system_utils.py b/gaussian-splatting/utils/system_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e --- /dev/null +++ b/gaussian-splatting/utils/system_utils.py @@ -0,0 +1,28 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0eaa82f994402980b1f429614e18661af10f6bd --- /dev/null +++ b/gaussian_renderer/__init__.py @@ -0,0 +1,271 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer +from scene.gaussian_model import GaussianModel +from scene.motion_net import MotionNetwork, MouthMotionNetwork +from utils.sh_utils import eval_sh + +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "depth": rendered_depth, + "alpha": rendered_alpha, + "radii": radii} + + +def render_motion(viewpoint_camera, pc : GaussianModel, motion_net : MotionNetwork, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, frame_idx = None, return_attn = False): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + audio_feat = viewpoint_camera.talking_dict["auds"].cuda() + exp_feat = viewpoint_camera.talking_dict["au_exp"].cuda() + + # ind_code = motion_net.individual_codes[frame_idx if frame_idx is not None else viewpoint_camera.talking_dict["img_id"]] + ind_code = None + motion_preds = motion_net(pc.get_xyz, audio_feat, exp_feat, ind_code) # + means3D = pc.get_xyz + motion_preds['d_xyz'] + means2D = screenspace_points + # opacity = pc.opacity_activation(pc._opacity + motion_preds['d_opa']) + opacity = pc.get_opacity + + cov3D_precomp = None + # scales = pc.get_scaling + scales = pc.scaling_activation(pc._scaling + motion_preds['d_scale']) + rotations = pc.rotation_activation(pc._rotation + motion_preds['d_rot']) + + colors_precomp = None + shs = pc.get_features + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + + # Attn + rendered_attn = None + if return_attn: + attn_precomp = torch.cat([motion_preds['ambient_aud'], motion_preds['ambient_eye'], torch.zeros_like(motion_preds['ambient_eye'])], dim=-1) + rendered_attn, _, _, _ = rasterizer( + means3D = means3D.detach(), + means2D = means2D, + shs = None, + colors_precomp = attn_precomp, + opacities = opacity.detach(), + scales = scales.detach(), + rotations = rotations.detach(), + cov3D_precomp = cov3D_precomp) + + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "depth": rendered_depth, + "alpha": rendered_alpha, + "radii": radii, + "motion": motion_preds, + 'attn': rendered_attn} + + + + + +def render_motion_mouth(viewpoint_camera, pc : GaussianModel, motion_net : MouthMotionNetwork, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, frame_idx = None, return_attn = False): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + audio_feat = viewpoint_camera.talking_dict["auds"].cuda() + + motion_preds = motion_net(pc.get_xyz, audio_feat) + means3D = pc.get_xyz + motion_preds['d_xyz'] + means2D = screenspace_points + opacity = pc.get_opacity + + cov3D_precomp = None + scales = pc.get_scaling + rotations = pc.get_rotation + + colors_precomp = None + shs = pc.get_features + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "depth": rendered_depth, + "alpha": rendered_alpha, + "radii": radii, + "motion": motion_preds} + diff --git a/gaussian_renderer/network_gui.py b/gaussian_renderer/network_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..df2f9dae782b24527ae5b09f91ca4009361de53f --- /dev/null +++ b/gaussian_renderer/network_gui.py @@ -0,0 +1,86 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import traceback +import socket +import json +from scene.cameras import MiniCam + +host = "127.0.0.1" +port = 6009 + +conn = None +addr = None + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +def init(wish_host, wish_port): + global host, port, listener + host = wish_host + port = wish_port + listener.bind((host, port)) + listener.listen() + listener.settimeout(0) + +def try_connect(): + global conn, addr, listener + try: + conn, addr = listener.accept() + print(f"\nConnected by {addr}") + conn.settimeout(None) + except Exception as inst: + pass + +def read(): + global conn + messageLength = conn.recv(4) + messageLength = int.from_bytes(messageLength, 'little') + message = conn.recv(messageLength) + return json.loads(message.decode("utf-8")) + +def send(message_bytes, verify): + global conn + if message_bytes != None: + conn.sendall(message_bytes) + conn.sendall(len(verify).to_bytes(4, 'little')) + conn.sendall(bytes(verify, 'ascii')) + +def receive(): + message = read() + + width = message["resolution_x"] + height = message["resolution_y"] + + if width != 0 and height != 0: + try: + do_training = bool(message["train"]) + fovy = message["fov_y"] + fovx = message["fov_x"] + znear = message["z_near"] + zfar = message["z_far"] + do_shs_python = bool(message["shs_python"]) + do_rot_scale_python = bool(message["rot_scale_python"]) + keep_alive = bool(message["keep_alive"]) + scaling_modifier = message["scaling_modifier"] + world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() + world_view_transform[:,1] = -world_view_transform[:,1] + world_view_transform[:,2] = -world_view_transform[:,2] + full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() + full_proj_transform[:,1] = -full_proj_transform[:,1] + custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) + except Exception as e: + print("") + traceback.print_exc() + raise e + return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier + else: + return None, None, None, None, None, None \ No newline at end of file diff --git a/gridencoder/__init__.py b/gridencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1476cef5314e0918b963d1ac64ee0613a7743d5 --- /dev/null +++ b/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/gridencoder/backend.py b/gridencoder/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..d99acb1f4353786e16468948780f377008d94872 --- /dev/null +++ b/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/gridencoder/grid.py b/gridencoder/grid.py new file mode 100644 index 0000000000000000000000000000000000000000..32b8bead0d9d0575b0988302afb1794dffcfe72d --- /dev/null +++ b/gridencoder/grid.py @@ -0,0 +1,185 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) \ No newline at end of file diff --git a/gridencoder/setup.py b/gridencoder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..714bf1cad7880fe25dca319414748c15e86cc48e --- /dev/null +++ b/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/gridencoder/src/bindings.cpp b/gridencoder/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93dea943c939cffc7ec73c76410aeff7afddc1f9 --- /dev/null +++ b/gridencoder/src/bindings.cpp @@ -0,0 +1,9 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.cu b/gridencoder/src/gridencoder.cu new file mode 100644 index 0000000000000000000000000000000000000000..cba5e94f5f4ca6b728bc9006c79e80cb0fce62dd --- /dev/null +++ b/gridencoder/src/gridencoder.cu @@ -0,0 +1,645 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { + return min(max(v, lo), hi); +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } else { + pos_deriv[d] = 1.0f; // linear deriv is default to 1 + } + + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + // pos[d] -= (float)pos_grid[d]; // not used + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.h b/gridencoder/src/gridencoder.h new file mode 100644 index 0000000000000000000000000000000000000000..1b385755d13711b04df4866dd654e88b48054554 --- /dev/null +++ b/gridencoder/src/gridencoder.h @@ -0,0 +1,17 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..6e75d40cbdff9727c1a451e3ac768ea35ab0f9f1 --- /dev/null +++ b/install.sh @@ -0,0 +1,10 @@ +# git clone git@github.com:Fictionarry/TalkingGaussian.git --recursive + +# conda env create --file environment.yml +# conda activate talking_gaussian +# pip install "git+https://github.com/facebookresearch/pytorch3d.git" +# pip install tensorflow-gpu==2.8.0 + +python ./submodules/diff-gaussian-rasterization/setup.py +python ./submodules/simple-knn//setup.py +#./gridencoder \ No newline at end of file diff --git a/install.txt b/install.txt new file mode 100644 index 0000000000000000000000000000000000000000..c92b665bb770f3cf89ecc91bcf1dffc3665030d6 --- /dev/null +++ b/install.txt @@ -0,0 +1,25 @@ +numpy +pillow +scipy +tensorboard +opencv-python +tensorboardX + +pandas +tqdm +matplotlib +PyMCubes==0.1.4 +rich +packaging +scikit-learn + +face_alignment +python_speech_features +numba +resampy +pyaudio +soundfile +configargparse + +lpips +imageio-ffmpeg \ No newline at end of file diff --git a/lpipsPyTorch/__init__.py b/lpipsPyTorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6297daa457d1d041c9491dfdf6a75994ffe06e --- /dev/null +++ b/lpipsPyTorch/__init__.py @@ -0,0 +1,21 @@ +import torch + +from .modules.lpips import LPIPS + + +def lpips(x: torch.Tensor, + y: torch.Tensor, + net_type: str = 'alex', + version: str = '0.1'): + r"""Function that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + x, y (torch.Tensor): the input tensors to compare. + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + device = x.device + criterion = LPIPS(net_type, version).to(device) + return criterion(x, y) diff --git a/lpipsPyTorch/modules/lpips.py b/lpipsPyTorch/modules/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd001d1e1036c7f8f5db62e81446e2ff2db80ab --- /dev/null +++ b/lpipsPyTorch/modules/lpips.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + +from .networks import get_network, LinLayers +from .utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type) + + # linear layers + self.lin = LinLayers(self.net.n_channels_list) + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0), 0, True) diff --git a/lpipsPyTorch/modules/networks.py b/lpipsPyTorch/modules/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..d36c6a56163004d49c321da5e26404af9baa4c2a --- /dev/null +++ b/lpipsPyTorch/modules/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from .utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) diff --git a/lpipsPyTorch/modules/utils.py b/lpipsPyTorch/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5 --- /dev/null +++ b/lpipsPyTorch/modules/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..977aaa7e08fd385adb04460682fa8ebd55fb2ccb --- /dev/null +++ b/metrics.py @@ -0,0 +1,218 @@ +import cv2 +import sys +import lpips +import numpy as np +from matplotlib import pyplot as plt +import torch + +class LMDMeter: + def __init__(self, backend='dlib', region='mouth'): + self.backend = backend + self.region = region # mouth or face + + if self.backend == 'dlib': + import dlib + + # load checkpoint manually + self.predictor_path = './shape_predictor_68_face_landmarks.dat' + if not os.path.exists(self.predictor_path): + raise FileNotFoundError('Please download dlib checkpoint from http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2') + + self.detector = dlib.get_frontal_face_detector() + self.predictor = dlib.shape_predictor(self.predictor_path) + + else: + + import face_alignment + try: + self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) + except: + self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) + + self.V = 0 + self.N = 0 + + def get_landmarks(self, img): + + if self.backend == 'dlib': + dets = self.detector(img, 1) + for det in dets: + shape = self.predictor(img, det) + # ref: https://github.com/PyImageSearch/imutils/blob/c12f15391fcc945d0d644b85194b8c044a392e0a/imutils/face_utils/helpers.py + lms = np.zeros((68, 2), dtype=np.int32) + for i in range(0, 68): + lms[i, 0] = shape.part(i).x + lms[i, 1] = shape.part(i).y + break + + else: + lms = self.predictor.get_landmarks(img)[-1] + + # self.vis_landmarks(img, lms) + lms = lms.astype(np.float32) + + return lms + + def vis_landmarks(self, img, lms): + plt.imshow(img) + plt.plot(lms[48:68, 0], lms[48:68, 1], marker='o', markersize=1, linestyle='-', lw=2) + plt.show() + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + inp = inp.detach().cpu().numpy() + inp = (inp * 255).astype(np.uint8) + outputs.append(inp) + return outputs + + def update(self, preds, truths): + # assert B == 1 + preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array + + # get lms + lms_pred = self.get_landmarks(preds) + lms_truth = self.get_landmarks(truths) + + if self.region == 'mouth': + lms_pred = lms_pred[48:68] + lms_truth = lms_truth[48:68] + + # avarage + lms_pred = lms_pred - lms_pred.mean(0) + lms_truth = lms_truth - lms_truth.mean(0) + + # distance + dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0) + + self.V += dist + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step) + + def report(self): + return f'LMD ({self.backend}) = {self.measure():.6f}' + + +class PSNRMeter: + def __init__(self): + self.V = 0 + self.N = 0 + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + if torch.is_tensor(inp): + inp = inp.detach().cpu().numpy() + outputs.append(inp) + + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, N, 3] or [B, H, W, 3], range in [0, 1] + + # simplified since max_pixel_value is 1 here. + psnr = -10 * np.log10(np.mean((preds - truths) ** 2)) + + self.V += psnr + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) + + def report(self): + return f'PSNR = {self.measure():.6f}' + + +class LPIPSMeter: + def __init__(self, net='alex', device=None): + self.V = 0 + self.N = 0 + self.net = net + + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.fn = lpips.LPIPS(net=net).eval().to(self.device) + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] + inp = inp.to(self.device) + outputs.append(inp) + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] + v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] + self.V += v + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step) + + def report(self): + return f'LPIPS ({self.net}) = {self.measure():.6f}' + + + + +lmd_meter = LMDMeter(backend='fan') +psnr_meter = PSNRMeter() +lpips_meter = LPIPSMeter() + +lmd_meter.clear() +psnr_meter.clear() +lpips_meter.clear() + +vid_path_1 = sys.argv[1] +vid_path_2 = sys.argv[2] + +capture_1 = cv2.VideoCapture(vid_path_1) +capture_2 = cv2.VideoCapture(vid_path_2) + +counter = 0 +while True: + ret_1, frame_1 = capture_1.read() + ret_2, frame_2 = capture_2.read() + + if not ret_1 * ret_2: + break + + # plt.imshow(frame_1[:, :, ::-1]) + # plt.show() + inp_1 = torch.FloatTensor(frame_1[..., ::-1] / 255.0)[None, ...].cuda() + inp_2 = torch.FloatTensor(frame_2[..., ::-1] / 255.0)[None, ...].cuda() + lmd_meter.update(inp_1, inp_2) + psnr_meter.update(inp_1, inp_2) + lpips_meter.update(inp_1, inp_2) + + counter+=1 + if counter % 100 == 0: + print(counter) + +print(lmd_meter.report()) +print(psnr_meter.report()) +print(lpips_meter.report()) + diff --git a/scene/__init__.py b/scene/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eed05e810d6107ecb7ef48bf9969635444811ae3 --- /dev/null +++ b/scene/__init__.py @@ -0,0 +1,96 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import gc +import os +import random +import json +from utils.system_utils import searchForMaxIteration +from scene.dataset_readers import sceneLoadTypeCallbacks +from scene.gaussian_model import GaussianModel +from scene.motion_net import MotionNetwork, MouthMotionNetwork +from arguments import ModelParams +from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON + +class Scene: + + gaussians : GaussianModel + + def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): + """b + :param path: Path to colmap scene main folder. + """ + self.model_path = args.model_path + self.loaded_iter = None + self.gaussians = gaussians + + if load_iteration: + if load_iteration == -1: + self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) + else: + self.loaded_iter = load_iteration + print("Loading trained model at iteration {}".format(self.loaded_iter)) + + self.train_cameras = {} + self.test_cameras = {} + + if os.path.exists(os.path.join(args.source_path, "transforms_train.json")): + print("Found transforms_train.json file, assuming Blender data set!") + scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, args=args) + else: + assert False, "Could not recognize scene type!" + + if not self.loaded_iter: + with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: + dest_file.write(src_file.read()) + json_cams = [] + camlist = [] + if scene_info.test_cameras: + camlist.extend(scene_info.test_cameras) + if scene_info.train_cameras: + camlist.extend(scene_info.train_cameras) + for id, cam in enumerate(camlist): + json_cams.append(camera_to_JSON(id, cam)) + with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: + json.dump(json_cams, file) + + if shuffle: + random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling + random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling + + self.cameras_extent = scene_info.nerf_normalization["radius"] + + for resolution_scale in resolution_scales: + print("Loading Training Cameras") + self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) + print("Loading Test Cameras") + self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) + + if self.loaded_iter: + self.gaussians.load_ply(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "point_cloud.ply")) + else: + self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) + + gc.collect() + + + def save(self, iteration): + point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + + def getTrainCameras(self, scale=1.0): + return self.train_cameras[scale] + + def getTestCameras(self, scale=1.0): + return self.test_cameras[scale] \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..39be12fcaf958557a32533126a73c201351eda2a --- /dev/null +++ b/scene/cameras.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from torch import nn +import numpy as np +from utils.graphics_utils import getWorld2View2, getProjectionMatrix + +class Camera(nn.Module): + def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, background, talking_dict, + image_name, uid, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + ): + super(Camera, self).__init__() + + self.uid = uid + self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.image_name = image_name + self.talking_dict = talking_dict + + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.original_image = image.clamp(0, 255).to(self.data_device) + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + self.background = background.clamp(0, 255).to(self.data_device) + + # for key in self.mask.keys(): + # self.mask[key] = torch.as_tensor(self.mask[key], device=self.data_device) + + if gt_alpha_mask is not None: + self.original_image *= gt_alpha_mask.to(self.data_device) + # else: + # self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + +class MiniCam: + def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + self.world_view_transform = world_view_transform + self.full_proj_transform = full_proj_transform + view_inv = torch.inverse(self.world_view_transform) + self.camera_center = view_inv[3][:3] + diff --git a/scene/colmap_loader.py b/scene/colmap_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6fba6a9c961f52c88780ecb44d7821b4cb73ee --- /dev/null +++ b/scene/colmap_loader.py @@ -0,0 +1,294 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import numpy as np +import collections +import struct + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + xyzs = None + rgbs = None + errors = None + num_points = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + num_points += 1 + + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + count = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = np.array(float(elems[7])) + xyzs[count] = xyz + rgbs[count] = rgb + errors[count] = error + count += 1 + + return xyzs, rgbs, errors + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + + + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + + for p_id in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + xyzs[p_id] = xyz + rgbs[p_id] = rgb + errors[p_id] = error + return xyzs, rgbs, errors + +def read_intrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + +def read_extrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_intrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_extrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_colmap_bin_array(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py + + :param path: path to the colmap binary file. + :return: nd array with the floating point values in the value + """ + with open(path, "rb") as fid: + width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, + usecols=(0, 1, 2), dtype=int) + fid.seek(0) + num_delimiter = 0 + byte = fid.read(1) + while True: + if byte == b"&": + num_delimiter += 1 + if num_delimiter >= 3: + break + byte = fid.read(1) + array = np.fromfile(fid, np.float32) + array = array.reshape((width, height, channels), order="F") + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..1f49d170301086eeff983f98424fc311e8eb1a7f --- /dev/null +++ b/scene/dataset_readers.py @@ -0,0 +1,294 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import sys +import torch +from PIL import Image +from typing import NamedTuple +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +import numpy as np +import json +from pathlib import Path +from plyfile import PlyData, PlyElement +from tqdm import tqdm +import pandas as pd + +from utils.sh_utils import SH2RGB +from utils.audio_utils import get_audio_features +from scene.gaussian_model import BasicPointCloud + +class CameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + image: np.array + image_path: str + image_name: str + width: int + height: int + background: np.array + talking_dict: dict + +class SceneInfo(NamedTuple): + point_cloud: BasicPointCloud + train_cameras: list + test_cameras: list + nerf_normalization: dict + ply_path: str + +def getNerfppNorm(cam_info): + def get_center_and_diag(cam_centers): + cam_centers = np.hstack(cam_centers) + avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) + center = avg_cam_center + dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) + diagonal = np.max(dist) + return center.flatten(), diagonal + + cam_centers = [] + + for cam in cam_info: + W2C = getWorld2View2(cam.R, cam.T) + C2W = np.linalg.inv(W2C) + cam_centers.append(C2W[:3, 3:4]) + + center, diagonal = get_center_and_diag(cam_centers) + radius = diagonal * 1.1 + + translate = -center + + return {"translate": translate, "radius": radius} + +def fetchPly(path): + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + return BasicPointCloud(points=positions, colors=colors, normals=normals) + +def storePly(path, xyz, rgb): + # Define the dtype for the structured array + dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), + ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), + ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] + + normals = np.zeros_like(xyz) + + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb), axis=1) + elements[:] = list(map(tuple, attributes)) + + # Create the PlyData object and write to file + vertex_element = PlyElement.describe(elements, 'vertex') + ply_data = PlyData([vertex_element]) + ply_data.write(path) + +def readCamerasFromTransforms(path, transformsfile, white_background, extension=".jpg", audio_file='', audio_extractor='deepspeech'): + cam_infos = [] + postfix_dict = {"deepspeech": "ds", "esperanto": "eo", "hubert": "hu"} + + with open(os.path.join(path, transformsfile)) as json_file: + contents = json.load(json_file) + focal_len = contents["focal_len"] + bg_img = np.array(Image.open(os.path.join(path, 'bc.jpg')).convert("RGB")) + + frames = contents["frames"] + + if audio_file == '': + aud_features = np.load(os.path.join(path, 'aud_{}.npy'.format(postfix_dict[audio_extractor]))) + else: + aud_features = np.load(audio_file) + aud_features = torch.from_numpy(aud_features) + aud_features = aud_features.float().permute(0, 2, 1) + auds = aud_features + + au_info=pd.read_csv(os.path.join(path, 'au.csv')) + au_blink = au_info[' AU45_r'].values + au25 = au_info[' AU25_r'].values + au25 = np.clip(au25, 0, np.percentile(au25, 95)) + + au25_25, au25_50, au25_75, au25_100 = np.percentile(au25, 25), np.percentile(au25, 50), np.percentile(au25, 75), au25.max() + + au_exp = [] + for i in [1,4,5,6,7,45]: + _key = ' AU' + str(i).zfill(2) + '_r' + au_exp_t = au_info[_key].values + if i == 45: + au_exp_t = au_exp_t.clip(0, 2) + au_exp.append(au_exp_t[:, None]) + au_exp = np.concatenate(au_exp, axis=-1, dtype=np.float32) + + ldmks_lips = [] + ldmks_mouth = [] + ldmks_lhalf = [] + + for idx, frame in tqdm(enumerate(frames)): + lms = np.loadtxt(os.path.join(path, 'ori_imgs', str(frame['img_id']) + '.lms')) # [68, 2] + lips = slice(48, 60) + mouth = slice(60, 68) + xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max()) + ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max()) + + ldmks_lips.append([int(xmin), int(xmax), int(ymin), int(ymax)]) + ldmks_mouth.append([int(lms[mouth, 1].min()), int(lms[mouth, 1].max())]) + + lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area + xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max()) + ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max()) + # self.face_rect.append([xmin, xmax, ymin, ymax]) + ldmks_lhalf.append([lh_xmin, lh_xmax, ymin, ymax]) + + ldmks_lips = np.array(ldmks_lips) + ldmks_mouth = np.array(ldmks_mouth) + ldmks_lhalf = np.array(ldmks_lhalf) + mouth_lb = (ldmks_mouth[:, 1] - ldmks_mouth[:, 0]).min() + mouth_ub = (ldmks_mouth[:, 1] - ldmks_mouth[:, 0]).max() + + + + for idx, frame in tqdm(enumerate(frames)): + cam_name = os.path.join(path, 'gt_imgs', str(frame["img_id"]) + extension) + + # NeRF 'transform_matrix' is a camera-to-world transform + c2w = np.array(frame["transform_matrix"]) + # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) + c2w[:3, 1:3] *= -1 + + # get the world-to-camera transform and set R, T + w2c = np.linalg.inv(c2w) + R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code + T = w2c[:3, 3] + + image_path = os.path.join(path, cam_name) + image_name = Path(cam_name).stem + image = Image.open(image_path) + w, h = image.size[0], image.size[1] + image = np.array(image.convert("RGB")) + + torso_img_path = os.path.join(path, 'torso_imgs', str(frame['img_id']) + '.png') + torso_img = np.array(Image.open(torso_img_path).convert("RGBA")) * 1.0 + bg = torso_img[..., :3] * torso_img[..., 3:] / 255.0 + bg_img * (1 - torso_img[..., 3:] / 255.0) + bg = bg.astype(np.uint8) + # bg = Image.fromarray(np.array(bg, dtype=np.byte), "RGB") + # bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) + + talking_dict = {} + talking_dict['img_id'] = frame['img_id'] + + teeth_mask_path = os.path.join(path, 'teeth_mask', str(frame['img_id']) + '.npy') + teeth_mask = np.load(teeth_mask_path) + + mask_path = os.path.join(path, 'parsing', str(frame['img_id']) + '.png') + mask = np.array(Image.open(mask_path).convert("RGB")) * 1.0 + talking_dict['face_mask'] = (mask[:, :, 2] > 254) * (mask[:, :, 0] == 0) * (mask[:, :, 1] == 0) ^ teeth_mask + talking_dict['hair_mask'] = (mask[:, :, 0] < 1) * (mask[:, :, 1] < 1) * (mask[:, :, 2] < 1) + talking_dict['mouth_mask'] = (mask[:, :, 0] == 100) * (mask[:, :, 1] == 100) * (mask[:, :, 2] == 100) + teeth_mask + + + if audio_file == '': + talking_dict['auds'] = get_audio_features(auds, 2, frame['img_id']) + if frame['img_id'] > auds.shape[0]: + print("[warnining] audio feature is too short") + break + else: + talking_dict['auds'] = get_audio_features(auds, 2, idx) + if idx >= auds.shape[0]: + break + + + talking_dict['blink'] = torch.as_tensor(np.clip(au_blink[frame['img_id']], 0, 2) / 2) + talking_dict['au25'] = [au25[frame['img_id']], au25_25, au25_50, au25_75, au25_100] + + talking_dict['au_exp'] = torch.as_tensor(au_exp[frame['img_id']]) + + + [xmin, xmax, ymin, ymax] = ldmks_lips[idx].tolist() + # padding to H == W + cx = (xmin + xmax) // 2 + cy = (ymin + ymax) // 2 + + l = max(xmax - xmin, ymax - ymin) // 2 + xmin = cx - l + xmax = cx + l + ymin = cy - l + ymax = cy + l + + talking_dict['lips_rect'] = [xmin, xmax, ymin, ymax] + talking_dict['lhalf_rect'] = ldmks_lhalf[idx] + talking_dict['mouth_bound'] = [mouth_lb, mouth_ub, ldmks_mouth[idx, 1] - ldmks_mouth[idx, 0]] + talking_dict['img_id'] = frame['img_id'] + + + # norm_data = im_data / 255.0 + # arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) + # image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + + FovX = focal2fov(focal_len, w) + FovY = focal2fov(focal_len, h) + + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=w, height=h, background=bg, talking_dict=talking_dict)) + + # if idx > 200: break + # if idx > 6500: break + + return cam_infos + +def readNerfSyntheticInfo(path, white_background, eval, extension=".jpg", args=None): + audio_file = args.audio + audio_extractor = args.audio_extractor + if not eval: + print("Reading Training Transforms") + train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension, audio_file, audio_extractor) + print("Reading Test Transforms") + test_cam_infos = readCamerasFromTransforms(path, "transforms_val.json", white_background, extension, audio_file, audio_extractor) + + # if not eval: + # train_cam_infos.extend(test_cam_infos) + # test_cam_infos = [] + if eval: + train_cam_infos = test_cam_infos + + nerf_normalization = getNerfppNorm(train_cam_infos) + + + ply_path = os.path.join(path, "points3d.ply") + if not os.path.exists(ply_path) or True: + # Since this data set has no colmap data, we start with random points + num_pts = args.init_num + print(f"Generating random point cloud ({num_pts})...") + + # We create random points inside the bounds of the synthetic Blender scenes + xyz = np.random.random((num_pts, 3)) * 0.2 - 0.1 + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) + + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path) + return scene_info + +sceneLoadTypeCallbacks = { + "Colmap": None, + "Blender" : readNerfSyntheticInfo +} diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1233bc57334623ac339f3e52cb10c160ca3d52b3 --- /dev/null +++ b/scene/gaussian_model.py @@ -0,0 +1,444 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from utils.sh_utils import RGB2SH +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation + +class GaussianModel: + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.scaling_activation = torch.nn.functional.softplus # torch.exp + self.scaling_inverse_activation = lambda x: x + torch.log(-torch.expm1(-x)) # torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + + def __init__(self, sh_degree : int): + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + self._xyz = torch.empty(0) + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._identity = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + self.denom = torch.empty(0) + self.optimizer = None + self.percent_dense = 0 + self.spatial_lr_scale = 0 + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._identity, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + self.xyz_gradient_accum, + self.denom, + self.optimizer.state_dict(), + self.spatial_lr_scale, + ) + + def restore(self, model_args, training_args): + (self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._identity, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + xyz_gradient_accum, + denom, + opt_dict, + self.spatial_lr_scale) = model_args + if training_args is not None: + self.training_setup(training_args) + self.optimizer.load_state_dict(opt_dict) + self.xyz_gradient_accum = xyz_gradient_accum + self.denom = denom + + + @property + def get_scaling(self): + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_identity(self): + return self._identity + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) + scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + identity = torch.zeros((fused_point_cloud.shape[0], 1), device="cuda") + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._identity = nn.Parameter(identity.requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._identity], 'lr': 1e-2, "name": "identity"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, + lr_final=training_args.position_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + + def update_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def save_deformed_ply(self, xyz, scale, rotation, path): + mkdir_p(os.path.dirname(path)) + + xyz = xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = torch.log(self.scaling_activation(scale)).detach().cpu().numpy() + rotation = rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def reset_opacity(self): + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + + self.active_sh_degree = self.max_sh_degree + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] == name: + stored_state = self.optimizer.state.get(group['params'][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._identity = optimizable_tensors["identity"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_identity, new_opacities, new_scaling, new_rotation): + d = {"xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "identity": new_identity, + "opacity": new_opacities, + "scaling" : new_scaling, + "rotation" : new_rotation} + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._identity = optimizable_tensors["identity"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[:grads.shape[0]] = grads.squeeze() + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + + stds = self.get_scaling[selected_pts_mask].repeat(N,1) + means =torch.zeros((stds.size(0), 3),device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N,1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) + new_identity = self._identity[selected_pts_mask].repeat(N,1) + new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_identity, new_opacity, new_scaling, new_rotation) + + prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) + + new_xyz = self._xyz[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_identity = self._identity[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_identity, new_opacities, new_scaling, new_rotation) + + def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) + self.denom[update_filter] += 1 \ No newline at end of file diff --git a/scene/motion_net.py b/scene/motion_net.py new file mode 100644 index 0000000000000000000000000000000000000000..aedef6cf853bc3efbb56b3f5dfe3734181864301 --- /dev/null +++ b/scene/motion_net.py @@ -0,0 +1,357 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder + +# Audio feature extractor +class AudioAttNet(nn.Module): + def __init__(self, dim_aud=64, seq_len=8): + super(AudioAttNet, self).__init__() + self.seq_len = seq_len + self.dim_aud = dim_aud + self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len + nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True) + ) + self.attentionNet = nn.Sequential( + nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), + nn.Softmax(dim=1) + ) + + def forward(self, x): + # x: [1, seq_len, dim_aud] + y = x.permute(0, 2, 1) # [1, dim_aud, seq_len] + y = self.attentionConvNet(y) + y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1) + return torch.sum(y * x, dim=1) # [1, dim_aud] + + +# Audio feature extractor +class AudioNet(nn.Module): + def __init__(self, dim_in=29, dim_aud=64, win_size=16): + super(AudioNet, self).__init__() + self.win_size = win_size + self.dim_aud = dim_aud + self.encoder_conv = nn.Sequential( # n x 29 x 16 + nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2 + nn.LeakyReLU(0.02, True), + nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1 + nn.LeakyReLU(0.02, True), + ) + self.encoder_fc1 = nn.Sequential( + nn.Linear(64, 64), + nn.LeakyReLU(0.02, True), + nn.Linear(64, dim_aud), + ) + + def forward(self, x): + half_w = int(self.win_size/2) + x = x[:, :, 8-half_w:8+half_w] + x = self.encoder_conv(x).squeeze(-1) + x = self.encoder_fc1(x) + return x + + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + # x = F.dropout(x, p=0.1, training=self.training) + + return x + + +class MotionNetwork(nn.Module): + def __init__(self, + audio_dim = 32, + ind_dim = 0, + args = None, + ): + super(MotionNetwork, self).__init__() + + if 'esperanto' in args.audio_extractor: + self.audio_in_dim = 44 + elif 'deepspeech' in args.audio_extractor: + self.audio_in_dim = 29 + elif 'hubert' in args.audio_extractor: + self.audio_in_dim = 1024 + else: + raise NotImplementedError + + self.bound = 0.15 + self.exp_eye = True + + + self.individual_dim = ind_dim + if self.individual_dim > 0: + self.individual_codes = nn.Parameter(torch.randn(10000, self.individual_dim) * 0.1) + + # audio network + self.audio_dim = audio_dim + self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) + + self.audio_att_net = AudioAttNet(self.audio_dim) + + # DYNAMIC PART + self.num_levels = 12 + self.level_dim = 1 + self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=16, log2_hashmap_size=17, desired_resolution=256 * self.bound) + self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=16, log2_hashmap_size=17, desired_resolution=256 * self.bound) + self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=16, log2_hashmap_size=17, desired_resolution=256 * self.bound) + + self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz + + + self.num_layers = 3 + self.hidden_dim = 64 + + self.exp_in_dim = 6 - 1 + self.eye_dim = 6 if self.exp_eye else 0 + self.exp_encode_net = MLP(self.exp_in_dim, self.eye_dim - 1, 16, 2) + + self.eye_att_net = MLP(self.in_dim, self.eye_dim, 16, 2) + + # rot: 4 xyz: 3 opac: 1 scale: 3 + self.out_dim = 11 + self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim + self.individual_dim, self.out_dim, self.hidden_dim, self.num_layers) + + self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 32, 2) + + + @staticmethod + @torch.jit.script + def split_xyz(x): + xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) + return xy, yz, xz + + + def encode_x(self, xyz, bound): + # x: [N, 3], in [-bound, bound] + N, M = xyz.shape + xy, yz, xz = self.split_xyz(xyz) + feat_xy = self.encoder_xy(xy, bound=bound) + feat_yz = self.encoder_yz(yz, bound=bound) + feat_xz = self.encoder_xz(xz, bound=bound) + + return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) + + + def encode_audio(self, a): + # a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech + # if emb, a should be: [1, 16] or [8, 16] + + # fix audio traininig + if a is None: return None + + enc_a = self.audio_net(a) # [1/8, 64] + enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64] + + return enc_a + + + def forward(self, x, a, e=None, c=None): + # x: [N, 3], in [-bound, bound] + enc_x = self.encode_x(x, bound=self.bound) + + enc_a = self.encode_audio(a) + enc_a = enc_a.repeat(enc_x.shape[0], 1) + aud_ch_att = self.aud_ch_att_net(enc_x) + enc_w = enc_a * aud_ch_att + + eye_att = torch.relu(self.eye_att_net(enc_x)) + enc_e = self.exp_encode_net(e[:-1]) + enc_e = torch.cat([enc_e, e[-1:]], dim=-1) + enc_e = enc_e * eye_att + if c is not None: + c = c.repeat(enc_x.shape[0], 1) + h = torch.cat([enc_x, enc_w, enc_e, c], dim=-1) + else: + h = torch.cat([enc_x, enc_w, enc_e], dim=-1) + + h = self.sigma_net(h) + + d_xyz = h[..., :3] * 1e-2 + d_rot = h[..., 3:7] + d_opa = h[..., 7:8] + d_scale = h[..., 8:11] + return { + 'd_xyz': d_xyz, + 'd_rot': d_rot, + 'd_opa': d_opa, + 'd_scale': d_scale, + 'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True), + 'ambient_eye' : eye_att.norm(dim=-1, keepdim=True), + } + + + # optimizer utils + def get_params(self, lr, lr_net, wd=0): + + params = [ + {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + {'params': self.encoder_xy.parameters(), 'lr': lr}, + {'params': self.encoder_yz.parameters(), 'lr': lr}, + {'params': self.encoder_xz.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + ] + params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) + if self.individual_dim > 0: + params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) + + params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + params.append({'params': self.exp_encode_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + + return params + + + + +class MouthMotionNetwork(nn.Module): + def __init__(self, + audio_dim = 32, + ind_dim = 0, + args = None, + ): + super(MouthMotionNetwork, self).__init__() + + if 'esperanto' in args.audio_extractor: + self.audio_in_dim = 44 + elif 'deepspeech' in args.audio_extractor: + self.audio_in_dim = 29 + elif 'hubert' in args.audio_extractor: + self.audio_in_dim = 1024 + else: + raise NotImplementedError + + + self.bound = 0.15 + + + self.individual_dim = ind_dim + if self.individual_dim > 0: + self.individual_codes = nn.Parameter(torch.randn(10000, self.individual_dim) * 0.1) + + # audio network + self.audio_dim = audio_dim + self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) + + self.audio_att_net = AudioAttNet(self.audio_dim) + + # DYNAMIC PART + self.num_levels = 12 + self.level_dim = 1 + self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=17, desired_resolution=384 * self.bound) + self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=17, desired_resolution=384 * self.bound) + self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=17, desired_resolution=384 * self.bound) + + self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz + + ## sigma network + self.num_layers = 3 + self.hidden_dim = 32 + + self.out_dim = 3 + self.sigma_net = MLP(self.in_dim + self.audio_dim + self.individual_dim, self.out_dim, self.hidden_dim, self.num_layers) + + self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 32, 2) + + + def encode_audio(self, a): + # a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech + # if emb, a should be: [1, 16] or [8, 16] + + # fix audio traininig + if a is None: return None + + enc_a = self.audio_net(a) # [1/8, 64] + enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64] + + return enc_a + + + @staticmethod + @torch.jit.script + def split_xyz(x): + xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) + return xy, yz, xz + + + def encode_x(self, xyz, bound): + # x: [N, 3], in [-bound, bound] + N, M = xyz.shape + xy, yz, xz = self.split_xyz(xyz) + feat_xy = self.encoder_xy(xy, bound=bound) + feat_yz = self.encoder_yz(yz, bound=bound) + feat_xz = self.encoder_xz(xz, bound=bound) + + return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) + + + def forward(self, x, a): + # x: [N, 3], in [-bound, bound] + enc_x = self.encode_x(x, bound=self.bound) + + enc_a = self.encode_audio(a) + enc_w = enc_a.repeat(enc_x.shape[0], 1) + # aud_ch_att = self.aud_ch_att_net(enc_x) + # enc_w = enc_a * aud_ch_att + + h = torch.cat([enc_x, enc_w], dim=-1) + + h = self.sigma_net(h) + + d_xyz = h * 1e-2 + d_xyz[..., 0] = d_xyz[..., 0] / 5 + d_xyz[..., 2] = d_xyz[..., 2] / 5 + return { + 'd_xyz': d_xyz, + # 'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True), + } + + + # optimizer utils + def get_params(self, lr, lr_net, wd=0): + + params = [ + {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + {'params': self.encoder_xy.parameters(), 'lr': lr}, + {'params': self.encoder_yz.parameters(), 'lr': lr}, + {'params': self.encoder_xz.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + ] + params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) + if self.individual_dim > 0: + params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) + + params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + + return params diff --git a/scripts/prepare.sh b/scripts/prepare.sh new file mode 100644 index 0000000000000000000000000000000000000000..b3befd5265b446ca7fc13bbcaf39b4136489d25e --- /dev/null +++ b/scripts/prepare.sh @@ -0,0 +1,8 @@ +wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_parsing/79999_iter.pth?raw=true -O data_utils/face_parsing/79999_iter.pth + +mkdir data_utils/face_tracking/3DMM + +wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/exp_info.npy?raw=true -O data_utils/face_tracking/3DMM/exp_info.npy +wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/keys_info.npy?raw=true -O data_utils/face_tracking/3DMM/keys_info.npy +wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/sub_mesh.obj?raw=true -O data_utils/face_tracking/3DMM/sub_mesh.obj +wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/topology_info.npy?raw=true -O data_utils/face_tracking/3DMM/topology_info.npy \ No newline at end of file diff --git a/scripts/train_xx.sh b/scripts/train_xx.sh new file mode 100644 index 0000000000000000000000000000000000000000..734d5f6a17c8eb4df4488eb5693ea747df6575b7 --- /dev/null +++ b/scripts/train_xx.sh @@ -0,0 +1,18 @@ +dataset=$1 +workspace=$2 +gpu_id=$3 +audio_extractor='deepspeech' # deepspeech, esperanto, hubert + +export CUDA_VISIBLE_DEVICES=$gpu_id + +python train_mouth.py -s $dataset -m $workspace --audio_extractor $audio_extractor +python train_face.py -s $dataset -m $workspace --init_num 2000 --densify_grad_threshold 0.0005 --audio_extractor $audio_extractor +python train_fuse.py -s $dataset -m $workspace --opacity_lr 0.001 --audio_extractor $audio_extractor + +# # Parallel. Ensure that you have aleast 2 GPUs, and over N x 64GB memory for about N x 5k frames (IMPORTANT! Otherwise the computer will crash). +# CUDA_VISIBLE_DEVICES=$gpu_id python train_mouth.py -s $dataset -m $workspace --audio_extractor $audio_extractor & +# CUDA_VISIBLE_DEVICES=$((gpu_id+1)) python train_face.py -s $dataset -m $workspace --init_num 2000 --densify_grad_threshold 0.0005 --audio_extractor $audio_extractor +# CUDA_VISIBLE_DEVICES=$gpu_id python train_fuse.py -s $dataset -m $workspace --opacity_lr 0.001 --audio_extractor $audio_extractor + +python synthesize_fuse.py -s $dataset -m $workspace --eval --audio_extractor $audio_extractor +python metrics.py $workspace/test/ours_None/renders/out.mp4 $workspace/test/ours_None/gt/out.mp4 \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/.gitignore b/submodules/diff-gaussian-rasterization/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dbab9eca465b1ea6b69f196a6064a075950d359d --- /dev/null +++ b/submodules/diff-gaussian-rasterization/.gitignore @@ -0,0 +1,7 @@ +build/ +diff_gaussian_rasterization.egg-info/ +dist/ + +__pycache__ + +*.so \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/.gitmodules b/submodules/diff-gaussian-rasterization/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..4553c29f4224a9a8723482bc9aca759a97693a64 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/glm"] + path = third_party/glm + url = https://github.com/g-truc/glm.git diff --git a/submodules/diff-gaussian-rasterization/CMakeLists.txt b/submodules/diff-gaussian-rasterization/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9fdb9cbfdcc0fc2b81595cd616912067291ea0b1 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/CMakeLists.txt @@ -0,0 +1,36 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +cmake_minimum_required(VERSION 3.20) + +project(DiffRast LANGUAGES CUDA CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CUDA_STANDARD 17) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + +add_library(CudaRasterizer + cuda_rasterizer/backward.h + cuda_rasterizer/backward.cu + cuda_rasterizer/forward.h + cuda_rasterizer/forward.cu + cuda_rasterizer/auxiliary.h + cuda_rasterizer/rasterizer_impl.cu + cuda_rasterizer/rasterizer_impl.h + cuda_rasterizer/rasterizer.h +) + +set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "75;86") + +target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer) +target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) diff --git a/submodules/diff-gaussian-rasterization/LICENSE.md b/submodules/diff-gaussian-rasterization/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..c869e695fa63bfde6f887d63a24a2a71f03480ac --- /dev/null +++ b/submodules/diff-gaussian-rasterization/LICENSE.md @@ -0,0 +1,83 @@ +Gaussian-Splatting License +=========================== + +**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. +The *Software* is in the process of being registered with the Agence pour la Protection des +Programmes (APP). + +The *Software* is still being developed by the *Licensor*. + +*Licensor*'s goal is to allow the research community to use, test and evaluate +the *Software*. + +## 1. Definitions + +*Licensee* means any person or entity that uses the *Software* and distributes +its *Work*. + +*Licensor* means the owners of the *Software*, i.e Inria and MPII + +*Software* means the original work of authorship made available under this +License ie gaussian-splatting. + +*Work* means the *Software* and any additions to or derivative works of the +*Software* that are made available under this License. + + +## 2. Purpose +This license is intended to define the rights granted to the *Licensee* by +Licensors under the *Software*. + +## 3. Rights granted + +For the above reasons Licensors have decided to distribute the *Software*. +Licensors grant non-exclusive rights to use the *Software* for research purposes +to research users (both academic and industrial), free of charge, without right +to sublicense.. The *Software* may be used "non-commercially", i.e., for research +and/or evaluation purposes only. + +Subject to the terms and conditions of this License, you are granted a +non-exclusive, royalty-free, license to reproduce, prepare derivative works of, +publicly display, publicly perform and distribute its *Work* and any resulting +derivative works in any form. + +## 4. Limitations + +**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do +so under this License, (b) you include a complete copy of this License with +your distribution, and (c) you retain without modification any copyright, +patent, trademark, or attribution notices that are present in the *Work*. + +**4.2 Derivative Works.** You may specify that additional or different terms apply +to the use, reproduction, and distribution of your derivative works of the *Work* +("Your Terms") only if (a) Your Terms provide that the use limitation in +Section 2 applies to your derivative works, and (b) you identify the specific +derivative works that are subject to Your Terms. Notwithstanding Your Terms, +this License (including the redistribution requirements in Section 3.1) will +continue to apply to the *Work* itself. + +**4.3** Any other use without of prior consent of Licensors is prohibited. Research +users explicitly acknowledge having received from Licensors all information +allowing to appreciate the adequacy between of the *Software* and their needs and +to undertake all necessary precautions for its execution and use. + +**4.4** The *Software* is provided both as a compiled library file and as source +code. In case of using the *Software* for a publication or other results obtained +through the use of the *Software*, users are strongly encouraged to cite the +corresponding publications as explained in the documentation of the *Software*. + +## 5. Disclaimer + +THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES +WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY +UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL +CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES +OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL +USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR +ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE +AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR +IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. diff --git a/submodules/diff-gaussian-rasterization/README.md b/submodules/diff-gaussian-rasterization/README.md new file mode 100644 index 0000000000000000000000000000000000000000..407956b912aaf311a2ab760c4fa745379d11cdb5 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/README.md @@ -0,0 +1,35 @@ +# Differential Gaussian Rasterization + +**NOTE**: this is a modified version to support depth & alpha rendering (both forward and backward) from the [original repository](https://github.com/graphdeco-inria/diff-gaussian-rasterization). + +```python +rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D=means3D, + means2D=means2D, + shs=shs, + colors_precomp=colors_precomp, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=cov3D_precomp, +) +``` + + +Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us. + +
+
+

BibTeX

+
@Article{kerbl3Dgaussians,
+      author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
+      title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
+      journal      = {ACM Transactions on Graphics},
+      number       = {4},
+      volume       = {42},
+      month        = {July},
+      year         = {2023},
+      url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
+}
+
+
diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h new file mode 100644 index 0000000000000000000000000000000000000000..4d4b9b78ad491ad8033002c1fce0a336aedd34d1 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED +#define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED + +#include "config.h" +#include "stdio.h" + +#define BLOCK_SIZE (BLOCK_X * BLOCK_Y) +#define NUM_WARPS (BLOCK_SIZE/32) + +// Spherical harmonics coefficients +__device__ const float SH_C0 = 0.28209479177387814f; +__device__ const float SH_C1 = 0.4886025119029199f; +__device__ const float SH_C2[] = { + 1.0925484305920792f, + -1.0925484305920792f, + 0.31539156525252005f, + -1.0925484305920792f, + 0.5462742152960396f +}; +__device__ const float SH_C3[] = { + -0.5900435899266435f, + 2.890611442640554f, + -0.4570457994644658f, + 0.3731763325901154f, + -0.4570457994644658f, + 1.445305721320277f, + -0.5900435899266435f +}; + +__forceinline__ __device__ float ndc2Pix(float v, int S) +{ + return ((v + 1.0) * S - 1.0) * 0.5; +} + +__forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) +{ + rect_min = { + min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), + min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) + }; + rect_max = { + min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))), + min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y))) + }; +} + +__forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], + }; + return transformed; +} + +__forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) +{ + float4 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], + matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] + }; + return transformed; +} + +__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, + }; + return transformed; +} + +__forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, + matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, + matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, + }; + return transformed; +} + +__forceinline__ __device__ float dnormvdz(float3 v, float3 dv) +{ + float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; + float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; + return dnormvdz; +} + +__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) +{ + float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; + float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + + float3 dnormvdv; + dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; + dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; + dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; + return dnormvdv; +} + +__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) +{ + float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; + float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + + float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; + float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; + float4 dnormvdv; + dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; + dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; + dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; + dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; + return dnormvdv; +} + +__forceinline__ __device__ float sigmoid(float x) +{ + return 1.0f / (1.0f + expf(-x)); +} + +__forceinline__ __device__ bool in_frustum(int idx, + const float* orig_points, + const float* viewmatrix, + const float* projmatrix, + bool prefiltered, + float3& p_view) +{ + float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; + + // Bring points to screen space + float4 p_hom = transformPoint4x4(p_orig, projmatrix); + float p_w = 1.0f / (p_hom.w + 0.0000001f); + float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; + p_view = transformPoint4x3(p_orig, viewmatrix); + + if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) + { + if (prefiltered) + { + printf("Point is filtered although prefiltered is set. This shouldn't happen!"); + __trap(); + } + return false; + } + return true; +} + +#define CHECK_CUDA(A, debug) \ +A; if(debug) { \ +auto ret = cudaDeviceSynchronize(); \ +if (ret != cudaSuccess) { \ +std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ +throw std::runtime_error(cudaGetErrorString(ret)); \ +} \ +} + +#endif \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu b/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu new file mode 100644 index 0000000000000000000000000000000000000000..60af994640327132dcb88fb92d575621561f16f2 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu @@ -0,0 +1,712 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include "backward.h" +#include "auxiliary.h" +#include +#include +namespace cg = cooperative_groups; + +// Backward pass for conversion of spherical harmonics to RGB for +// each Gaussian. +__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) +{ + // Compute intermediate values, as it is done during forward + glm::vec3 pos = means[idx]; + glm::vec3 dir_orig = pos - campos; + glm::vec3 dir = dir_orig / glm::length(dir_orig); + + glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; + + // Use PyTorch rule for clamping: if clamping was applied, + // gradient becomes 0. + glm::vec3 dL_dRGB = dL_dcolor[idx]; + dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1; + dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1; + dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1; + + glm::vec3 dRGBdx(0, 0, 0); + glm::vec3 dRGBdy(0, 0, 0); + glm::vec3 dRGBdz(0, 0, 0); + float x = dir.x; + float y = dir.y; + float z = dir.z; + + // Target location for this Gaussian to write SH gradients to + glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs; + + // No tricks here, just high school-level calculus. + float dRGBdsh0 = SH_C0; + dL_dsh[0] = dRGBdsh0 * dL_dRGB; + if (deg > 0) + { + float dRGBdsh1 = -SH_C1 * y; + float dRGBdsh2 = SH_C1 * z; + float dRGBdsh3 = -SH_C1 * x; + dL_dsh[1] = dRGBdsh1 * dL_dRGB; + dL_dsh[2] = dRGBdsh2 * dL_dRGB; + dL_dsh[3] = dRGBdsh3 * dL_dRGB; + + dRGBdx = -SH_C1 * sh[3]; + dRGBdy = -SH_C1 * sh[1]; + dRGBdz = SH_C1 * sh[2]; + + if (deg > 1) + { + float xx = x * x, yy = y * y, zz = z * z; + float xy = x * y, yz = y * z, xz = x * z; + + float dRGBdsh4 = SH_C2[0] * xy; + float dRGBdsh5 = SH_C2[1] * yz; + float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy); + float dRGBdsh7 = SH_C2[3] * xz; + float dRGBdsh8 = SH_C2[4] * (xx - yy); + dL_dsh[4] = dRGBdsh4 * dL_dRGB; + dL_dsh[5] = dRGBdsh5 * dL_dRGB; + dL_dsh[6] = dRGBdsh6 * dL_dRGB; + dL_dsh[7] = dRGBdsh7 * dL_dRGB; + dL_dsh[8] = dRGBdsh8 * dL_dRGB; + + dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8]; + dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8]; + dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7]; + + if (deg > 2) + { + float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy); + float dRGBdsh10 = SH_C3[1] * xy * z; + float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy); + float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy); + float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy); + float dRGBdsh14 = SH_C3[5] * z * (xx - yy); + float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy); + dL_dsh[9] = dRGBdsh9 * dL_dRGB; + dL_dsh[10] = dRGBdsh10 * dL_dRGB; + dL_dsh[11] = dRGBdsh11 * dL_dRGB; + dL_dsh[12] = dRGBdsh12 * dL_dRGB; + dL_dsh[13] = dRGBdsh13 * dL_dRGB; + dL_dsh[14] = dRGBdsh14 * dL_dRGB; + dL_dsh[15] = dRGBdsh15 * dL_dRGB; + + dRGBdx += ( + SH_C3[0] * sh[9] * 3.f * 2.f * xy + + SH_C3[1] * sh[10] * yz + + SH_C3[2] * sh[11] * -2.f * xy + + SH_C3[3] * sh[12] * -3.f * 2.f * xz + + SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) + + SH_C3[5] * sh[14] * 2.f * xz + + SH_C3[6] * sh[15] * 3.f * (xx - yy)); + + dRGBdy += ( + SH_C3[0] * sh[9] * 3.f * (xx - yy) + + SH_C3[1] * sh[10] * xz + + SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) + + SH_C3[3] * sh[12] * -3.f * 2.f * yz + + SH_C3[4] * sh[13] * -2.f * xy + + SH_C3[5] * sh[14] * -2.f * yz + + SH_C3[6] * sh[15] * -3.f * 2.f * xy); + + dRGBdz += ( + SH_C3[1] * sh[10] * xy + + SH_C3[2] * sh[11] * 4.f * 2.f * yz + + SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) + + SH_C3[4] * sh[13] * 4.f * 2.f * xz + + SH_C3[5] * sh[14] * (xx - yy)); + } + } + } + + // The view direction is an input to the computation. View direction + // is influenced by the Gaussian's mean, so SHs gradients + // must propagate back into 3D position. + glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB)); + + // Account for normalization of direction + float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z }); + + // Gradients of loss w.r.t. Gaussian means, but only the portion + // that is caused because the mean affects the view-dependent color. + // Additional mean gradient is accumulated in below methods. + dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z); +} + +// Backward version of INVERSE 2D covariance matrix computation +// (due to length launched as separate kernel before other +// backward steps contained in preprocess) +__global__ void computeCov2DCUDA(int P, + const float3* means, + const int* radii, + const float* cov3Ds, + const float h_x, float h_y, + const float tan_fovx, float tan_fovy, + const float* view_matrix, + const float* dL_dconics, + float3* dL_dmeans, + float* dL_dcov) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P || !(radii[idx] > 0)) + return; + + // Reading location of 3D covariance for this Gaussian + const float* cov3D = cov3Ds + 6 * idx; + + // Fetch gradients, recompute 2D covariance and relevant + // intermediate forward results needed in the backward. + float3 mean = means[idx]; + float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] }; + float3 t = transformPoint4x3(mean, view_matrix); + + const float limx = 1.3f * tan_fovx; + const float limy = 1.3f * tan_fovy; + const float txtz = t.x / t.z; + const float tytz = t.y / t.z; + t.x = min(limx, max(-limx, txtz)) * t.z; + t.y = min(limy, max(-limy, tytz)) * t.z; + + const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1; + const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1; + + glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z), + 0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z), + 0, 0, 0); + + glm::mat3 W = glm::mat3( + view_matrix[0], view_matrix[4], view_matrix[8], + view_matrix[1], view_matrix[5], view_matrix[9], + view_matrix[2], view_matrix[6], view_matrix[10]); + + glm::mat3 Vrk = glm::mat3( + cov3D[0], cov3D[1], cov3D[2], + cov3D[1], cov3D[3], cov3D[4], + cov3D[2], cov3D[4], cov3D[5]); + + glm::mat3 T = W * J; + + glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T; + + // Use helper variables for 2D covariance entries. More compact. + float a = cov2D[0][0] += 0.3f; + float b = cov2D[0][1]; + float c = cov2D[1][1] += 0.3f; + + float denom = a * c - b * b; + float dL_da = 0, dL_db = 0, dL_dc = 0; + float denom2inv = 1.0f / ((denom * denom) + 0.0000001f); + + if (denom2inv != 0) + { + // Gradients of loss w.r.t. entries of 2D covariance matrix, + // given gradients of loss w.r.t. conic matrix (inverse covariance matrix). + // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a + dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z); + dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x); + dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z); + + // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, + // given gradients w.r.t. 2D covariance matrix (diagonal). + // cov2D = transpose(T) * transpose(Vrk) * T; + dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc); + dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc); + dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc); + + // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, + // given gradients w.r.t. 2D covariance matrix (off-diagonal). + // Off-diagonal elements appear twice --> double the gradient. + // cov2D = transpose(T) * transpose(Vrk) * T; + dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc; + dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc; + dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc; + } + else + { + for (int i = 0; i < 6; i++) + dL_dcov[6 * idx + i] = 0; + } + + // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T + // cov2D = transpose(T) * transpose(Vrk) * T; + float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da + + (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db; + float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da + + (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db; + float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da + + (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db; + float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc + + (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db; + float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc + + (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db; + float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc + + (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db; + + // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix + // T = W * J + float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02; + float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02; + float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12; + float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12; + + float tz = 1.f / t.z; + float tz2 = tz * tz; + float tz3 = tz2 * tz; + + // Gradients of loss w.r.t. transformed Gaussian mean t + float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02; + float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12; + float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12; + + // Account for transformation of mean to t + // t = transformPoint4x3(mean, view_matrix); + float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix); + + // Gradients of loss w.r.t. Gaussian means, but only the portion + // that is caused because the mean affects the covariance matrix. + // Additional mean gradient is accumulated in BACKWARD::preprocess. + dL_dmeans[idx] = dL_dmean; +} + +// Backward pass for the conversion of scale and rotation to a +// 3D covariance matrix for each Gaussian. +__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots) +{ + // Recompute (intermediate) results for the 3D covariance computation. + glm::vec4 q = rot;// / glm::length(rot); + float r = q.x; + float x = q.y; + float y = q.z; + float z = q.w; + + glm::mat3 R = glm::mat3( + 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), + 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), + 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) + ); + + glm::mat3 S = glm::mat3(1.0f); + + glm::vec3 s = mod * scale; + S[0][0] = s.x; + S[1][1] = s.y; + S[2][2] = s.z; + + glm::mat3 M = S * R; + + const float* dL_dcov3D = dL_dcov3Ds + 6 * idx; + + glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]); + glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]); + + // Convert per-element covariance loss gradients to matrix form + glm::mat3 dL_dSigma = glm::mat3( + dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2], + 0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4], + 0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5] + ); + + // Compute loss gradient w.r.t. matrix M + // dSigma_dM = 2 * M + glm::mat3 dL_dM = 2.0f * M * dL_dSigma; + + glm::mat3 Rt = glm::transpose(R); + glm::mat3 dL_dMt = glm::transpose(dL_dM); + + // Gradients of loss w.r.t. scale + glm::vec3* dL_dscale = dL_dscales + idx; + dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]); + dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]); + dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]); + + dL_dMt[0] *= s.x; + dL_dMt[1] *= s.y; + dL_dMt[2] *= s.z; + + // Gradients of loss w.r.t. normalized quaternion + glm::vec4 dL_dq; + dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]); + dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]); + dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]); + dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]); + + // Gradients of loss w.r.t. unnormalized quaternion + float4* dL_drot = (float4*)(dL_drots + idx); + *dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w }); +} + +// Backward pass of the preprocessing steps, except +// for the covariance computation and inversion +// (those are handled by a previous kernel call) +template +__global__ void preprocessCUDA( + int P, int D, int M, + const float3* means, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* view, + const float* proj, + const glm::vec3* campos, + const float3* dL_dmean2D, + glm::vec3* dL_dmeans, + float* dL_dcolor, + float* dL_ddepth, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P || !(radii[idx] > 0)) + return; + + float3 m = means[idx]; + + // Taking care of gradients from the screenspace points + float4 m_hom = transformPoint4x4(m, proj); + float m_w = 1.0f / (m_hom.w + 0.0000001f); + + // Compute loss gradient w.r.t. 3D means due to gradients of 2D means + // from rendering procedure + glm::vec3 dL_dmean; + float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w; + float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w; + dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y; + dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y; + dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y; + + // That's the second part of the mean gradient. Previous computation + // of cov2D and following SH conversion also affects it. + dL_dmeans[idx] += dL_dmean; + + // the w must be equal to 1 for view^T * [x,y,z,1] + float3 m_view = transformPoint4x3(m, view); + + // Compute loss gradient w.r.t. 3D means due to gradients of depth + // from rendering procedure + glm::vec3 dL_dmean2; + float mul3 = view[2] * m.x + view[6] * m.y + view[10] * m.z + view[14]; + dL_dmean2.x = (view[2] - view[3] * mul3) * dL_ddepth[idx]; + dL_dmean2.y = (view[6] - view[7] * mul3) * dL_ddepth[idx]; + dL_dmean2.z = (view[10] - view[11] * mul3) * dL_ddepth[idx]; + + // That's the third part of the mean gradient. + dL_dmeans[idx] += dL_dmean2; + + // Compute gradient updates due to computing colors from SHs + if (shs) + computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh); + + // Compute gradient updates due to computing covariance from scale/rotation + if (scales) + computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot); +} + +// Backward version of the rendering procedure. +template +__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) +renderCUDA( + const uint2* __restrict__ ranges, + const uint32_t* __restrict__ point_list, + int W, int H, + const float* __restrict__ bg_color, + const float2* __restrict__ points_xy_image, + const float4* __restrict__ conic_opacity, + const float* __restrict__ colors, + const float* __restrict__ depths, + const float* __restrict__ alphas, + const uint32_t* __restrict__ n_contrib, + const float* __restrict__ dL_dpixels, + const float* __restrict__ dL_dpixel_depths, + const float* __restrict__ dL_dalphas, + float3* __restrict__ dL_dmean2D, + float4* __restrict__ dL_dconic2D, + float* __restrict__ dL_dopacity, + float* __restrict__ dL_dcolors, + float* __restrict__ dL_ddepths +) +{ + // We rasterize again. Compute necessary block info. + auto block = cg::this_thread_block(); + const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; + const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; + const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; + const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; + const uint32_t pix_id = W * pix.y + pix.x; + const float2 pixf = { (float)pix.x, (float)pix.y }; + + const bool inside = pix.x < W&& pix.y < H; + const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; + + const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); + + bool done = !inside; + int toDo = range.y - range.x; + + __shared__ int collected_id[BLOCK_SIZE]; + __shared__ float2 collected_xy[BLOCK_SIZE]; + __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; + __shared__ float collected_colors[C * BLOCK_SIZE]; + __shared__ float collected_depths[BLOCK_SIZE]; + + // In the forward, we stored the final value for T, the + // product of all (1 - alpha) factors. + const float T_final = inside ? (1 - alphas[pix_id]) : 0; + float T = T_final; + + // We start from the back. The ID of the last contributing + // Gaussian is known from each pixel from the forward. + uint32_t contributor = toDo; + const int last_contributor = inside ? n_contrib[pix_id] : 0; + + float accum_rec[C] = { 0 }; + float dL_dpixel[C]; + float accum_depth_rec = 0; + float dL_dpixel_depth; + float accum_alpha_rec = 0; + float dL_dalpha; + if (inside) { + for (int i = 0; i < C; i++) + dL_dpixel[i] = dL_dpixels[i * H * W + pix_id]; + dL_dpixel_depth = dL_dpixel_depths[pix_id]; + dL_dalpha = dL_dalphas[pix_id]; + } + + float last_alpha = 0; + float last_color[C] = { 0 }; + float last_depth = 0; + + // Gradient of pixel coordinate w.r.t. normalized + // screen-space viewport corrdinates (-1 to 1) + const float ddelx_dx = 0.5 * W; + const float ddely_dy = 0.5 * H; + + // Traverse all Gaussians + for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) + { + // Load auxiliary data into shared memory, start in the BACK + // and load them in revers order. + block.sync(); + const int progress = i * BLOCK_SIZE + block.thread_rank(); + if (range.x + progress < range.y) + { + const int coll_id = point_list[range.y - progress - 1]; + collected_id[block.thread_rank()] = coll_id; + collected_xy[block.thread_rank()] = points_xy_image[coll_id]; + collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; + for (int i = 0; i < C; i++) + collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i]; + collected_depths[block.thread_rank()] = depths[coll_id]; + } + block.sync(); + + // Iterate over Gaussians + for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) + { + // Keep track of current Gaussian ID. Skip, if this one + // is behind the last contributor for this pixel. + contributor--; + if (contributor >= last_contributor) + continue; + + // Compute blending values, as before. + const float2 xy = collected_xy[j]; + const float2 d = { xy.x - pixf.x, xy.y - pixf.y }; + const float4 con_o = collected_conic_opacity[j]; + const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; + if (power > 0.0f) + continue; + + const float G = exp(power); + const float alpha = min(0.99f, con_o.w * G); + if (alpha < 1.0f / 255.0f) + continue; + + T = T / (1.f - alpha); + const float dchannel_dcolor = alpha * T; + const float dpixel_depth_ddepth = alpha * T; + + // Propagate gradients to per-Gaussian colors and keep + // gradients w.r.t. alpha (blending factor for a Gaussian/pixel + // pair). + float dL_dopa = 0.0f; + const int global_id = collected_id[j]; + for (int ch = 0; ch < C; ch++) + { + const float c = collected_colors[ch * BLOCK_SIZE + j]; + // Update last color (to be used in the next iteration) + accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; + last_color[ch] = c; + + const float dL_dchannel = dL_dpixel[ch]; + dL_dopa += (c - accum_rec[ch]) * dL_dchannel; + // Update the gradients w.r.t. color of the Gaussian. + // Atomic, since this pixel is just one of potentially + // many that were affected by this Gaussian. + atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel); + } + + // Propagate gradients from pixel depth to opacity + const float c_d = collected_depths[j]; + accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec; + last_depth = c_d; + dL_dopa += (c_d - accum_depth_rec) * dL_dpixel_depth; + atomicAdd(&(dL_ddepths[global_id]), dpixel_depth_ddepth * dL_dpixel_depth); + + // Propagate gradients from pixel alpha (weights_sum) to opacity + accum_alpha_rec = last_alpha + (1.f - last_alpha) * accum_alpha_rec; + dL_dopa += (1 - accum_alpha_rec) * dL_dalpha; //- (alpha - accum_alpha_rec) * dL_dalpha; + + dL_dopa *= T; + // Update last alpha (to be used in the next iteration) + last_alpha = alpha; + + // Account for fact that alpha also influences how much of + // the background color is added if nothing left to blend + float bg_dot_dpixel = 0; + for (int i = 0; i < C; i++) + bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; + dL_dopa += (-T_final / (1.f - alpha)) * bg_dot_dpixel; + + + // Helpful reusable temporary variables + const float dL_dG = con_o.w * dL_dopa; + const float gdx = G * d.x; + const float gdy = G * d.y; + const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; + const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; + + // Update gradients w.r.t. 2D mean position of the Gaussian + atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx); + atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy); + + // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) + atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG); + atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG); + atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG); + + // Update gradients w.r.t. opacity of the Gaussian + atomicAdd(&(dL_dopacity[global_id]), G * dL_dopa); + } + } +} + +void BACKWARD::preprocess( + int P, int D, int M, + const float3* means3D, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* cov3Ds, + const float* viewmatrix, + const float* projmatrix, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + const glm::vec3* campos, + const float3* dL_dmean2D, + const float* dL_dconic, + glm::vec3* dL_dmean3D, + float* dL_dcolor, + float* dL_ddepth, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot) +{ + // Propagate gradients for the path of 2D conic matrix computation. + // Somewhat long, thus it is its own kernel rather than being part of + // "preprocess". When done, loss gradient w.r.t. 3D means has been + // modified and gradient w.r.t. 3D covariance matrix has been computed. + computeCov2DCUDA << <(P + 255) / 256, 256 >> > ( + P, + means3D, + radii, + cov3Ds, + focal_x, + focal_y, + tan_fovx, + tan_fovy, + viewmatrix, + dL_dconic, + (float3*)dL_dmean3D, + dL_dcov3D); + + // Propagate gradients for remaining steps: finish 3D mean gradients, + // propagate color gradients to SH (if desireD), propagate 3D covariance + // matrix gradients to scale and rotation. + preprocessCUDA << < (P + 255) / 256, 256 >> > ( + P, D, M, + (float3*)means3D, + radii, + shs, + clamped, + (glm::vec3*)scales, + (glm::vec4*)rotations, + scale_modifier, + viewmatrix, + projmatrix, + campos, + (float3*)dL_dmean2D, + (glm::vec3*)dL_dmean3D, + dL_dcolor, + dL_ddepth, + dL_dcov3D, + dL_dsh, + dL_dscale, + dL_drot); +} + +void BACKWARD::render( + const dim3 grid, const dim3 block, + const uint2* ranges, + const uint32_t* point_list, + int W, int H, + const float* bg_color, + const float2* means2D, + const float4* conic_opacity, + const float* colors, + const float* depths, + const float* alphas, + const uint32_t* n_contrib, + const float* dL_dpixels, + const float* dL_dpixel_depths, + const float* dL_dalphas, + float3* dL_dmean2D, + float4* dL_dconic2D, + float* dL_dopacity, + float* dL_dcolors, + float* dL_ddepths) +{ + renderCUDA << > >( + ranges, + point_list, + W, H, + bg_color, + means2D, + conic_opacity, + colors, + depths, + alphas, + n_contrib, + dL_dpixels, + dL_dpixel_depths, + dL_dalphas, + dL_dmean2D, + dL_dconic2D, + dL_dopacity, + dL_dcolors, + dL_ddepths + ); +} \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h new file mode 100644 index 0000000000000000000000000000000000000000..fe53b01b61fdaaad0ce7e1af5d57bd9a60bd9758 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED +#define CUDA_RASTERIZER_BACKWARD_H_INCLUDED + +#include +#include "cuda_runtime.h" +#include "device_launch_parameters.h" +#define GLM_FORCE_CUDA +#include + +namespace BACKWARD +{ + void render( + const dim3 grid, dim3 block, + const uint2* ranges, + const uint32_t* point_list, + int W, int H, + const float* bg_color, + const float2* means2D, + const float4* conic_opacity, + const float* colors, + const float* depths, + const float* alphas, + const uint32_t* n_contrib, + const float* dL_dpixels, + const float* dL_dpixel_depths, + const float* dL_dalphas, + float3* dL_dmean2D, + float4* dL_dconic2D, + float* dL_dopacity, + float* dL_dcolors, + float* dL_ddepths); + + void preprocess( + int P, int D, int M, + const float3* means, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* cov3Ds, + const float* view, + const float* proj, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + const glm::vec3* campos, + const float3* dL_dmean2D, + const float* dL_dconics, + glm::vec3* dL_dmeans, + float* dL_dcolor, + float* dL_ddepth, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot); +} + +#endif \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h new file mode 100644 index 0000000000000000000000000000000000000000..2a912fb34824349caadffe435fc1ab4b31e5aa4f --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED +#define CUDA_RASTERIZER_CONFIG_H_INCLUDED + +#define NUM_CHANNELS 3 // Default 3, RGB +#define BLOCK_X 16 +#define BLOCK_Y 16 + +#endif \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu b/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu new file mode 100644 index 0000000000000000000000000000000000000000..dc28500edd17de9e0b35835b86d3877307ef38c8 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu @@ -0,0 +1,466 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include "forward.h" +#include "auxiliary.h" +#include +#include +namespace cg = cooperative_groups; + +// Forward method for converting the input spherical harmonics +// coefficients of each Gaussian to a simple RGB color. +__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) +{ + // The implementation is loosely based on code for + // "Differentiable Point-Based Radiance Fields for + // Efficient View Synthesis" by Zhang et al. (2022) + glm::vec3 pos = means[idx]; + glm::vec3 dir = pos - campos; + dir = dir / glm::length(dir); + + glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs; + glm::vec3 result = SH_C0 * sh[0]; + + if (deg > 0) + { + float x = dir.x; + float y = dir.y; + float z = dir.z; + result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3]; + + if (deg > 1) + { + float xx = x * x, yy = y * y, zz = z * z; + float xy = x * y, yz = y * z, xz = x * z; + result = result + + SH_C2[0] * xy * sh[4] + + SH_C2[1] * yz * sh[5] + + SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] + + SH_C2[3] * xz * sh[7] + + SH_C2[4] * (xx - yy) * sh[8]; + + if (deg > 2) + { + result = result + + SH_C3[0] * y * (3.0f * xx - yy) * sh[9] + + SH_C3[1] * xy * z * sh[10] + + SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] + + SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] + + SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] + + SH_C3[5] * z * (xx - yy) * sh[14] + + SH_C3[6] * x * (xx - 3.0f * yy) * sh[15]; + } + } + } + result += 0.5f; + + // RGB colors are clamped to positive values. If values are + // clamped, we need to keep track of this for the backward pass. + clamped[3 * idx + 0] = (result.x < 0); + clamped[3 * idx + 1] = (result.y < 0); + clamped[3 * idx + 2] = (result.z < 0); + return glm::max(result, 0.0f); +} + +// Forward version of 2D covariance matrix computation +__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix) +{ + // The following models the steps outlined by equations 29 + // and 31 in "EWA Splatting" (Zwicker et al., 2002). + // Additionally considers aspect / scaling of viewport. + // Transposes used to account for row-/column-major conventions. + float3 t = transformPoint4x3(mean, viewmatrix); + + const float limx = 1.3f * tan_fovx; + const float limy = 1.3f * tan_fovy; + const float txtz = t.x / t.z; + const float tytz = t.y / t.z; + t.x = min(limx, max(-limx, txtz)) * t.z; + t.y = min(limy, max(-limy, tytz)) * t.z; + + glm::mat3 J = glm::mat3( + focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z), + 0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z), + 0, 0, 0); + + glm::mat3 W = glm::mat3( + viewmatrix[0], viewmatrix[4], viewmatrix[8], + viewmatrix[1], viewmatrix[5], viewmatrix[9], + viewmatrix[2], viewmatrix[6], viewmatrix[10]); + + glm::mat3 T = W * J; + + glm::mat3 Vrk = glm::mat3( + cov3D[0], cov3D[1], cov3D[2], + cov3D[1], cov3D[3], cov3D[4], + cov3D[2], cov3D[4], cov3D[5]); + + glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T; + + // Apply low-pass filter: every Gaussian should be at least + // one pixel wide/high. Discard 3rd row and column. + cov[0][0] += 0.3f; + cov[1][1] += 0.3f; + return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) }; +} + +// Forward method for converting scale and rotation properties of each +// Gaussian to a 3D covariance matrix in world space. Also takes care +// of quaternion normalization. +__device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D) +{ + // Create scaling matrix + glm::mat3 S = glm::mat3(1.0f); + S[0][0] = mod * scale.x; + S[1][1] = mod * scale.y; + S[2][2] = mod * scale.z; + + // Normalize quaternion to get valid rotation + glm::vec4 q = rot;// / glm::length(rot); + float r = q.x; + float x = q.y; + float y = q.z; + float z = q.w; + + // Compute rotation matrix from quaternion + glm::mat3 R = glm::mat3( + 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y), + 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x), + 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y) + ); + + glm::mat3 M = S * R; + + // Compute 3D world covariance matrix Sigma + glm::mat3 Sigma = glm::transpose(M) * M; + + // Covariance is symmetric, only store upper right + cov3D[0] = Sigma[0][0]; + cov3D[1] = Sigma[0][1]; + cov3D[2] = Sigma[0][2]; + cov3D[3] = Sigma[1][1]; + cov3D[4] = Sigma[1][2]; + cov3D[5] = Sigma[2][2]; +} + +// Perform initial steps for each Gaussian prior to rasterization. +template +__global__ void preprocessCUDA(int P, int D, int M, + const float* orig_points, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float tan_fovx, float tan_fovy, + const float focal_x, float focal_y, + int* radii, + float2* points_xy_image, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + // Initialize radius and touched tiles to 0. If this isn't changed, + // this Gaussian will not be processed further. + radii[idx] = 0; + tiles_touched[idx] = 0; + + // Perform near culling, quit if outside. + float3 p_view; + if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) + return; + + // Transform point by projecting + float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; + float4 p_hom = transformPoint4x4(p_orig, projmatrix); + float p_w = 1.0f / (p_hom.w + 0.0000001f); + float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; + + // If 3D covariance matrix is precomputed, use it, otherwise compute + // from scaling and rotation parameters. + const float* cov3D; + if (cov3D_precomp != nullptr) + { + cov3D = cov3D_precomp + idx * 6; + } + else + { + computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6); + cov3D = cov3Ds + idx * 6; + } + + // Compute 2D screen-space covariance matrix + float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); + + // Invert covariance (EWA algorithm) + float det = (cov.x * cov.z - cov.y * cov.y); + if (det == 0.0f) + return; + float det_inv = 1.f / det; + float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv }; + + // Compute extent in screen space (by finding eigenvalues of + // 2D covariance matrix). Use extent to compute a bounding rectangle + // of screen-space tiles that this Gaussian overlaps with. Quit if + // rectangle covers 0 tiles. + float mid = 0.5f * (cov.x + cov.z); + float lambda1 = mid + sqrt(max(0.1f, mid * mid - det)); + float lambda2 = mid - sqrt(max(0.1f, mid * mid - det)); + float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2))); + float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) }; + uint2 rect_min, rect_max; + getRect(point_image, my_radius, rect_min, rect_max, grid); + if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) + return; + + // If colors have been precomputed, use them, otherwise convert + // spherical harmonics coefficients to RGB color. + if (colors_precomp == nullptr) + { + glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); + rgb[idx * C + 0] = result.x; + rgb[idx * C + 1] = result.y; + rgb[idx * C + 2] = result.z; + } + + // Store some useful helper data for the next steps. + depths[idx] = p_view.z; + radii[idx] = my_radius; + points_xy_image[idx] = point_image; + // Inverse 2D covariance and opacity neatly pack into one float4 + conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] }; + tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); +} + +// Main rasterization method. Collaboratively works on one tile per +// block, each thread treats one pixel. Alternates between fetching +// and rasterizing data. +template +__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) +renderCUDA( + const uint2* __restrict__ ranges, + const uint32_t* __restrict__ point_list, + int W, int H, + const float2* __restrict__ points_xy_image, + const float* __restrict__ features, + const float* __restrict__ depths, + const float4* __restrict__ conic_opacity, + float* __restrict__ out_alpha, + uint32_t* __restrict__ n_contrib, + const float* __restrict__ bg_color, + float* __restrict__ out_color, + float* __restrict__ out_depth) +{ + // Identify current tile and associated min/max pixel range. + auto block = cg::this_thread_block(); + uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; + uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; + uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; + uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; + uint32_t pix_id = W * pix.y + pix.x; + float2 pixf = { (float)pix.x, (float)pix.y }; + + // Check if this thread is associated with a valid pixel or outside. + bool inside = pix.x < W&& pix.y < H; + // Done threads can help with fetching, but don't rasterize + bool done = !inside; + + // Load start/end range of IDs to process in bit sorted list. + uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; + const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); + int toDo = range.y - range.x; + + // Allocate storage for batches of collectively fetched data. + __shared__ int collected_id[BLOCK_SIZE]; + __shared__ float2 collected_xy[BLOCK_SIZE]; + __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; + + // Initialize helper variables + float T = 1.0f; + uint32_t contributor = 0; + uint32_t last_contributor = 0; + float C[CHANNELS] = { 0 }; + float weight = 0; + float D = 0; + + // Iterate over batches until all done or range is complete + for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) + { + // End if entire block votes that it is done rasterizing + int num_done = __syncthreads_count(done); + if (num_done == BLOCK_SIZE) + break; + + // Collectively fetch per-Gaussian data from global to shared + int progress = i * BLOCK_SIZE + block.thread_rank(); + if (range.x + progress < range.y) + { + int coll_id = point_list[range.x + progress]; + collected_id[block.thread_rank()] = coll_id; + collected_xy[block.thread_rank()] = points_xy_image[coll_id]; + collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; + } + block.sync(); + + // Iterate over current batch + for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) + { + // Keep track of current position in range + contributor++; + + // Resample using conic matrix (cf. "Surface + // Splatting" by Zwicker et al., 2001) + float2 xy = collected_xy[j]; + float2 d = { xy.x - pixf.x, xy.y - pixf.y }; + float4 con_o = collected_conic_opacity[j]; + float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; + if (power > 0.0f) + continue; + + // Eq. (2) from 3D Gaussian splatting paper. + // Obtain alpha by multiplying with Gaussian opacity + // and its exponential falloff from mean. + // Avoid numerical instabilities (see paper appendix). + float alpha = min(0.99f, con_o.w * exp(power)); + if (alpha < 1.0f / 255.0f) + continue; + float test_T = T * (1 - alpha); + if (test_T < 0.0001f) + { + done = true; + continue; + } + + // Eq. (3) from 3D Gaussian splatting paper. + for (int ch = 0; ch < CHANNELS; ch++) + C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; + weight += alpha * T; + D += depths[collected_id[j]] * alpha * T; + + T = test_T; + + // Keep track of last range entry to update this + // pixel. + last_contributor = contributor; + } + } + + // All threads that treat valid pixel write out their final + // rendering data to the frame and auxiliary buffers. + if (inside) + { + n_contrib[pix_id] = last_contributor; + for (int ch = 0; ch < CHANNELS; ch++) + out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; + out_alpha[pix_id] = weight; //1 - T; + out_depth[pix_id] = D; + } +} + +void FORWARD::render( + const dim3 grid, dim3 block, + const uint2* ranges, + const uint32_t* point_list, + int W, int H, + const float2* means2D, + const float* colors, + const float* depths, + const float4* conic_opacity, + float* out_alpha, + uint32_t* n_contrib, + const float* bg_color, + float* out_color, + float* out_depth) +{ + renderCUDA << > > ( + ranges, + point_list, + W, H, + means2D, + colors, + depths, + conic_opacity, + out_alpha, + n_contrib, + bg_color, + out_color, + out_depth); +} + +void FORWARD::preprocess(int P, int D, int M, + const float* means3D, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + int* radii, + float2* means2D, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered) +{ + preprocessCUDA << <(P + 255) / 256, 256 >> > ( + P, D, M, + means3D, + scales, + scale_modifier, + rotations, + opacities, + shs, + clamped, + cov3D_precomp, + colors_precomp, + viewmatrix, + projmatrix, + cam_pos, + W, H, + tan_fovx, tan_fovy, + focal_x, focal_y, + radii, + means2D, + depths, + cov3Ds, + rgb, + conic_opacity, + grid, + tiles_touched, + prefiltered + ); +} \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h new file mode 100644 index 0000000000000000000000000000000000000000..6fa308f01b28726d613334798bae0fc309eec776 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED +#define CUDA_RASTERIZER_FORWARD_H_INCLUDED + +#include +#include "cuda_runtime.h" +#include "device_launch_parameters.h" +#define GLM_FORCE_CUDA +#include + +namespace FORWARD +{ + // Perform initial steps for each Gaussian prior to rasterization. + void preprocess(int P, int D, int M, + const float* orig_points, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, int H, + const float focal_x, float focal_y, + const float tan_fovx, float tan_fovy, + int* radii, + float2* points_xy_image, + float* depths, + float* cov3Ds, + float* colors, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered); + + // Main rasterization method. + void render( + const dim3 grid, dim3 block, + const uint2* ranges, + const uint32_t* point_list, + int W, int H, + const float2* points_xy_image, + const float* features, + const float* depths, + const float4* conic_opacity, + float* out_alpha, + uint32_t* n_contrib, + const float* bg_color, + float* out_color, + float* out_depth); +} + + +#endif \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h new file mode 100644 index 0000000000000000000000000000000000000000..10270f6843bb0c396a82e27c22a68533409a39c1 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef CUDA_RASTERIZER_H_INCLUDED +#define CUDA_RASTERIZER_H_INCLUDED + +#include +#include + +namespace CudaRasterizer +{ + class Rasterizer + { + public: + + static void markVisible( + int P, + float* means3D, + float* viewmatrix, + float* projmatrix, + bool* present); + + static int forward( + std::function geometryBuffer, + std::function binningBuffer, + std::function imageBuffer, + const int P, int D, int M, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* opacities, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float tan_fovx, float tan_fovy, + const bool prefiltered, + float* out_color, + float* out_depth, + float* out_alpha, + int* radii = nullptr, + bool debug = false); + + static void backward( + const int P, int D, int M, int R, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* alphas, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* campos, + const float tan_fovx, float tan_fovy, + const int* radii, + char* geom_buffer, + char* binning_buffer, + char* image_buffer, + const float* dL_dpix, + const float* dL_dpix_depth, + const float* dL_dalphas, + float* dL_dmean2D, + float* dL_dconic, + float* dL_dopacity, + float* dL_dcolor, + float* dL_ddepth, + float* dL_dmean3D, + float* dL_dcov3D, + float* dL_dsh, + float* dL_dscale, + float* dL_drot, + bool debug); + }; +}; + +#endif \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu new file mode 100644 index 0000000000000000000000000000000000000000..f17793a12d58cd2bb1d7e11c8a196768dedb8086 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu @@ -0,0 +1,447 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include "rasterizer_impl.h" +#include +#include +#include +#include +#include +#include "cuda_runtime.h" +#include "device_launch_parameters.h" +#include +#include +#define GLM_FORCE_CUDA +#include + +#include +#include +namespace cg = cooperative_groups; + +#include "auxiliary.h" +#include "forward.h" +#include "backward.h" + +// Helper function to find the next-highest bit of the MSB +// on the CPU. +uint32_t getHigherMsb(uint32_t n) +{ + uint32_t msb = sizeof(n) * 4; + uint32_t step = msb; + while (step > 1) + { + step /= 2; + if (n >> msb) + msb += step; + else + msb -= step; + } + if (n >> msb) + msb++; + return msb; +} + +// Wrapper method to call auxiliary coarse frustum containment test. +// Mark all Gaussians that pass it. +__global__ void checkFrustum(int P, + const float* orig_points, + const float* viewmatrix, + const float* projmatrix, + bool* present) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + float3 p_view; + present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view); +} + +// Generates one key/value pair for all Gaussian / tile overlaps. +// Run once per Gaussian (1:N mapping). +__global__ void duplicateWithKeys( + int P, + const float2* points_xy, + const float* depths, + const uint32_t* offsets, + uint64_t* gaussian_keys_unsorted, + uint32_t* gaussian_values_unsorted, + int* radii, + dim3 grid) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + // Generate no key/value pair for invisible Gaussians + if (radii[idx] > 0) + { + // Find this Gaussian's offset in buffer for writing keys/values. + uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; + uint2 rect_min, rect_max; + + getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid); + + // For each tile that the bounding rect overlaps, emit a + // key/value pair. The key is | tile ID | depth |, + // and the value is the ID of the Gaussian. Sorting the values + // with this key yields Gaussian IDs in a list, such that they + // are first sorted by tile and then by depth. + for (int y = rect_min.y; y < rect_max.y; y++) + { + for (int x = rect_min.x; x < rect_max.x; x++) + { + uint64_t key = y * grid.x + x; + key <<= 32; + key |= *((uint32_t*)&depths[idx]); + gaussian_keys_unsorted[off] = key; + gaussian_values_unsorted[off] = idx; + off++; + } + } + } +} + +// Check keys to see if it is at the start/end of one tile's range in +// the full sorted list. If yes, write start/end of this tile. +// Run once per instanced (duplicated) Gaussian ID. +__global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= L) + return; + + // Read tile ID from key. Update start/end of tile range if at limit. + uint64_t key = point_list_keys[idx]; + uint32_t currtile = key >> 32; + if (idx == 0) + ranges[currtile].x = 0; + else + { + uint32_t prevtile = point_list_keys[idx - 1] >> 32; + if (currtile != prevtile) + { + ranges[prevtile].y = idx; + ranges[currtile].x = idx; + } + } + if (idx == L - 1) + ranges[currtile].y = L; +} + +// Mark Gaussians as visible/invisible, based on view frustum testing +void CudaRasterizer::Rasterizer::markVisible( + int P, + float* means3D, + float* viewmatrix, + float* projmatrix, + bool* present) +{ + checkFrustum << <(P + 255) / 256, 256 >> > ( + P, + means3D, + viewmatrix, projmatrix, + present); +} + +CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P) +{ + GeometryState geom; + obtain(chunk, geom.depths, P, 128); + obtain(chunk, geom.clamped, P * 3, 128); + obtain(chunk, geom.internal_radii, P, 128); + obtain(chunk, geom.means2D, P, 128); + obtain(chunk, geom.cov3D, P * 6, 128); + obtain(chunk, geom.conic_opacity, P, 128); + obtain(chunk, geom.rgb, P * 3, 128); + obtain(chunk, geom.tiles_touched, P, 128); + cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P); + obtain(chunk, geom.scanning_space, geom.scan_size, 128); + obtain(chunk, geom.point_offsets, P, 128); + return geom; +} + +CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N) +{ + ImageState img; + obtain(chunk, img.n_contrib, N, 128); + obtain(chunk, img.ranges, N, 128); + return img; +} + +CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P) +{ + BinningState binning; + obtain(chunk, binning.point_list, P, 128); + obtain(chunk, binning.point_list_unsorted, P, 128); + obtain(chunk, binning.point_list_keys, P, 128); + obtain(chunk, binning.point_list_keys_unsorted, P, 128); + cub::DeviceRadixSort::SortPairs( + nullptr, binning.sorting_size, + binning.point_list_keys_unsorted, binning.point_list_keys, + binning.point_list_unsorted, binning.point_list, P); + obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); + return binning; +} + +// Forward rendering procedure for differentiable rasterization +// of Gaussians. +int CudaRasterizer::Rasterizer::forward( + std::function geometryBuffer, + std::function binningBuffer, + std::function imageBuffer, + const int P, int D, int M, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* opacities, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float tan_fovx, float tan_fovy, + const bool prefiltered, + float* out_color, + float* out_depth, + float* out_alpha, + int* radii, + bool debug) +{ + const float focal_y = height / (2.0f * tan_fovy); + const float focal_x = width / (2.0f * tan_fovx); + + size_t chunk_size = required(P); + char* chunkptr = geometryBuffer(chunk_size); + GeometryState geomState = GeometryState::fromChunk(chunkptr, P); + + if (radii == nullptr) + { + radii = geomState.internal_radii; + } + + dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); + dim3 block(BLOCK_X, BLOCK_Y, 1); + + // Dynamically resize image-based auxiliary buffers during training + size_t img_chunk_size = required(width * height); + char* img_chunkptr = imageBuffer(img_chunk_size); + ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); + + if (NUM_CHANNELS != 3 && colors_precomp == nullptr) + { + throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!"); + } + + // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) + CHECK_CUDA(FORWARD::preprocess( + P, D, M, + means3D, + (glm::vec3*)scales, + scale_modifier, + (glm::vec4*)rotations, + opacities, + shs, + geomState.clamped, + cov3D_precomp, + colors_precomp, + viewmatrix, projmatrix, + (glm::vec3*)cam_pos, + width, height, + focal_x, focal_y, + tan_fovx, tan_fovy, + radii, + geomState.means2D, + geomState.depths, + geomState.cov3D, + geomState.rgb, + geomState.conic_opacity, + tile_grid, + geomState.tiles_touched, + prefiltered + ), debug) + + // Compute prefix sum over full list of touched tile counts by Gaussians + // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] + CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug) + + // Retrieve total number of Gaussian instances to launch and resize aux buffers + int num_rendered; + CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug); + + size_t binning_chunk_size = required(num_rendered); + char* binning_chunkptr = binningBuffer(binning_chunk_size); + BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered); + + // For each instance to be rendered, produce adequate [ tile | depth ] key + // and corresponding dublicated Gaussian indices to be sorted + duplicateWithKeys << <(P + 255) / 256, 256 >> > ( + P, + geomState.means2D, + geomState.depths, + geomState.point_offsets, + binningState.point_list_keys_unsorted, + binningState.point_list_unsorted, + radii, + tile_grid) + CHECK_CUDA(, debug) + + int bit = getHigherMsb(tile_grid.x * tile_grid.y); + + // Sort complete list of (duplicated) Gaussian indices by keys + CHECK_CUDA(cub::DeviceRadixSort::SortPairs( + binningState.list_sorting_space, + binningState.sorting_size, + binningState.point_list_keys_unsorted, binningState.point_list_keys, + binningState.point_list_unsorted, binningState.point_list, + num_rendered, 0, 32 + bit), debug) + + CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug); + + // Identify start and end of per-tile workloads in sorted list + if (num_rendered > 0) + identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( + num_rendered, + binningState.point_list_keys, + imgState.ranges); + CHECK_CUDA(, debug); + + // Let each tile blend its range of Gaussians independently in parallel + const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb; + CHECK_CUDA(FORWARD::render( + tile_grid, block, + imgState.ranges, + binningState.point_list, + width, height, + geomState.means2D, + feature_ptr, + geomState.depths, + geomState.conic_opacity, + out_alpha, + imgState.n_contrib, + background, + out_color, + out_depth), debug); + + return num_rendered; +} + +// Produce necessary gradients for optimization, corresponding +// to forward render pass +void CudaRasterizer::Rasterizer::backward( + const int P, int D, int M, int R, + const float* background, + const int width, int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* alphas, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* campos, + const float tan_fovx, float tan_fovy, + const int* radii, + char* geom_buffer, + char* binning_buffer, + char* img_buffer, + const float* dL_dpix, + const float* dL_dpix_depth, + const float* dL_dalphas, + float* dL_dmean2D, + float* dL_dconic, + float* dL_dopacity, + float* dL_dcolor, + float* dL_ddepth, + float* dL_dmean3D, + float* dL_dcov3D, + float* dL_dsh, + float* dL_dscale, + float* dL_drot, + bool debug) +{ + GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); + BinningState binningState = BinningState::fromChunk(binning_buffer, R); + ImageState imgState = ImageState::fromChunk(img_buffer, width * height); + + if (radii == nullptr) + { + radii = geomState.internal_radii; + } + + const float focal_y = height / (2.0f * tan_fovy); + const float focal_x = width / (2.0f * tan_fovx); + + const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); + const dim3 block(BLOCK_X, BLOCK_Y, 1); + + // Compute loss gradients w.r.t. 2D mean position, conic matrix, + // opacity and RGB of Gaussians from per-pixel loss gradients. + // If we were given precomputed colors and not SHs, use them. + const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb; + const float* depth_ptr = geomState.depths; + CHECK_CUDA(BACKWARD::render( + tile_grid, + block, + imgState.ranges, + binningState.point_list, + width, height, + background, + geomState.means2D, + geomState.conic_opacity, + color_ptr, + depth_ptr, + alphas, + imgState.n_contrib, + dL_dpix, + dL_dpix_depth, + dL_dalphas, + (float3*)dL_dmean2D, + (float4*)dL_dconic, + dL_dopacity, + dL_dcolor, + dL_ddepth), debug) + + // Take care of the rest of preprocessing. Was the precomputed covariance + // given to us or a scales/rot pair? If precomputed, pass that. If not, + // use the one we computed ourselves. + const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D; + CHECK_CUDA(BACKWARD::preprocess(P, D, M, + (float3*)means3D, + radii, + shs, + geomState.clamped, + (glm::vec3*)scales, + (glm::vec4*)rotations, + scale_modifier, + cov3D_ptr, + viewmatrix, + projmatrix, + focal_x, focal_y, + tan_fovx, tan_fovy, + (glm::vec3*)campos, + (float3*)dL_dmean2D, + dL_dconic, + (glm::vec3*)dL_dmean3D, + dL_dcolor, + dL_ddepth, + dL_dcov3D, + dL_dsh, + (glm::vec3*)dL_dscale, + (glm::vec4*)dL_drot), debug) +} \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..ff2d117ff3e71443cd9b349c908b641ead687b71 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#pragma once + +#include +#include +#include "rasterizer.h" +#include + +namespace CudaRasterizer +{ + template + static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) + { + std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); + ptr = reinterpret_cast(offset); + chunk = reinterpret_cast(ptr + count); + } + + struct GeometryState + { + size_t scan_size; + float* depths; + char* scanning_space; + bool* clamped; + int* internal_radii; + float2* means2D; + float* cov3D; + float4* conic_opacity; + float* rgb; + uint32_t* point_offsets; + uint32_t* tiles_touched; + + static GeometryState fromChunk(char*& chunk, size_t P); + }; + + struct ImageState + { + uint2* ranges; + uint32_t* n_contrib; + + static ImageState fromChunk(char*& chunk, size_t N); + }; + + struct BinningState + { + size_t sorting_size; + uint64_t* point_list_keys_unsorted; + uint64_t* point_list_keys; + uint32_t* point_list_unsorted; + uint32_t* point_list; + char* list_sorting_space; + + static BinningState fromChunk(char*& chunk, size_t P); + }; + + template + size_t required(size_t P) + { + char* size = nullptr; + T::fromChunk(size, P); + return ((size_t)size) + 128; + } +}; \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py b/submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81b1ecb17e9824213a5bc5ab6bed3688fbf73418 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py @@ -0,0 +1,224 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from typing import NamedTuple +import torch.nn as nn +import torch +from . import _C + +def cpu_deep_copy_tuple(input_tuple): + copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] + return tuple(copied_tensors) + +def rasterize_gaussians( + means3D, + means2D, + sh, + colors_precomp, + opacities, + scales, + rotations, + cov3Ds_precomp, + raster_settings, +): + return _RasterizeGaussians.apply( + means3D, + means2D, + sh, + colors_precomp, + opacities, + scales, + rotations, + cov3Ds_precomp, + raster_settings, + ) + +class _RasterizeGaussians(torch.autograd.Function): + @staticmethod + def forward( + ctx, + means3D, + means2D, + sh, + colors_precomp, + opacities, + scales, + rotations, + cov3Ds_precomp, + raster_settings, + ): + + # Restructure arguments the way that the C++ lib expects them + args = ( + raster_settings.bg, + means3D, + colors_precomp, + opacities, + scales, + rotations, + raster_settings.scale_modifier, + cov3Ds_precomp, + raster_settings.viewmatrix, + raster_settings.projmatrix, + raster_settings.tanfovx, + raster_settings.tanfovy, + raster_settings.image_height, + raster_settings.image_width, + sh, + raster_settings.sh_degree, + raster_settings.campos, + raster_settings.prefiltered, + raster_settings.debug + ) + + # Invoke C++/CUDA rasterizer + if raster_settings.debug: + cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted + try: + num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + except Exception as ex: + torch.save(cpu_args, "snapshot_fw.dump") + print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") + raise ex + else: + num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + + # Keep relevant tensors for backward + ctx.raster_settings = raster_settings + ctx.num_rendered = num_rendered + ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha) + return color, radii, depth, alpha + + @staticmethod + def backward(ctx, grad_color, grad_radii, grad_depth, grad_alpha): + + # Restore necessary values from context + num_rendered = ctx.num_rendered + raster_settings = ctx.raster_settings + colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha = ctx.saved_tensors + + # Restructure args as C++ method expects them + args = (raster_settings.bg, + means3D, + radii, + colors_precomp, + scales, + rotations, + raster_settings.scale_modifier, + cov3Ds_precomp, + raster_settings.viewmatrix, + raster_settings.projmatrix, + raster_settings.tanfovx, + raster_settings.tanfovy, + grad_color, + grad_depth, + grad_alpha, + sh, + raster_settings.sh_degree, + raster_settings.campos, + geomBuffer, + num_rendered, + binningBuffer, + imgBuffer, + alpha, + raster_settings.debug) + + # Compute gradients for relevant tensors by invoking backward method + if raster_settings.debug: + cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted + try: + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) + except Exception as ex: + torch.save(cpu_args, "snapshot_bw.dump") + print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") + raise ex + else: + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) + + grads = ( + grad_means3D, + grad_means2D, + grad_sh, + grad_colors_precomp, + grad_opacities, + grad_scales, + grad_rotations, + grad_cov3Ds_precomp, + None, + ) + + return grads + +class GaussianRasterizationSettings(NamedTuple): + image_height: int + image_width: int + tanfovx : float + tanfovy : float + bg : torch.Tensor + scale_modifier : float + viewmatrix : torch.Tensor + projmatrix : torch.Tensor + sh_degree : int + campos : torch.Tensor + prefiltered : bool + debug : bool + +class GaussianRasterizer(nn.Module): + def __init__(self, raster_settings): + super().__init__() + self.raster_settings = raster_settings + + def markVisible(self, positions): + # Mark visible points (based on frustum culling for camera) with a boolean + with torch.no_grad(): + raster_settings = self.raster_settings + visible = _C.mark_visible( + positions, + raster_settings.viewmatrix, + raster_settings.projmatrix) + + return visible + + def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): + + raster_settings = self.raster_settings + + if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): + raise Exception('Please provide excatly one of either SHs or precomputed colors!') + + if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): + raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') + + if shs is None: + shs = torch.Tensor([]) + if colors_precomp is None: + colors_precomp = torch.Tensor([]) + + if scales is None: + scales = torch.Tensor([]) + if rotations is None: + rotations = torch.Tensor([]) + if cov3D_precomp is None: + cov3D_precomp = torch.Tensor([]) + + # Invoke C++/CUDA rasterization routine + return rasterize_gaussians( + means3D, + means2D, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov3D_precomp, + raster_settings, + ) + diff --git a/submodules/diff-gaussian-rasterization/ext.cpp b/submodules/diff-gaussian-rasterization/ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d768779579761238347972a973fbd1603d44235e --- /dev/null +++ b/submodules/diff-gaussian-rasterization/ext.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include +#include "rasterize_points.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); + m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); + m.def("mark_visible", &markVisible); +} \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/rasterize_points.cu b/submodules/diff-gaussian-rasterization/rasterize_points.cu new file mode 100644 index 0000000000000000000000000000000000000000..ee9f8282fcdd261227e0fb97d3c487f1406d98c8 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/rasterize_points.cu @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "cuda_rasterizer/config.h" +#include "cuda_rasterizer/rasterizer.h" +#include +#include +#include + +std::function resizeFunctional(torch::Tensor& t) { + auto lambda = [&t](size_t N) { + t.resize_({(long long)N}); + return reinterpret_cast(t.contiguous().data_ptr()); + }; + return lambda; +} + +std::tuple +RasterizeGaussiansCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& colors, + const torch::Tensor& opacity, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const bool prefiltered, + const bool debug) +{ + if (means3D.ndimension() != 2 || means3D.size(1) != 3) { + AT_ERROR("means3D must have dimensions (num_points, 3)"); + } + + const int P = means3D.size(0); + const int H = image_height; + const int W = image_width; + + auto int_opts = means3D.options().dtype(torch::kInt32); + auto float_opts = means3D.options().dtype(torch::kFloat32); + + torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); + torch::Tensor out_depth = torch::full({1, H, W}, 0.0, float_opts); + torch::Tensor out_alpha = torch::full({1, H, W}, 0.0, float_opts); + torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); + + torch::Device device(torch::kCUDA); + torch::TensorOptions options(torch::kByte); + torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); + torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); + torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); + std::function geomFunc = resizeFunctional(geomBuffer); + std::function binningFunc = resizeFunctional(binningBuffer); + std::function imgFunc = resizeFunctional(imgBuffer); + + int rendered = 0; + if(P != 0) + { + int M = 0; + if(sh.size(0) != 0) + { + M = sh.size(1); + } + + rendered = CudaRasterizer::Rasterizer::forward( + geomFunc, + binningFunc, + imgFunc, + P, degree, M, + background.contiguous().data(), + W, H, + means3D.contiguous().data(), + sh.contiguous().data_ptr(), + colors.contiguous().data(), + opacity.contiguous().data(), + scales.contiguous().data_ptr(), + scale_modifier, + rotations.contiguous().data_ptr(), + cov3D_precomp.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + campos.contiguous().data(), + tan_fovx, + tan_fovy, + prefiltered, + out_color.contiguous().data(), + out_depth.contiguous().data(), + out_alpha.contiguous().data(), + radii.contiguous().data(), + debug); + } + return std::make_tuple(rendered, out_color, out_depth, out_alpha, radii, geomBuffer, binningBuffer, imgBuffer); +} + +std::tuple + RasterizeGaussiansBackwardCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& radii, + const torch::Tensor& colors, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const torch::Tensor& dL_dout_color, + const torch::Tensor& dL_dout_depth, + const torch::Tensor& dL_dout_alpha, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const torch::Tensor& geomBuffer, + const int R, + const torch::Tensor& binningBuffer, + const torch::Tensor& imageBuffer, + const torch::Tensor& alphas, + const bool debug) +{ + const int P = means3D.size(0); + const int H = dL_dout_color.size(1); + const int W = dL_dout_color.size(2); + + int M = 0; + if(sh.size(0) != 0) + { + M = sh.size(1); + } + + torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); + torch::Tensor dL_ddepths = torch::zeros({P, 1}, means3D.options()); + torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); + torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); + torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); + torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); + torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); + + if(P != 0) + { + CudaRasterizer::Rasterizer::backward(P, degree, M, R, + background.contiguous().data(), + W, H, + means3D.contiguous().data(), + sh.contiguous().data(), + colors.contiguous().data(), + alphas.contiguous().data(), + scales.data_ptr(), + scale_modifier, + rotations.data_ptr(), + cov3D_precomp.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + campos.contiguous().data(), + tan_fovx, + tan_fovy, + radii.contiguous().data(), + reinterpret_cast(geomBuffer.contiguous().data_ptr()), + reinterpret_cast(binningBuffer.contiguous().data_ptr()), + reinterpret_cast(imageBuffer.contiguous().data_ptr()), + dL_dout_color.contiguous().data(), + dL_dout_depth.contiguous().data(), + dL_dout_alpha.contiguous().data(), + dL_dmeans2D.contiguous().data(), + dL_dconic.contiguous().data(), + dL_dopacity.contiguous().data(), + dL_dcolors.contiguous().data(), + dL_ddepths.contiguous().data(), + dL_dmeans3D.contiguous().data(), + dL_dcov3D.contiguous().data(), + dL_dsh.contiguous().data(), + dL_dscales.contiguous().data(), + dL_drotations.contiguous().data(), + debug); + } + + return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); +} + +torch::Tensor markVisible( + torch::Tensor& means3D, + torch::Tensor& viewmatrix, + torch::Tensor& projmatrix) +{ + const int P = means3D.size(0); + + torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); + + if(P != 0) + { + CudaRasterizer::Rasterizer::markVisible(P, + means3D.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + present.contiguous().data()); + } + + return present; +} \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/rasterize_points.h b/submodules/diff-gaussian-rasterization/rasterize_points.h new file mode 100644 index 0000000000000000000000000000000000000000..2fc5e72b5ae6170afe6a510db0e32e00d989313e --- /dev/null +++ b/submodules/diff-gaussian-rasterization/rasterize_points.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#pragma once +#include +#include +#include +#include + +std::tuple +RasterizeGaussiansCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& colors, + const torch::Tensor& opacity, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const bool prefiltered, + const bool debug); + +std::tuple + RasterizeGaussiansBackwardCUDA( + const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& radii, + const torch::Tensor& colors, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const torch::Tensor& dL_dout_color, + const torch::Tensor& dL_dout_depth, + const torch::Tensor& dL_dout_alpha, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const torch::Tensor& geomBuffer, + const int R, + const torch::Tensor& binningBuffer, + const torch::Tensor& imageBuffer, + const torch::Tensor& alpha, + const bool debug); + +torch::Tensor markVisible( + torch::Tensor& means3D, + torch::Tensor& viewmatrix, + torch::Tensor& projmatrix); \ No newline at end of file diff --git a/submodules/diff-gaussian-rasterization/setup.py b/submodules/diff-gaussian-rasterization/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7220d2934d006ea756e35ecb0f391403b43d64 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/setup.py @@ -0,0 +1,34 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import os +os.path.dirname(os.path.abspath(__file__)) + +setup( + name="diff_gaussian_rasterization", + packages=['diff_gaussian_rasterization'], + ext_modules=[ + CUDAExtension( + name="diff_gaussian_rasterization._C", + sources=[ + "cuda_rasterizer/rasterizer_impl.cu", + "cuda_rasterizer/forward.cu", + "cuda_rasterizer/backward.cu", + "rasterize_points.cu", + "ext.cpp"], + extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/submodules/diff-gaussian-rasterization/third_party/stbi_image_write.h b/submodules/diff-gaussian-rasterization/third_party/stbi_image_write.h new file mode 100644 index 0000000000000000000000000000000000000000..023d71e460f4f2d8265d7519a9e3a2c1858fadb0 --- /dev/null +++ b/submodules/diff-gaussian-rasterization/third_party/stbi_image_write.h @@ -0,0 +1,1724 @@ +/* stb_image_write - v1.16 - public domain - http://nothings.org/stb + writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015 + no warranty implied; use at your own risk + + Before #including, + + #define STB_IMAGE_WRITE_IMPLEMENTATION + + in the file that you want to have the implementation. + + Will probably not work correctly with strict-aliasing optimizations. + +ABOUT: + + This header file is a library for writing images to C stdio or a callback. + + The PNG output is not optimal; it is 20-50% larger than the file + written by a decent optimizing implementation; though providing a custom + zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that. + This library is designed for source code compactness and simplicity, + not optimal image file size or run-time performance. + +BUILDING: + + You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h. + You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace + malloc,realloc,free. + You can #define STBIW_MEMMOVE() to replace memmove() + You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function + for PNG compression (instead of the builtin one), it must have the following signature: + unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality); + The returned data will be freed with STBIW_FREE() (free() by default), + so it must be heap allocated with STBIW_MALLOC() (malloc() by default), + +UNICODE: + + If compiling for Windows and you wish to use Unicode filenames, compile + with + #define STBIW_WINDOWS_UTF8 + and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert + Windows wchar_t filenames to utf8. + +USAGE: + + There are five functions, one for each image file format: + + int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); + int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); + int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); + int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality); + int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); + + void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically + + There are also five equivalent functions that use an arbitrary write function. You are + expected to open/close your file-equivalent before and after calling these: + + int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); + int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); + int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); + int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); + int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); + + where the callback is: + void stbi_write_func(void *context, void *data, int size); + + You can configure it with these global variables: + int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE + int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression + int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode + + + You can define STBI_WRITE_NO_STDIO to disable the file variant of these + functions, so the library will not use stdio.h at all. However, this will + also disable HDR writing, because it requires stdio for formatted output. + + Each function returns 0 on failure and non-0 on success. + + The functions create an image file defined by the parameters. The image + is a rectangle of pixels stored from left-to-right, top-to-bottom. + Each pixel contains 'comp' channels of data stored interleaved with 8-bits + per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is + monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall. + The *data pointer points to the first byte of the top-left-most pixel. + For PNG, "stride_in_bytes" is the distance in bytes from the first byte of + a row of pixels to the first byte of the next row of pixels. + + PNG creates output files with the same number of components as the input. + The BMP format expands Y to RGB in the file format and does not + output alpha. + + PNG supports writing rectangles of data even when the bytes storing rows of + data are not consecutive in memory (e.g. sub-rectangles of a larger image), + by supplying the stride between the beginning of adjacent rows. The other + formats do not. (Thus you cannot write a native-format BMP through the BMP + writer, both because it is in BGR order and because it may have padding + at the end of the line.) + + PNG allows you to set the deflate compression level by setting the global + variable 'stbi_write_png_compression_level' (it defaults to 8). + + HDR expects linear float data. Since the format is always 32-bit rgb(e) + data, alpha (if provided) is discarded, and for monochrome data it is + replicated across all three channels. + + TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed + data, set the global variable 'stbi_write_tga_with_rle' to 0. + + JPEG does ignore alpha channels in input data; quality is between 1 and 100. + Higher quality looks better but results in a bigger image. + JPEG baseline (no JPEG progressive). + +CREDITS: + + + Sean Barrett - PNG/BMP/TGA + Baldur Karlsson - HDR + Jean-Sebastien Guay - TGA monochrome + Tim Kelsey - misc enhancements + Alan Hickman - TGA RLE + Emmanuel Julien - initial file IO callback implementation + Jon Olick - original jo_jpeg.cpp code + Daniel Gibson - integrate JPEG, allow external zlib + Aarni Koskela - allow choosing PNG filter + + bugfixes: + github:Chribba + Guillaume Chereau + github:jry2 + github:romigrou + Sergio Gonzalez + Jonas Karlsson + Filip Wasil + Thatcher Ulrich + github:poppolopoppo + Patrick Boettcher + github:xeekworx + Cap Petschulat + Simon Rodriguez + Ivan Tikhonov + github:ignotion + Adam Schackart + Andrew Kensler + +LICENSE + + See end of file for license information. + +*/ + +#ifndef INCLUDE_STB_IMAGE_WRITE_H +#define INCLUDE_STB_IMAGE_WRITE_H + +#include + +// if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline' +#ifndef STBIWDEF +#ifdef STB_IMAGE_WRITE_STATIC +#define STBIWDEF static +#else +#ifdef __cplusplus +#define STBIWDEF extern "C" +#else +#define STBIWDEF extern +#endif +#endif +#endif + +#ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations +STBIWDEF int stbi_write_tga_with_rle; +STBIWDEF int stbi_write_png_compression_level; +STBIWDEF int stbi_write_force_png_filter; +#endif + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); +STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); + +#ifdef STBIW_WINDOWS_UTF8 +STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +#endif +#endif + +typedef void stbi_write_func(void *context, void *data, int size); + +STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); +STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); +STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); + +STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean); + +#endif//INCLUDE_STB_IMAGE_WRITE_H + +#ifdef STB_IMAGE_WRITE_IMPLEMENTATION + +#ifdef _WIN32 + #ifndef _CRT_SECURE_NO_WARNINGS + #define _CRT_SECURE_NO_WARNINGS + #endif + #ifndef _CRT_NONSTDC_NO_DEPRECATE + #define _CRT_NONSTDC_NO_DEPRECATE + #endif +#endif + +#ifndef STBI_WRITE_NO_STDIO +#include +#endif // STBI_WRITE_NO_STDIO + +#include +#include +#include +#include + +#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED)) +// ok +#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED) +// ok +#else +#error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)." +#endif + +#ifndef STBIW_MALLOC +#define STBIW_MALLOC(sz) malloc(sz) +#define STBIW_REALLOC(p,newsz) realloc(p,newsz) +#define STBIW_FREE(p) free(p) +#endif + +#ifndef STBIW_REALLOC_SIZED +#define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz) +#endif + + +#ifndef STBIW_MEMMOVE +#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz) +#endif + + +#ifndef STBIW_ASSERT +#include +#define STBIW_ASSERT(x) assert(x) +#endif + +#define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff) + +#ifdef STB_IMAGE_WRITE_STATIC +static int stbi_write_png_compression_level = 8; +static int stbi_write_tga_with_rle = 1; +static int stbi_write_force_png_filter = -1; +#else +int stbi_write_png_compression_level = 8; +int stbi_write_tga_with_rle = 1; +int stbi_write_force_png_filter = -1; +#endif + +static int stbi__flip_vertically_on_write = 0; + +STBIWDEF void stbi_flip_vertically_on_write(int flag) +{ + stbi__flip_vertically_on_write = flag; +} + +typedef struct +{ + stbi_write_func *func; + void *context; + unsigned char buffer[64]; + int buf_used; +} stbi__write_context; + +// initialize a callback-based context +static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context) +{ + s->func = c; + s->context = context; +} + +#ifndef STBI_WRITE_NO_STDIO + +static void stbi__stdio_write(void *context, void *data, int size) +{ + fwrite(data,1,size,(FILE*) context); +} + +#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) +#ifdef __cplusplus +#define STBIW_EXTERN extern "C" +#else +#define STBIW_EXTERN extern +#endif +STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); + +STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +} +#endif + +static FILE *stbiw__fopen(char const *filename, char const *mode) +{ + FILE *f; +#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; + + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; +#else + f = _wfopen(wFilename, wMode); +#endif + +#elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != fopen_s(&f, filename, mode)) + f=0; +#else + f = fopen(filename, mode); +#endif + return f; +} + +static int stbi__start_write_file(stbi__write_context *s, const char *filename) +{ + FILE *f = stbiw__fopen(filename, "wb"); + stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f); + return f != NULL; +} + +static void stbi__end_write_file(stbi__write_context *s) +{ + fclose((FILE *)s->context); +} + +#endif // !STBI_WRITE_NO_STDIO + +typedef unsigned int stbiw_uint32; +typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1]; + +static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v) +{ + while (*fmt) { + switch (*fmt++) { + case ' ': break; + case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int)); + s->func(s->context,&x,1); + break; } + case '2': { int x = va_arg(v,int); + unsigned char b[2]; + b[0] = STBIW_UCHAR(x); + b[1] = STBIW_UCHAR(x>>8); + s->func(s->context,b,2); + break; } + case '4': { stbiw_uint32 x = va_arg(v,int); + unsigned char b[4]; + b[0]=STBIW_UCHAR(x); + b[1]=STBIW_UCHAR(x>>8); + b[2]=STBIW_UCHAR(x>>16); + b[3]=STBIW_UCHAR(x>>24); + s->func(s->context,b,4); + break; } + default: + STBIW_ASSERT(0); + return; + } + } +} + +static void stbiw__writef(stbi__write_context *s, const char *fmt, ...) +{ + va_list v; + va_start(v, fmt); + stbiw__writefv(s, fmt, v); + va_end(v); +} + +static void stbiw__write_flush(stbi__write_context *s) +{ + if (s->buf_used) { + s->func(s->context, &s->buffer, s->buf_used); + s->buf_used = 0; + } +} + +static void stbiw__putc(stbi__write_context *s, unsigned char c) +{ + s->func(s->context, &c, 1); +} + +static void stbiw__write1(stbi__write_context *s, unsigned char a) +{ + if ((size_t)s->buf_used + 1 > sizeof(s->buffer)) + stbiw__write_flush(s); + s->buffer[s->buf_used++] = a; +} + +static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c) +{ + int n; + if ((size_t)s->buf_used + 3 > sizeof(s->buffer)) + stbiw__write_flush(s); + n = s->buf_used; + s->buf_used = n+3; + s->buffer[n+0] = a; + s->buffer[n+1] = b; + s->buffer[n+2] = c; +} + +static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d) +{ + unsigned char bg[3] = { 255, 0, 255}, px[3]; + int k; + + if (write_alpha < 0) + stbiw__write1(s, d[comp - 1]); + + switch (comp) { + case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case + case 1: + if (expand_mono) + stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp + else + stbiw__write1(s, d[0]); // monochrome TGA + break; + case 4: + if (!write_alpha) { + // composite against pink background + for (k = 0; k < 3; ++k) + px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255; + stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]); + break; + } + /* FALLTHROUGH */ + case 3: + stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]); + break; + } + if (write_alpha > 0) + stbiw__write1(s, d[comp - 1]); +} + +static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono) +{ + stbiw_uint32 zero = 0; + int i,j, j_end; + + if (y <= 0) + return; + + if (stbi__flip_vertically_on_write) + vdir *= -1; + + if (vdir < 0) { + j_end = -1; j = y-1; + } else { + j_end = y; j = 0; + } + + for (; j != j_end; j += vdir) { + for (i=0; i < x; ++i) { + unsigned char *d = (unsigned char *) data + (j*x+i)*comp; + stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d); + } + stbiw__write_flush(s); + s->func(s->context, &zero, scanline_pad); + } +} + +static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...) +{ + if (y < 0 || x < 0) { + return 0; + } else { + va_list v; + va_start(v, fmt); + stbiw__writefv(s, fmt, v); + va_end(v); + stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono); + return 1; + } +} + +static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data) +{ + if (comp != 4) { + // write RGB bitmap + int pad = (-x*3) & 3; + return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad, + "11 4 22 4" "4 44 22 444444", + 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header + 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header + } else { + // RGBA bitmaps need a v4 header + // use BI_BITFIELDS mode with 32bpp and alpha mask + // (straight BI_RGB with alpha mask doesn't work in most readers) + return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0, + "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444", + 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header + 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header + } +} + +STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_bmp_core(&s, x, y, comp, data); +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_bmp_core(&s, x, y, comp, data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif //!STBI_WRITE_NO_STDIO + +static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data) +{ + int has_alpha = (comp == 2 || comp == 4); + int colorbytes = has_alpha ? comp-1 : comp; + int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3 + + if (y < 0 || x < 0) + return 0; + + if (!stbi_write_tga_with_rle) { + return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0, + "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8); + } else { + int i,j,k; + int jend, jdir; + + stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8); + + if (stbi__flip_vertically_on_write) { + j = 0; + jend = y; + jdir = 1; + } else { + j = y-1; + jend = -1; + jdir = -1; + } + for (; j != jend; j += jdir) { + unsigned char *row = (unsigned char *) data + j * x * comp; + int len; + + for (i = 0; i < x; i += len) { + unsigned char *begin = row + i * comp; + int diff = 1; + len = 1; + + if (i < x - 1) { + ++len; + diff = memcmp(begin, row + (i + 1) * comp, comp); + if (diff) { + const unsigned char *prev = begin; + for (k = i + 2; k < x && len < 128; ++k) { + if (memcmp(prev, row + k * comp, comp)) { + prev += comp; + ++len; + } else { + --len; + break; + } + } + } else { + for (k = i + 2; k < x && len < 128; ++k) { + if (!memcmp(begin, row + k * comp, comp)) { + ++len; + } else { + break; + } + } + } + } + + if (diff) { + unsigned char header = STBIW_UCHAR(len - 1); + stbiw__write1(s, header); + for (k = 0; k < len; ++k) { + stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp); + } + } else { + unsigned char header = STBIW_UCHAR(len - 129); + stbiw__write1(s, header); + stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin); + } + } + } + stbiw__write_flush(s); + } + return 1; +} + +STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_tga_core(&s, x, y, comp, (void *) data); +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_tga_core(&s, x, y, comp, (void *) data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif + +// ************************************************************************************************* +// Radiance RGBE HDR writer +// by Baldur Karlsson + +#define stbiw__max(a, b) ((a) > (b) ? (a) : (b)) + +#ifndef STBI_WRITE_NO_STDIO + +static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear) +{ + int exponent; + float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2])); + + if (maxcomp < 1e-32f) { + rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0; + } else { + float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp; + + rgbe[0] = (unsigned char)(linear[0] * normalize); + rgbe[1] = (unsigned char)(linear[1] * normalize); + rgbe[2] = (unsigned char)(linear[2] * normalize); + rgbe[3] = (unsigned char)(exponent + 128); + } +} + +static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte) +{ + unsigned char lengthbyte = STBIW_UCHAR(length+128); + STBIW_ASSERT(length+128 <= 255); + s->func(s->context, &lengthbyte, 1); + s->func(s->context, &databyte, 1); +} + +static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data) +{ + unsigned char lengthbyte = STBIW_UCHAR(length); + STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code + s->func(s->context, &lengthbyte, 1); + s->func(s->context, data, length); +} + +static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline) +{ + unsigned char scanlineheader[4] = { 2, 2, 0, 0 }; + unsigned char rgbe[4]; + float linear[3]; + int x; + + scanlineheader[2] = (width&0xff00)>>8; + scanlineheader[3] = (width&0x00ff); + + /* skip RLE for images too small or large */ + if (width < 8 || width >= 32768) { + for (x=0; x < width; x++) { + switch (ncomp) { + case 4: /* fallthrough */ + case 3: linear[2] = scanline[x*ncomp + 2]; + linear[1] = scanline[x*ncomp + 1]; + linear[0] = scanline[x*ncomp + 0]; + break; + default: + linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; + break; + } + stbiw__linear_to_rgbe(rgbe, linear); + s->func(s->context, rgbe, 4); + } + } else { + int c,r; + /* encode into scratch buffer */ + for (x=0; x < width; x++) { + switch(ncomp) { + case 4: /* fallthrough */ + case 3: linear[2] = scanline[x*ncomp + 2]; + linear[1] = scanline[x*ncomp + 1]; + linear[0] = scanline[x*ncomp + 0]; + break; + default: + linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; + break; + } + stbiw__linear_to_rgbe(rgbe, linear); + scratch[x + width*0] = rgbe[0]; + scratch[x + width*1] = rgbe[1]; + scratch[x + width*2] = rgbe[2]; + scratch[x + width*3] = rgbe[3]; + } + + s->func(s->context, scanlineheader, 4); + + /* RLE each component separately */ + for (c=0; c < 4; c++) { + unsigned char *comp = &scratch[width*c]; + + x = 0; + while (x < width) { + // find first run + r = x; + while (r+2 < width) { + if (comp[r] == comp[r+1] && comp[r] == comp[r+2]) + break; + ++r; + } + if (r+2 >= width) + r = width; + // dump up to first run + while (x < r) { + int len = r-x; + if (len > 128) len = 128; + stbiw__write_dump_data(s, len, &comp[x]); + x += len; + } + // if there's a run, output it + if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd + // find next byte after run + while (r < width && comp[r] == comp[x]) + ++r; + // output run up to r + while (x < r) { + int len = r-x; + if (len > 127) len = 127; + stbiw__write_run_data(s, len, comp[x]); + x += len; + } + } + } + } + } +} + +static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data) +{ + if (y <= 0 || x <= 0 || data == NULL) + return 0; + else { + // Each component is stored separately. Allocate scratch space for full output scanline. + unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4); + int i, len; + char buffer[128]; + char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n"; + s->func(s->context, header, sizeof(header)-1); + +#ifdef __STDC_LIB_EXT1__ + len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); +#else + len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); +#endif + s->func(s->context, buffer, len); + + for(i=0; i < y; i++) + stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i)); + STBIW_FREE(scratch); + return 1; + } +} + +STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_hdr_core(&s, x, y, comp, (float *) data); +} + +STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif // STBI_WRITE_NO_STDIO + + +////////////////////////////////////////////////////////////////////////////// +// +// PNG writer +// + +#ifndef STBIW_ZLIB_COMPRESS +// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size() +#define stbiw__sbraw(a) ((int *) (void *) (a) - 2) +#define stbiw__sbm(a) stbiw__sbraw(a)[0] +#define stbiw__sbn(a) stbiw__sbraw(a)[1] + +#define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a)) +#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0) +#define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a))) + +#define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v)) +#define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0) +#define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0) + +static void *stbiw__sbgrowf(void **arr, int increment, int itemsize) +{ + int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1; + void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2); + STBIW_ASSERT(p); + if (p) { + if (!*arr) ((int *) p)[1] = 0; + *arr = (void *) ((int *) p + 2); + stbiw__sbm(*arr) = m; + } + return *arr; +} + +static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount) +{ + while (*bitcount >= 8) { + stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer)); + *bitbuffer >>= 8; + *bitcount -= 8; + } + return data; +} + +static int stbiw__zlib_bitrev(int code, int codebits) +{ + int res=0; + while (codebits--) { + res = (res << 1) | (code & 1); + code >>= 1; + } + return res; +} + +static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit) +{ + int i; + for (i=0; i < limit && i < 258; ++i) + if (a[i] != b[i]) break; + return i; +} + +static unsigned int stbiw__zhash(unsigned char *data) +{ + stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16); + hash ^= hash << 3; + hash += hash >> 5; + hash ^= hash << 4; + hash += hash >> 17; + hash ^= hash << 25; + hash += hash >> 6; + return hash; +} + +#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount)) +#define stbiw__zlib_add(code,codebits) \ + (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush()) +#define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c) +// default huffman tables +#define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8) +#define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9) +#define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7) +#define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8) +#define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n)) +#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n)) + +#define stbiw__ZHASH 16384 + +#endif // STBIW_ZLIB_COMPRESS + +STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality) +{ +#ifdef STBIW_ZLIB_COMPRESS + // user provided a zlib compress implementation, use that + return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality); +#else // use builtin + static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 }; + static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 }; + static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 }; + static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 }; + unsigned int bitbuf=0; + int i,j, bitcount=0; + unsigned char *out = NULL; + unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**)); + if (hash_table == NULL) + return NULL; + if (quality < 5) quality = 5; + + stbiw__sbpush(out, 0x78); // DEFLATE 32K window + stbiw__sbpush(out, 0x5e); // FLEVEL = 1 + stbiw__zlib_add(1,1); // BFINAL = 1 + stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman + + for (i=0; i < stbiw__ZHASH; ++i) + hash_table[i] = NULL; + + i=0; + while (i < data_len-3) { + // hash next 3 bytes of data to be compressed + int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3; + unsigned char *bestloc = 0; + unsigned char **hlist = hash_table[h]; + int n = stbiw__sbcount(hlist); + for (j=0; j < n; ++j) { + if (hlist[j]-data > i-32768) { // if entry lies within window + int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i); + if (d >= best) { best=d; bestloc=hlist[j]; } + } + } + // when hash table entry is too long, delete half the entries + if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) { + STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality); + stbiw__sbn(hash_table[h]) = quality; + } + stbiw__sbpush(hash_table[h],data+i); + + if (bestloc) { + // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal + h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1); + hlist = hash_table[h]; + n = stbiw__sbcount(hlist); + for (j=0; j < n; ++j) { + if (hlist[j]-data > i-32767) { + int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1); + if (e > best) { // if next match is better, bail on current match + bestloc = NULL; + break; + } + } + } + } + + if (bestloc) { + int d = (int) (data+i - bestloc); // distance back + STBIW_ASSERT(d <= 32767 && best <= 258); + for (j=0; best > lengthc[j+1]-1; ++j); + stbiw__zlib_huff(j+257); + if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]); + for (j=0; d > distc[j+1]-1; ++j); + stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5); + if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]); + i += best; + } else { + stbiw__zlib_huffb(data[i]); + ++i; + } + } + // write out final bytes + for (;i < data_len; ++i) + stbiw__zlib_huffb(data[i]); + stbiw__zlib_huff(256); // end of block + // pad with 0 bits to byte boundary + while (bitcount) + stbiw__zlib_add(0,1); + + for (i=0; i < stbiw__ZHASH; ++i) + (void) stbiw__sbfree(hash_table[i]); + STBIW_FREE(hash_table); + + // store uncompressed instead if compression was worse + if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) { + stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1 + for (j = 0; j < data_len;) { + int blocklen = data_len - j; + if (blocklen > 32767) blocklen = 32767; + stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression + stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN + stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN + stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8)); + memcpy(out+stbiw__sbn(out), data+j, blocklen); + stbiw__sbn(out) += blocklen; + j += blocklen; + } + } + + { + // compute adler32 on input + unsigned int s1=1, s2=0; + int blocklen = (int) (data_len % 5552); + j=0; + while (j < data_len) { + for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; } + s1 %= 65521; s2 %= 65521; + j += blocklen; + blocklen = 5552; + } + stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(s2)); + stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(s1)); + } + *out_len = stbiw__sbn(out); + // make returned pointer freeable + STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len); + return (unsigned char *) stbiw__sbraw(out); +#endif // STBIW_ZLIB_COMPRESS +} + +static unsigned int stbiw__crc32(unsigned char *buffer, int len) +{ +#ifdef STBIW_CRC32 + return STBIW_CRC32(buffer, len); +#else + static unsigned int crc_table[256] = + { + 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3, + 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, + 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, + 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, + 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, + 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, + 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F, + 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, + 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, + 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01, + 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, + 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, + 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, + 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, + 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F, + 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD, + 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, + 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, + 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, + 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, + 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, + 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79, + 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, + 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, + 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713, + 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, + 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, + 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, + 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, + 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, + 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF, + 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D + }; + + unsigned int crc = ~0u; + int i; + for (i=0; i < len; ++i) + crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)]; + return ~crc; +#endif +} + +#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4) +#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v)); +#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3]) + +static void stbiw__wpcrc(unsigned char **data, int len) +{ + unsigned int crc = stbiw__crc32(*data - len - 4, len+4); + stbiw__wp32(*data, crc); +} + +static unsigned char stbiw__paeth(int a, int b, int c) +{ + int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c); + if (pa <= pb && pa <= pc) return STBIW_UCHAR(a); + if (pb <= pc) return STBIW_UCHAR(b); + return STBIW_UCHAR(c); +} + +// @OPTIMIZE: provide an option that always forces left-predict or paeth predict +static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer) +{ + static int mapping[] = { 0,1,2,3,4 }; + static int firstmap[] = { 0,1,0,5,6 }; + int *mymap = (y != 0) ? mapping : firstmap; + int i; + int type = mymap[filter_type]; + unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y); + int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes; + + if (type==0) { + memcpy(line_buffer, z, width*n); + return; + } + + // first loop isn't optimized since it's just one pixel + for (i = 0; i < n; ++i) { + switch (type) { + case 1: line_buffer[i] = z[i]; break; + case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break; + case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break; + case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break; + case 5: line_buffer[i] = z[i]; break; + case 6: line_buffer[i] = z[i]; break; + } + } + switch (type) { + case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break; + case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break; + case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break; + case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break; + case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break; + case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break; + } +} + +STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len) +{ + int force_filter = stbi_write_force_png_filter; + int ctype[5] = { -1, 0, 4, 2, 6 }; + unsigned char sig[8] = { 137,80,78,71,13,10,26,10 }; + unsigned char *out,*o, *filt, *zlib; + signed char *line_buffer; + int j,zlen; + + if (stride_bytes == 0) + stride_bytes = x * n; + + if (force_filter >= 5) { + force_filter = -1; + } + + filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0; + line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; } + for (j=0; j < y; ++j) { + int filter_type; + if (force_filter > -1) { + filter_type = force_filter; + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer); + } else { // Estimate the best filter by running through all of them: + int best_filter = 0, best_filter_val = 0x7fffffff, est, i; + for (filter_type = 0; filter_type < 5; filter_type++) { + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer); + + // Estimate the entropy of the line using this filter; the less, the better. + est = 0; + for (i = 0; i < x*n; ++i) { + est += abs((signed char) line_buffer[i]); + } + if (est < best_filter_val) { + best_filter_val = est; + best_filter = filter_type; + } + } + if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer); + filter_type = best_filter; + } + } + // when we get here, filter_type contains the filter type, and line_buffer contains the data + filt[j*(x*n+1)] = (unsigned char) filter_type; + STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n); + } + STBIW_FREE(line_buffer); + zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level); + STBIW_FREE(filt); + if (!zlib) return 0; + + // each tag requires 12 bytes of overhead + out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12); + if (!out) return 0; + *out_len = 8 + 12+13 + 12+zlen + 12; + + o=out; + STBIW_MEMMOVE(o,sig,8); o+= 8; + stbiw__wp32(o, 13); // header length + stbiw__wptag(o, "IHDR"); + stbiw__wp32(o, x); + stbiw__wp32(o, y); + *o++ = 8; + *o++ = STBIW_UCHAR(ctype[n]); + *o++ = 0; + *o++ = 0; + *o++ = 0; + stbiw__wpcrc(&o,13); + + stbiw__wp32(o, zlen); + stbiw__wptag(o, "IDAT"); + STBIW_MEMMOVE(o, zlib, zlen); + o += zlen; + STBIW_FREE(zlib); + stbiw__wpcrc(&o, zlen); + + stbiw__wp32(o,0); + stbiw__wptag(o, "IEND"); + stbiw__wpcrc(&o,0); + + STBIW_ASSERT(o == out + *out_len); + + return out; +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes) +{ + FILE *f; + int len; + unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); + if (png == NULL) return 0; + + f = stbiw__fopen(filename, "wb"); + if (!f) { STBIW_FREE(png); return 0; } + fwrite(png, 1, len, f); + fclose(f); + STBIW_FREE(png); + return 1; +} +#endif + +STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes) +{ + int len; + unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); + if (png == NULL) return 0; + func(context, png, len); + STBIW_FREE(png); + return 1; +} + + +/* *************************************************************************** + * + * JPEG writer + * + * This is based on Jon Olick's jo_jpeg.cpp: + * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html + */ + +static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18, + 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 }; + +static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) { + int bitBuf = *bitBufP, bitCnt = *bitCntP; + bitCnt += bs[1]; + bitBuf |= bs[0] << (24 - bitCnt); + while(bitCnt >= 8) { + unsigned char c = (bitBuf >> 16) & 255; + stbiw__putc(s, c); + if(c == 255) { + stbiw__putc(s, 0); + } + bitBuf <<= 8; + bitCnt -= 8; + } + *bitBufP = bitBuf; + *bitCntP = bitCnt; +} + +static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) { + float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p; + float z1, z2, z3, z4, z5, z11, z13; + + float tmp0 = d0 + d7; + float tmp7 = d0 - d7; + float tmp1 = d1 + d6; + float tmp6 = d1 - d6; + float tmp2 = d2 + d5; + float tmp5 = d2 - d5; + float tmp3 = d3 + d4; + float tmp4 = d3 - d4; + + // Even part + float tmp10 = tmp0 + tmp3; // phase 2 + float tmp13 = tmp0 - tmp3; + float tmp11 = tmp1 + tmp2; + float tmp12 = tmp1 - tmp2; + + d0 = tmp10 + tmp11; // phase 3 + d4 = tmp10 - tmp11; + + z1 = (tmp12 + tmp13) * 0.707106781f; // c4 + d2 = tmp13 + z1; // phase 5 + d6 = tmp13 - z1; + + // Odd part + tmp10 = tmp4 + tmp5; // phase 2 + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + // The rotator is modified from fig 4-8 to avoid extra negations. + z5 = (tmp10 - tmp12) * 0.382683433f; // c6 + z2 = tmp10 * 0.541196100f + z5; // c2-c6 + z4 = tmp12 * 1.306562965f + z5; // c2+c6 + z3 = tmp11 * 0.707106781f; // c4 + + z11 = tmp7 + z3; // phase 5 + z13 = tmp7 - z3; + + *d5p = z13 + z2; // phase 6 + *d3p = z13 - z2; + *d1p = z11 + z4; + *d7p = z11 - z4; + + *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6; +} + +static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) { + int tmp1 = val < 0 ? -val : val; + val = val < 0 ? val-1 : val; + bits[1] = 1; + while(tmp1 >>= 1) { + ++bits[1]; + } + bits[0] = val & ((1<0)&&(DU[end0pos]==0); --end0pos) { + } + // end0pos = first element in reverse order !=0 + if(end0pos == 0) { + stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); + return DU[0]; + } + for(i = 1; i <= end0pos; ++i) { + int startpos = i; + int nrzeroes; + unsigned short bits[2]; + for (; DU[i]==0 && i<=end0pos; ++i) { + } + nrzeroes = i-startpos; + if ( nrzeroes >= 16 ) { + int lng = nrzeroes>>4; + int nrmarker; + for (nrmarker=1; nrmarker <= lng; ++nrmarker) + stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes); + nrzeroes &= 15; + } + stbiw__jpg_calcBits(DU[i], bits); + stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]); + stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits); + } + if(end0pos != 63) { + stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); + } + return DU[0]; +} + +static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { + // Constants that don't pollute global namespace + static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; + static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; + static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d}; + static const unsigned char std_ac_luminance_values[] = { + 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08, + 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28, + 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59, + 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89, + 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6, + 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2, + 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa + }; + static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0}; + static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; + static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77}; + static const unsigned char std_ac_chrominance_values[] = { + 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91, + 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26, + 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58, + 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87, + 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4, + 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda, + 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa + }; + // Huffman tables + static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}}; + static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}}; + static const unsigned short YAC_HT[256][2] = { + {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} + }; + static const unsigned short UVAC_HT[256][2] = { + {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} + }; + static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22, + 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99}; + static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99, + 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99}; + static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f, + 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f }; + + int row, col, i, k, subsample; + float fdtbl_Y[64], fdtbl_UV[64]; + unsigned char YTable[64], UVTable[64]; + + if(!data || !width || !height || comp > 4 || comp < 1) { + return 0; + } + + quality = quality ? quality : 90; + subsample = quality <= 90 ? 1 : 0; + quality = quality < 1 ? 1 : quality > 100 ? 100 : quality; + quality = quality < 50 ? 5000 / quality : 200 - quality * 2; + + for(i = 0; i < 64; ++i) { + int uvti, yti = (YQT[i]*quality+50)/100; + YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti); + uvti = (UVQT[i]*quality+50)/100; + UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti); + } + + for(row = 0, k = 0; row < 8; ++row) { + for(col = 0; col < 8; ++col, ++k) { + fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); + fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); + } + } + + // Write Headers + { + static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 }; + static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 }; + const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width), + 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 }; + s->func(s->context, (void*)head0, sizeof(head0)); + s->func(s->context, (void*)YTable, sizeof(YTable)); + stbiw__putc(s, 1); + s->func(s->context, UVTable, sizeof(UVTable)); + s->func(s->context, (void*)head1, sizeof(head1)); + s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); + s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); + stbiw__putc(s, 0x10); // HTYACinfo + s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1); + s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values)); + stbiw__putc(s, 1); // HTUDCinfo + s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1); + s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values)); + stbiw__putc(s, 0x11); // HTUACinfo + s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1); + s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values)); + s->func(s->context, (void*)head2, sizeof(head2)); + } + + // Encode 8x8 macroblocks + { + static const unsigned short fillBits[] = {0x7F, 7}; + int DCY=0, DCU=0, DCV=0; + int bitBuf=0, bitCnt=0; + // comp == 2 is grey+alpha (alpha is ignored) + int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0; + const unsigned char *dataR = (const unsigned char *)data; + const unsigned char *dataG = dataR + ofsG; + const unsigned char *dataB = dataR + ofsB; + int x, y, pos; + if(subsample) { + for(y = 0; y < height; y += 16) { + for(x = 0; x < width; x += 16) { + float Y[256], U[256], V[256]; + for(row = y, pos = 0; row < y+16; ++row) { + // row >= height => use last input row + int clamped_row = (row < height) ? row : height - 1; + int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; + for(col = x; col < x+16; ++col, ++pos) { + // if col >= width => use pixel from last input column + int p = base_p + ((col < width) ? col : (width-1))*comp; + float r = dataR[p], g = dataG[p], b = dataB[p]; + Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; + U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; + V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; + } + } + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + + // subsample U,V + { + float subU[64], subV[64]; + int yy, xx; + for(yy = 0, pos = 0; yy < 8; ++yy) { + for(xx = 0; xx < 8; ++xx, ++pos) { + int j = yy*32+xx*2; + subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f; + subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f; + } + } + DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); + DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); + } + } + } + } else { + for(y = 0; y < height; y += 8) { + for(x = 0; x < width; x += 8) { + float Y[64], U[64], V[64]; + for(row = y, pos = 0; row < y+8; ++row) { + // row >= height => use last input row + int clamped_row = (row < height) ? row : height - 1; + int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; + for(col = x; col < x+8; ++col, ++pos) { + // if col >= width => use pixel from last input column + int p = base_p + ((col < width) ? col : (width-1))*comp; + float r = dataR[p], g = dataG[p], b = dataB[p]; + Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; + U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; + V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; + } + } + + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); + DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); + } + } + } + + // Do the bit alignment of the EOI marker + stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits); + } + + // EOI + stbiw__putc(s, 0xFF); + stbiw__putc(s, 0xD9); + + return 1; +} + +STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); +} + + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif + +#endif // STB_IMAGE_WRITE_IMPLEMENTATION + +/* Revision history + 1.16 (2021-07-11) + make Deflate code emit uncompressed blocks when it would otherwise expand + support writing BMPs with alpha channel + 1.15 (2020-07-13) unknown + 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels + 1.13 + 1.12 + 1.11 (2019-08-11) + + 1.10 (2019-02-07) + support utf8 filenames in Windows; fix warnings and platform ifdefs + 1.09 (2018-02-11) + fix typo in zlib quality API, improve STB_I_W_STATIC in C++ + 1.08 (2018-01-29) + add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter + 1.07 (2017-07-24) + doc fix + 1.06 (2017-07-23) + writing JPEG (using Jon Olick's code) + 1.05 ??? + 1.04 (2017-03-03) + monochrome BMP expansion + 1.03 ??? + 1.02 (2016-04-02) + avoid allocating large structures on the stack + 1.01 (2016-01-16) + STBIW_REALLOC_SIZED: support allocators with no realloc support + avoid race-condition in crc initialization + minor compile issues + 1.00 (2015-09-14) + installable file IO function + 0.99 (2015-09-13) + warning fixes; TGA rle support + 0.98 (2015-04-08) + added STBIW_MALLOC, STBIW_ASSERT etc + 0.97 (2015-01-18) + fixed HDR asserts, rewrote HDR rle logic + 0.96 (2015-01-17) + add HDR output + fix monochrome BMP + 0.95 (2014-08-17) + add monochrome TGA output + 0.94 (2014-05-31) + rename private functions to avoid conflicts with stb_image.h + 0.93 (2014-05-27) + warning fixes + 0.92 (2010-08-01) + casts to unsigned char to fix warnings + 0.91 (2010-07-17) + first public release + 0.90 first internal release +*/ + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ \ No newline at end of file diff --git a/submodules/simple-knn/ext.cpp b/submodules/simple-knn/ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae6cefe6ce61a38352a88d07b69a8e6cb9de5b31 --- /dev/null +++ b/submodules/simple-knn/ext.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include +#include "spatial.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("distCUDA2", &distCUDA2); +} diff --git a/submodules/simple-knn/setup.py b/submodules/simple-knn/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..580d2bd8dc190ce642d87501d53de4f6d9d46c64 --- /dev/null +++ b/submodules/simple-knn/setup.py @@ -0,0 +1,35 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import os + +cxx_compiler_flags = [] + +if os.name == 'nt': + cxx_compiler_flags.append("/wd4624") + +setup( + name="simple_knn", + ext_modules=[ + CUDAExtension( + name="simple_knn._C", + sources=[ + "spatial.cu", + "simple_knn.cu", + "ext.cpp"], + extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/submodules/simple-knn/simple_knn.cu b/submodules/simple-knn/simple_knn.cu new file mode 100644 index 0000000000000000000000000000000000000000..e72e4c96ea9d161514835fc2fcee62b94954f2d9 --- /dev/null +++ b/submodules/simple-knn/simple_knn.cu @@ -0,0 +1,221 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#define BOX_SIZE 1024 + +#include "cuda_runtime.h" +#include "device_launch_parameters.h" +#include "simple_knn.h" +#include +#include +#include +#include +#include +#include +#define __CUDACC__ +#include +#include + +namespace cg = cooperative_groups; + +struct CustomMin +{ + __device__ __forceinline__ + float3 operator()(const float3& a, const float3& b) const { + return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; + } +}; + +struct CustomMax +{ + __device__ __forceinline__ + float3 operator()(const float3& a, const float3& b) const { + return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; + } +}; + +__host__ __device__ uint32_t prepMorton(uint32_t x) +{ + x = (x | (x << 16)) & 0x030000FF; + x = (x | (x << 8)) & 0x0300F00F; + x = (x | (x << 4)) & 0x030C30C3; + x = (x | (x << 2)) & 0x09249249; + return x; +} + +__host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) +{ + uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); + uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); + uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); + + return x | (y << 1) | (z << 2); +} + +__global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + codes[idx] = coord2Morton(points[idx], minn, maxx); +} + +struct MinMax +{ + float3 minn; + float3 maxx; +}; + +__global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) +{ + auto idx = cg::this_grid().thread_rank(); + + MinMax me; + if (idx < P) + { + me.minn = points[indices[idx]]; + me.maxx = points[indices[idx]]; + } + else + { + me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; + me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; + } + + __shared__ MinMax redResult[BOX_SIZE]; + + for (int off = BOX_SIZE / 2; off >= 1; off /= 2) + { + if (threadIdx.x < 2 * off) + redResult[threadIdx.x] = me; + __syncthreads(); + + if (threadIdx.x < off) + { + MinMax other = redResult[threadIdx.x + off]; + me.minn.x = min(me.minn.x, other.minn.x); + me.minn.y = min(me.minn.y, other.minn.y); + me.minn.z = min(me.minn.z, other.minn.z); + me.maxx.x = max(me.maxx.x, other.maxx.x); + me.maxx.y = max(me.maxx.y, other.maxx.y); + me.maxx.z = max(me.maxx.z, other.maxx.z); + } + __syncthreads(); + } + + if (threadIdx.x == 0) + boxes[blockIdx.x] = me; +} + +__device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) +{ + float3 diff = { 0, 0, 0 }; + if (p.x < box.minn.x || p.x > box.maxx.x) + diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); + if (p.y < box.minn.y || p.y > box.maxx.y) + diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); + if (p.z < box.minn.z || p.z > box.maxx.z) + diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); + return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; +} + +template +__device__ void updateKBest(const float3& ref, const float3& point, float* knn) +{ + float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; + float dist = d.x * d.x + d.y * d.y + d.z * d.z; + for (int j = 0; j < K; j++) + { + if (knn[j] > dist) + { + float t = knn[j]; + knn[j] = dist; + dist = t; + } + } +} + +__global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) +{ + int idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + float3 point = points[indices[idx]]; + float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; + + for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) + { + if (i == idx) + continue; + updateKBest<3>(point, points[indices[i]], best); + } + + float reject = best[2]; + best[0] = FLT_MAX; + best[1] = FLT_MAX; + best[2] = FLT_MAX; + + for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) + { + MinMax box = boxes[b]; + float dist = distBoxPoint(box, point); + if (dist > reject || dist > best[2]) + continue; + + for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) + { + if (i == idx) + continue; + updateKBest<3>(point, points[indices[i]], best); + } + } + dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; +} + +void SimpleKNN::knn(int P, float3* points, float* meanDists) +{ + float3* result; + cudaMalloc(&result, sizeof(float3)); + size_t temp_storage_bytes; + + float3 init = { 0, 0, 0 }, minn, maxx; + + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); + thrust::device_vector temp_storage(temp_storage_bytes); + + cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); + cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); + + cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); + cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); + + thrust::device_vector morton(P); + thrust::device_vector morton_sorted(P); + coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); + + thrust::device_vector indices(P); + thrust::sequence(indices.begin(), indices.end()); + thrust::device_vector indices_sorted(P); + + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); + temp_storage.resize(temp_storage_bytes); + + cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); + + uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; + thrust::device_vector boxes(num_boxes); + boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); + boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); + + cudaFree(result); +} \ No newline at end of file diff --git a/submodules/simple-knn/simple_knn.h b/submodules/simple-knn/simple_knn.h new file mode 100644 index 0000000000000000000000000000000000000000..3fcfdb87c53faaadc1fd820d5deeb1b2b5c21a86 --- /dev/null +++ b/submodules/simple-knn/simple_knn.h @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef SIMPLEKNN_H_INCLUDED +#define SIMPLEKNN_H_INCLUDED + +class SimpleKNN +{ +public: + static void knn(int P, float3* points, float* meanDists); +}; + +#endif \ No newline at end of file diff --git a/submodules/simple-knn/simple_knn/.gitkeep b/submodules/simple-knn/simple_knn/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/submodules/simple-knn/spatial.cu b/submodules/simple-knn/spatial.cu new file mode 100644 index 0000000000000000000000000000000000000000..1a6a654ba6f8c6a1856a40d14fb7a53c96602bad --- /dev/null +++ b/submodules/simple-knn/spatial.cu @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include "spatial.h" +#include "simple_knn.h" + +torch::Tensor +distCUDA2(const torch::Tensor& points) +{ + const int P = points.size(0); + + auto float_opts = points.options().dtype(torch::kFloat32); + torch::Tensor means = torch::full({P}, 0.0, float_opts); + + SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); + + return means; +} \ No newline at end of file diff --git a/submodules/simple-knn/spatial.h b/submodules/simple-knn/spatial.h new file mode 100644 index 0000000000000000000000000000000000000000..280c953a0321a769e433a43535fd36c251b730f0 --- /dev/null +++ b/submodules/simple-knn/spatial.h @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include + +torch::Tensor distCUDA2(const torch::Tensor& points); \ No newline at end of file diff --git a/synthesize_fuse.py b/synthesize_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..c15b7139cca1bb4b2cbc73405d76e33c7743a676 --- /dev/null +++ b/synthesize_fuse.py @@ -0,0 +1,125 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import imageio +import numpy as np +import torch +from scene import Scene +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render_motion, render_motion_mouth +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args +from gaussian_renderer import GaussianModel, MotionNetwork, MouthMotionNetwork + +import torch.nn.functional as F + +def dilate_fn(bin_img, ksize=13): + pad = (ksize - 1) // 2 + out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=pad) + return out + +def render_set(model_path, name, iteration, views, gaussians, motion_net, gaussians_mouth, motion_net_mouth, pipeline, background, fast, dilate): + render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") + gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + + makedirs(render_path, exist_ok=True) + makedirs(gts_path, exist_ok=True) + + all_preds = [] + all_gts = [] + + all_preds_face = [] + all_preds_mouth = [] + + + for idx, view in enumerate(tqdm(views, desc="Rendering progress", ascii=True)): + with torch.no_grad(): + render_pkg = render_motion(view, gaussians, motion_net, pipeline, background, frame_idx=0) + render_pkg_mouth = render_motion_mouth(view, gaussians_mouth, motion_net_mouth, pipeline, background, frame_idx=0) + # gt = view.original_image[0:3, :, :] + # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) + # torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) + + if dilate: + alpha_mouth = dilate_fn(render_pkg_mouth["alpha"][None])[0] + else: + alpha_mouth = render_pkg_mouth["alpha"] + + mouth_image = render_pkg_mouth["render"] + view.background.cuda() / 255.0 * (1.0 - alpha_mouth) + + # alpha = gaussian_blur(render_pkg["alpha"], [3, 3], 2) + alpha = render_pkg["alpha"] + image = render_pkg["render"] + mouth_image * (1.0 - alpha) + + pred = (image[0:3, ...].clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy()* 255).astype(np.uint8) + all_preds.append(pred) + + if not fast: + all_preds_face.append((render_pkg["render"].clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy()* 255).astype(np.uint8)) + all_preds_mouth.append((render_pkg_mouth["render"].clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy()* 255).astype(np.uint8)) + + all_gts.append(view.original_image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)) + + imageio.mimwrite(os.path.join(render_path, 'out.mp4'), all_preds, fps=25, quality=8, macro_block_size=1) + if not fast: + imageio.mimwrite(os.path.join(gts_path, 'out.mp4'), all_gts, fps=25, quality=8, macro_block_size=1) + + imageio.mimwrite(os.path.join(render_path, 'out_face.mp4'), all_preds_face, fps=25, quality=8, macro_block_size=1) + imageio.mimwrite(os.path.join(render_path, 'out_mouth.mp4'), all_preds_mouth, fps=25, quality=8, macro_block_size=1) + + + +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, use_train : bool, fast, dilate): + with torch.no_grad(): + gaussians = GaussianModel(dataset.sh_degree) + gaussians_mouth = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, shuffle=False) + + motion_net = MotionNetwork(args=dataset).cuda() + motion_net_mouth = MouthMotionNetwork(args=dataset).cuda() + + (model_params, motion_params, model_mouth_params, motion_mouth_params) = torch.load(os.path.join(dataset.model_path, "chkpnt_fuse_latest.pth")) + motion_net.load_state_dict(motion_params, strict=False) + gaussians.restore(model_params, None) + + motion_net_mouth.load_state_dict(motion_mouth_params, strict=False) + gaussians_mouth.restore(model_mouth_params, None) + + + # motion_net.fix(gaussians.get_xyz.cuda()) + # motion_net_mouth.fix(gaussians_mouth.get_xyz.cuda()) + + bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + render_set(dataset.model_path, "test" if not use_train else "train", scene.loaded_iter, scene.getTestCameras() if not use_train else scene.getTrainCameras(), gaussians, motion_net, gaussians_mouth, motion_net_mouth, pipeline, background, fast, dilate) + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Testing script parameters") + model = ModelParams(parser) + pipeline = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--use_train", action="store_true") + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--fast", action="store_true") + parser.add_argument("--dilate", action="store_true") + args = get_combined_args(parser) + print("Rendering " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.use_train, args.fast, args.dilate) diff --git a/torch-ngp/.gitignore b/torch-ngp/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2f100d837ff6216c727061e45ffa72330e7bb335 --- /dev/null +++ b/torch-ngp/.gitignore @@ -0,0 +1,17 @@ +__pycache__/ +build/ +*.egg-info/ +*.so + +tmp* +data/ +trial*/ +.vs/ + +#**dnerf* +#dnerf/ +dnerf/network_rf.py +dnerf/network_basis_perpoint.py + +**neus* +neus/ \ No newline at end of file diff --git a/torch-ngp/.gitmodules b/torch-ngp/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..8b4407c1e9cf8f90ea2d588c8bfa3b8ef7c60c3d --- /dev/null +++ b/torch-ngp/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ffmlp/dependencies/cutlass"] + path = ffmlp/dependencies/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/torch-ngp/LICENSE b/torch-ngp/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f565fb72319dab56d43bc896a87dadc93de1e1f8 --- /dev/null +++ b/torch-ngp/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 hawkey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/torch-ngp/activation.py b/torch-ngp/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..17e8187e55d74e9c73fb7f7698d111f8b204fd35 --- /dev/null +++ b/torch-ngp/activation.py @@ -0,0 +1,18 @@ +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +class _trunc_exp(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # cast to float32 + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(-15, 15)) + +trunc_exp = _trunc_exp.apply \ No newline at end of file diff --git a/torch-ngp/assets/bg_model.jpg b/torch-ngp/assets/bg_model.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d208b59af60aef331c01f863eba9090451f461d Binary files /dev/null and b/torch-ngp/assets/bg_model.jpg differ diff --git a/torch-ngp/assets/ccnerf.jpg b/torch-ngp/assets/ccnerf.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce2dffec5849c2828b814b8e9e60c8f06433fe4c Binary files /dev/null and b/torch-ngp/assets/ccnerf.jpg differ diff --git a/torch-ngp/assets/fox.jpg b/torch-ngp/assets/fox.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4ec7aef7edff25cfad7fd40e90eee75de2bc6225 Binary files /dev/null and b/torch-ngp/assets/fox.jpg differ diff --git a/torch-ngp/assets/gallery.md b/torch-ngp/assets/gallery.md new file mode 100644 index 0000000000000000000000000000000000000000..ee2fac3564fa1853ff6f3c09ffd7b4c821303495 --- /dev/null +++ b/torch-ngp/assets/gallery.md @@ -0,0 +1,26 @@ +# Gallery + +## D-NeRF + +https://user-images.githubusercontent.com/25863658/175821784-63ba79f6-29be-47b5-b3fc-dab5282fce7a.mp4 + + +## Instant-ngp NeRF + +Fox: + +![fox](fox.jpg) + +LLFF: + +![llff](llff.jpg) + +Tanks&Temples: + +![truck](truck.jpg) + +## CCNeRF + +Composition example: + +![ccnerf](ccnerf.jpg) diff --git a/torch-ngp/assets/llff.jpg b/torch-ngp/assets/llff.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8cf768cfeb75d626992bbab185f059e34803ea94 Binary files /dev/null and b/torch-ngp/assets/llff.jpg differ diff --git a/torch-ngp/assets/truck.jpg b/torch-ngp/assets/truck.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fc502f02cccd1d4d7a8c58f51750d2682d367ce5 Binary files /dev/null and b/torch-ngp/assets/truck.jpg differ diff --git a/torch-ngp/assets/update_logs.md b/torch-ngp/assets/update_logs.md new file mode 100644 index 0000000000000000000000000000000000000000..d6c5dd5235e393b2d3351d65c8e55ac2a524935e --- /dev/null +++ b/torch-ngp/assets/update_logs.md @@ -0,0 +1,47 @@ +## Update logs + +* 7.28: support saving video at test. +* 7.26: add a CUDA-based freqencoder (though not used by default), add LPIPS metric. +* 7.16: add temporal basis based dynamic nerf (experimental). It trains much faster compared to the deformation based dynamic nerf, but performance is much worse for now... +* 6.29: add support for HyperNeRF's dataset. + * we use a simplified pinhole camera model, may introduce bias. +* 6.26: add support for D-NeRF. + * issue: to enable the `--cuda_ray` in a dynamic scene, we have to record different density grid for different time. This lead to much slower `update_extra_status` and much larger `density_grid` since there is an additional time dimension. Current work arounds: (1) only use 64 time intervals, (2) update it every 100 steps (compared to the 16 steps in static nerf), (3) stop updation after 100 times since the grid should be stable now. +* 6.16: add support for CCNeRF. +* 6.15: fixed a bug in raymarching, improved PSNR. Density thresh is directly applied on sigmas now (removed the empirical scaling factor). +* 6.6: fix gridencoder to always use more accurate float32 inputs (coords), slightly improved performance (matched with tcnn). +* 6.3: implement morton3D, misc improvements. +* 5.29: fix a random bg color issue, add color_space option, better results for blender dataset. +* 5.28: add a background model (set bg_radius > 0), which can suppress noises for real-world 360 datasets. +* 5.21: expose more parameters to control, implement packbits. +* 4.30: performance improvement (better lr_scheduler). +* 4.25: add Tanks&Temples dataset support. +* 4.18: add some experimental utils for random pose sampling and combined training with CLIP. +* 4.13: add LLFF dataset support. +* 4.13: also implmented tiled grid encoder according to this [issue](https://github.com/NVlabs/instant-ngp/issues/97). +* 4.12: optimized dataloader, add error_map sampling (experimental, will slow down training since will only sample hard rays...) +* 4.10: add Windows support. +* 4.9: use 6D AABB instead of a single `bound` for more flexible rendering. More options in GUI to control the AABB and `dt_gamma` for adaptive ray marching. +* 4.9: implemented multi-res density grid (cascade) and adaptive ray marching. Now the fox renders much faster! +* 4.6: fixed TensorCP hyper-parameters. +* 4.3: add `mark_untrained_grid` to prevent training on out-of-camera regions. Add custom dataset instructions. +* 3.31: better compatibility for lower pytorch versions. +* 3.29: fix training speed for the fox dataset (balanced speed with performance...). +* 3.27: major update. basically improve performance, and support tensoRF model. +* 3.22: reverted from pre-generating rays as it takes too much CPU memory, still the PSNR for Lego can reach ~33 now. +* 3.14: fixed the precision related issue for `fp16` mode, and it renders much better quality. Added PSNR metric for NeRF. +* 3.14: linearly scale `desired_resolution` with `bound` according to https://github.com/ashawkey/torch-ngp/issues/23. +* 3.11: raymarching now supports supervising weights_sum (pixel alpha, or mask) directly, and bg_color is separated from CUDA to make it more flexible. Add an option to preload data into GPU. +* 3.9: add fov for gui. +* 3.1: add type='all' for blender dataset (load train + val + test data), which is the default behavior of instant-ngp. +* 2.28: density_grid now stores density on the voxel center (with randomness), instead of on the grid. This should improve the rendering quality, such as the black strips in the lego scene. +* 2.23: better support for the blender dataset. +* 2.22: add GUI for NeRF training. +* 2.21: add GUI for NeRF visualizing. +* 2.20: cuda raymarching is finally stable now! +* 2.15: add the official [tinycudann](https://github.com/NVlabs/tiny-cuda-nn) as an alternative backend. +* 2.10: add cuda_ray, can train/infer faster, but performance is worse currently. +* 2.6: add support for RGBA image. +* 1.30: fixed atomicAdd() to use __half2 in HashGrid Encoder's backward, now the training speed with fp16 is as expected! +* 1.29: finished an experimental binding of fully-fused MLP. replace SHEncoder with a CUDA implementation. +* 1.26: add fp16 support for HashGrid Encoder (requires CUDA >= 10 and GPU ARCH >= 70 for now...). \ No newline at end of file diff --git a/torch-ngp/dnerf/gui.py b/torch-ngp/dnerf/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..89c8c74299461d21583a8b05a63bb57362072d51 --- /dev/null +++ b/torch-ngp/dnerf/gui.py @@ -0,0 +1,444 @@ +import math +import torch +import numpy as np +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R + +from .utils import * + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention) + self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] -= self.radius + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2]) + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.radians(-0.1 * dx) + rotvec_y = side * np.radians(-0.1 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) + + +class NeRFGUI: + def __init__(self, opt, trainer, train_loader=None, debug=True): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) + self.debug = debug + self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg + self.training = False + self.step = 0 # training step + + self.trainer = trainer + self.train_loader = train_loader + if train_loader is not None: + self.trainer.error_map = train_loader._data.error_map + + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.mode = 'image' # choose from ['image', 'depth'] + self.time = 0 # time for dynamic scene, in [0, 1] + + self.dynamic_resolution = True + self.downscale = 1 + self.train_steps = 16 + + dpg.create_context() + self.register_dpg() + self.test_step() + + + def __del__(self): + dpg.destroy_context() + + + def train_step(self): + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.train_gui(self.train_loader, step=self.train_steps) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + self.step += self.train_steps + self.need_update = True + + dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') + + # dynamic train steps + # max allowed train time per-frame is 500 ms + full_t = t / self.train_steps * 16 + train_steps = min(16, max(4, int(16 * 500 / full_t))) + if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: + self.train_steps = train_steps + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'] + else: + return np.expand_dims(outputs['depth'], -1).repeat(3, -1) + + + def test_step(self): + # TODO: seems we have to move data from GPU --> CPU --> GPU? + + if self.need_update or self.spp < self.opt.max_spp: + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.time, self.bg_color, self.spp, self.downscale) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(tag="_primary_window", width=self.W, height=self.H): + + # add the texture + dpg.add_image("_texture") + + dpg.set_primary_window("_primary_window", True) + + # control window + with dpg.window(label="Control", tag="_control_window", width=400, height=300): + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # time + if not self.opt.test: + with dpg.group(horizontal=True): + dpg.add_text("Train time: ") + dpg.add_text("no data", tag="_log_train_time") + + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + with dpg.group(horizontal=True): + dpg.add_text("SPP: ") + dpg.add_text("1", tag="_log_spp") + + # train button + if not self.opt.test: + with dpg.collapsing_header(label="Train", default_open=True): + + # train / stop + with dpg.group(horizontal=True): + dpg.add_text("Train: ") + + def callback_train(sender, app_data): + if self.training: + self.training = False + dpg.configure_item("_button_train", label="start") + else: + self.training = True + dpg.configure_item("_button_train", label="stop") + + dpg.add_button(label="start", tag="_button_train", callback=callback_train) + dpg.bind_item_theme("_button_train", theme_button) + + def callback_reset(sender, app_data): + @torch.no_grad() + def weight_reset(m: nn.Module): + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + self.trainer.model.apply(fn=weight_reset) + self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter + self.need_update = True + + dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) + dpg.bind_item_theme("_button_reset", theme_button) + + # save ckpt + with dpg.group(horizontal=True): + dpg.add_text("Checkpoint: ") + + def callback_save(sender, app_data): + self.trainer.save_checkpoint(full=True, best=False) + dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="save", tag="_button_save", callback=callback_save) + dpg.bind_item_theme("_button_save", theme_button) + + dpg.add_text("", tag="_log_ckpt") + + # save mesh + with dpg.group(horizontal=True): + dpg.add_text("Marching Cubes: ") + + def callback_mesh(sender, app_data): + self.trainer.save_mesh(resolution=256, threshold=10) + dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) + dpg.bind_item_theme("_button_mesh", theme_button) + + dpg.add_text("", tag="_log_mesh") + + with dpg.group(horizontal=True): + dpg.add_text("", tag="_log_train_log") + + + # rendering options + with dpg.collapsing_header(label="Options", default_open=True): + + # dynamic rendering resolution + with dpg.group(horizontal=True): + + def callback_set_dynamic_resolution(sender, app_data): + if self.dynamic_resolution: + self.dynamic_resolution = False + self.downscale = 1 + else: + self.dynamic_resolution = True + self.need_update = True + + dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) + dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") + + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) + + # time slider + def callback_set_time(sender, app_data): + self.time = app_data + self.need_update = True + + dpg.add_slider_float(label="time", min_value=0.0, max_value=1.0, format="%.5f", default_value=self.time, callback=callback_set_time) + + # bg_color picker + def callback_change_bg(sender, app_data): + self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] + self.need_update = True + + dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + # dt_gamma slider + def callback_set_dt_gamma(sender, app_data): + self.opt.dt_gamma = app_data + self.need_update = True + + dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) + + # max_steps slider + def callback_set_max_steps(sender, app_data): + self.opt.max_steps = app_data + self.need_update = True + + dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) + + # aabb slider + def callback_set_aabb(sender, app_data, user_data): + # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) + self.trainer.model.aabb_infer[user_data] = app_data + + # also change train aabb ? [better not...] + #self.trainer.model.aabb_train[user_data] = app_data + + self.need_update = True + + dpg.add_separator() + dpg.add_text("Axis-aligned bounding box:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) + + + # debug info + if self.debug: + with dpg.collapsing_header(label="Debug"): + # pose + dpg.add_separator() + dpg.add_text("Camera Pose:") + dpg.add_text(str(self.cam.pose), tag="_log_pose") + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) + + + dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) + + # TODO: seems dearpygui doesn't support resizing texture... + # def callback_resize(sender, app_data): + # self.W = app_data[0] + # self.H = app_data[1] + # # how to reload texture ??? + + # dpg.set_viewport_resize_callback(callback_resize) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + # update texture every frame + if self.training: + self.train_step() + self.test_step() + dpg.render_dearpygui_frame() \ No newline at end of file diff --git a/torch-ngp/dnerf/network.py b/torch-ngp/dnerf/network.py new file mode 100644 index 0000000000000000000000000000000000000000..601d8c96fb42059d3995b05de08b2bfa22742ebc --- /dev/null +++ b/torch-ngp/dnerf/network.py @@ -0,0 +1,270 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from activation import trunc_exp +from .renderer import NeRFRenderer + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + encoding="tiledgrid", + encoding_dir="sphere_harmonics", + encoding_time="frequency", + encoding_deform="frequency", # "hashgrid" seems worse + encoding_bg="hashgrid", + num_layers=2, + hidden_dim=64, + geo_feat_dim=15, + num_layers_color=3, + hidden_dim_color=64, + num_layers_bg=2, + hidden_dim_bg=64, + num_layers_deform=5, # a deeper MLP is very necessary for performance. + hidden_dim_deform=128, + bound=1, + **kwargs, + ): + super().__init__(bound, **kwargs) + + # deformation network + self.num_layers_deform = num_layers_deform + self.hidden_dim_deform = hidden_dim_deform + self.encoder_deform, self.in_dim_deform = get_encoder(encoding_deform, multires=10) + self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) + + + deform_net = [] + for l in range(num_layers_deform): + if l == 0: + in_dim = self.in_dim_deform + self.in_dim_time # grid dim + time + else: + in_dim = hidden_dim_deform + + if l == num_layers_deform - 1: + out_dim = 3 # deformation for xyz + else: + out_dim = hidden_dim_deform + + deform_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.deform_net = nn.ModuleList(deform_net) + + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) + + sigma_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + self.in_dim_time + self.in_dim_deform # concat everything + else: + in_dim = hidden_dim + + if l == num_layers - 1: + out_dim = 1 + self.geo_feat_dim # 1 sigma + features for color + else: + out_dim = hidden_dim + + sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.sigma_net = nn.ModuleList(sigma_net) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) + + color_net = [] + for l in range(num_layers_color): + if l == 0: + in_dim = self.in_dim_dir + self.geo_feat_dim + else: + in_dim = hidden_dim_color + + if l == num_layers_color - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_color + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background network + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = self.in_dim_bg + self.in_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = nn.ModuleList(bg_net) + else: + self.bg_net = None + + + def forward(self, x, d, t): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + # t: [1, 1], in [0, 1] + + # deform + enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C] + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + if enc_t.shape[0] == 1: + enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] + + deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C'] + for l in range(self.num_layers_deform): + deform = self.deform_net[l](deform) + if l != self.num_layers_deform - 1: + deform = F.relu(deform, inplace=True) + + x = x + deform + + # sigma + x = self.encoder(x, bound=self.bound) + h = torch.cat([x, enc_ori_x, enc_t], dim=1) + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return sigma, rgbs, deform + + def density(self, x, t): + # x: [N, 3], in [-bound, bound] + # t: [1, 1], in [0, 1] + + results = {} + + # deformation + enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C] + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + if enc_t.shape[0] == 1: + enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] + + deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C'] + for l in range(self.num_layers_deform): + deform = self.deform_net[l](deform) + if l != self.num_layers_deform - 1: + deform = F.relu(deform, inplace=True) + + x = x + deform + results['deform'] = deform + + # sigma + x = self.encoder(x, bound=self.bound) + h = torch.cat([x, enc_ori_x, enc_t], dim=1) + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + results['sigma'] = sigma + results['geo_feat'] = geo_feat + + return results + + def background(self, x, d): + # x: [N, 2], in [-1, 1] + + h = self.encoder_bg(x) # [N, C] + d = self.encoder_dir(d) + + h = torch.cat([d, h], dim=-1) + for l in range(self.num_layers_bg): + h = self.bg_net[l](h) + if l != self.num_layers_bg - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # t: [1, 1], in [0, 1] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h + + return rgbs + + # optimizer utils + def get_params(self, lr, lr_net): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_deform.parameters(), 'lr': lr}, + {'params': self.encoder_time.parameters(), 'lr': lr}, + {'params': self.deform_net.parameters(), 'lr': lr_net}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) + + return params diff --git a/torch-ngp/dnerf/network_basis.py b/torch-ngp/dnerf/network_basis.py new file mode 100644 index 0000000000000000000000000000000000000000..133a03a3d55786ae4740d5d532ec8c65ba0887c5 --- /dev/null +++ b/torch-ngp/dnerf/network_basis.py @@ -0,0 +1,262 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from activation import trunc_exp +from .renderer import NeRFRenderer + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + encoding="tiledgrid", + encoding_dir="sphere_harmonics", + encoding_time="frequency", + encoding_bg="hashgrid", + num_layers=2, + hidden_dim=64, + geo_feat_dim=32, + num_layers_color=3, + hidden_dim_color=64, + num_layers_bg=2, + hidden_dim_bg=64, + sigma_basis_dim=32, + color_basis_dim=8, + num_layers_basis=5, + hidden_dim_basis=128, + bound=1, + **kwargs, + ): + super().__init__(bound, **kwargs) + + # basis network + self.num_layers_basis = num_layers_basis + self.hidden_dim_basis = hidden_dim_basis + self.sigma_basis_dim = sigma_basis_dim + self.color_basis_dim = color_basis_dim + self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) + + basis_net = [] + for l in range(num_layers_basis): + if l == 0: + in_dim = self.in_dim_time + else: + in_dim = hidden_dim_basis + + if l == num_layers_basis - 1: + out_dim = self.sigma_basis_dim + self.color_basis_dim + else: + out_dim = hidden_dim_basis + + basis_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.basis_net = nn.ModuleList(basis_net) + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) + + sigma_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + else: + in_dim = hidden_dim + + if l == num_layers - 1: + out_dim = self.sigma_basis_dim + self.geo_feat_dim # SB sigma + features for color + else: + out_dim = hidden_dim + + sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.sigma_net = nn.ModuleList(sigma_net) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) + + color_net = [] + for l in range(num_layers_color): + if l == 0: + in_dim = self.in_dim_dir + self.geo_feat_dim + else: + in_dim = hidden_dim_color + + if l == num_layers_color - 1: + out_dim = 3 * self.color_basis_dim # 3 * CB rgb + else: + out_dim = hidden_dim_color + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background network + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = self.in_dim_bg + self.in_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = nn.ModuleList(bg_net) + else: + self.bg_net = None + + + def forward(self, x, d, t): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + # t: [1, 1], in [0, 1] + + # time --> basis + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + h = enc_t + for l in range(self.num_layers_basis): + h = self.basis_net[l](h) + if l != self.num_layers_basis - 1: + h = F.relu(h, inplace=True) + + sigma_basis = h[0, :self.sigma_basis_dim] + color_basis = h[0, self.sigma_basis_dim:] + + # sigma + x = self.encoder(x, bound=self.bound) + h = x + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + sigma = trunc_exp(h[..., :self.sigma_basis_dim] @ sigma_basis) + geo_feat = h[..., self.sigma_basis_dim:] + + # color + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h.view(-1, 3, self.color_basis_dim) @ color_basis) + + return sigma, rgbs, None + + def density(self, x, t): + # x: [N, 3], in [-bound, bound] + # t: [1, 1], in [0, 1] + + results = {} + + # time --> basis + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + h = enc_t + for l in range(self.num_layers_basis): + h = self.basis_net[l](h) + if l != self.num_layers_basis - 1: + h = F.relu(h, inplace=True) + + sigma_basis = h[0, :self.sigma_basis_dim] + color_basis = h[0, self.sigma_basis_dim:] + + # sigma + x = self.encoder(x, bound=self.bound) + h = x + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + sigma = trunc_exp(h[..., :self.sigma_basis_dim] @ sigma_basis) + geo_feat = h[..., self.sigma_basis_dim:] + + results['sigma'] = sigma + results['geo_feat'] = geo_feat + # results['color_basis'] = color_basis + + return results + + def background(self, x, d): + # x: [N, 2], in [-1, 1] + + h = self.encoder_bg(x) # [N, C] + d = self.encoder_dir(d) + + h = torch.cat([d, h], dim=-1) + for l in range(self.num_layers_bg): + h = self.bg_net[l](h) + if l != self.num_layers_bg - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # TODO: non cuda-ray mode is broken for now... (how to pass color_basis to self.color()) + # # allow masked inference + # def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # # x: [N, 3] in [-bound, bound] + # # t: [1, 1], in [0, 1] + # # mask: [N,], bool, indicates where we actually needs to compute rgb. + + # if mask is not None: + # rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # # in case of empty mask + # if not mask.any(): + # return rgbs + # x = x[mask] + # d = d[mask] + # geo_feat = geo_feat[mask] + + # d = self.encoder_dir(d) + # h = torch.cat([d, geo_feat], dim=-1) + # for l in range(self.num_layers_color): + # h = self.color_net[l](h) + # if l != self.num_layers_color - 1: + # h = F.relu(h, inplace=True) + + # # sigmoid activation for rgb + # h = torch.sigmoid(h) + + # if mask is not None: + # rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + # else: + # rgbs = h + + # return rgbs + + # optimizer utils + def get_params(self, lr, lr_net): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_time.parameters(), 'lr': lr}, + {'params': self.basis_net.parameters(), 'lr': lr_net}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) + + return params diff --git a/torch-ngp/dnerf/network_hyper.py b/torch-ngp/dnerf/network_hyper.py new file mode 100644 index 0000000000000000000000000000000000000000..fb4abba44e42dad6db764677ae830d6525ddd3b0 --- /dev/null +++ b/torch-ngp/dnerf/network_hyper.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from activation import trunc_exp +from .renderer import NeRFRenderer + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + encoding="tiledgrid", + encoding_dir="sphere_harmonics", + encoding_time="frequency", + encoding_bg="hashgrid", + num_layers=2, + hidden_dim=64, + geo_feat_dim=32, + num_layers_color=3, + hidden_dim_color=64, + num_layers_bg=2, + hidden_dim_bg=64, + num_layers_ambient=5, + hidden_dim_ambient=128, + ambient_dim=1, + bound=1, + **kwargs, + ): + super().__init__(bound, **kwargs) + + # ambient network + self.num_layers_ambient = num_layers_ambient + self.hidden_dim_ambient = hidden_dim_ambient + self.ambient_dim = ambient_dim + self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) + + ambient_net = [] + for l in range(num_layers_ambient): + if l == 0: + in_dim = self.in_dim_time + else: + in_dim = hidden_dim_ambient + + if l == num_layers_ambient - 1: + out_dim = self.ambient_dim + else: + out_dim = hidden_dim_ambient + + ambient_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.ambient_net = nn.ModuleList(ambient_net) + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + self.encoder, self.in_dim = get_encoder(encoding, input_dim=3+self.ambient_dim, desired_resolution=2048 * bound) + + sigma_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + else: + in_dim = hidden_dim + + if l == num_layers - 1: + out_dim = 1 + self.geo_feat_dim # 1 sigma + features for color + else: + out_dim = hidden_dim + + sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.sigma_net = nn.ModuleList(sigma_net) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) + + color_net = [] + for l in range(num_layers_color): + if l == 0: + in_dim = self.in_dim_dir + self.geo_feat_dim + else: + in_dim = hidden_dim_color + + if l == num_layers_color - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_color + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background network + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = self.in_dim_bg + self.in_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = nn.ModuleList(bg_net) + else: + self.bg_net = None + + + def forward(self, x, d, t): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + # t: [1, 1], in [0, 1] + + # time --> ambient + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + # if enc_t.shape[0] == 1: + # enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] + ambient = enc_t + for l in range(self.num_layers_ambient): + ambient = self.ambient_net[l](ambient) + if l != self.num_layers_ambient - 1: + ambient = F.relu(ambient, inplace=True) + + ambient = F.tanh(ambient) * self.bound + x = torch.cat([x, ambient.repeat(x.shape[0], 1)], dim=1) + + # sigma + x = self.encoder(x, bound=self.bound) + h = x + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return sigma, rgbs, None + + def density(self, x, t): + # x: [N, 3], in [-bound, bound] + # t: [1, 1], in [0, 1] + + results = {} + + # time --> ambient + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + ambient = enc_t + for l in range(self.num_layers_ambient): + ambient = self.ambient_net[l](ambient) + if l != self.num_layers_ambient - 1: + ambient = F.relu(ambient, inplace=True) + + ambient = F.tanh(ambient) * self.bound + x = torch.cat([x, ambient.repeat(x.shape[0], 1)], dim=1) + + # sigma + x = self.encoder(x, bound=self.bound) + h = x + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + results['sigma'] = sigma + results['geo_feat'] = geo_feat + + return results + + def background(self, x, d): + # x: [N, 2], in [-1, 1] + + h = self.encoder_bg(x) # [N, C] + d = self.encoder_dir(d) + + h = torch.cat([d, h], dim=-1) + for l in range(self.num_layers_bg): + h = self.bg_net[l](h) + if l != self.num_layers_bg - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # t: [1, 1], in [0, 1] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h + + return rgbs + + # optimizer utils + def get_params(self, lr, lr_net): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_time.parameters(), 'lr': lr}, + {'params': self.ambient_net.parameters(), 'lr': lr_net}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) + + return params diff --git a/torch-ngp/dnerf/provider.py b/torch-ngp/dnerf/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3824c96f3b1c1d89d26c86dab67c9e8f04cc0899 --- /dev/null +++ b/torch-ngp/dnerf/provider.py @@ -0,0 +1,361 @@ +import os +import cv2 +import glob +import json +import tqdm +import numpy as np +from scipy.spatial.transform import Slerp, Rotation + +import trimesh + +import torch +from torch.utils.data import DataLoader + +from .utils import get_rays, srgb_to_linear + + +# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 +def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): + # for the fox dataset, 0.33 scales camera radius to ~ 2 + new_pose = np.array([ + [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], + [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], + [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], + [0, 0, 0, 1], + ], dtype=np.float32) + return new_pose + + +def visualize_poses(poses, size=0.1): + # poses: [B, 4, 4] + + axes = trimesh.creation.axis(axis_length=4) + box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() + box.colors = np.array([[128, 128, 128]] * len(box.entities)) + objects = [axes, box] + + for pose in poses: + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + + dir = (a + b + c + d) / 4 - pos + dir = dir / (np.linalg.norm(dir) + 1e-8) + o = pos + dir * 3 + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) + segs = trimesh.load_path(segs) + objects.append(segs) + + trimesh.Scene(objects).show() + + +def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2*np.pi/3], phi_range=[0, 2*np.pi]): + ''' generate random poses from an orbit camera + Args: + size: batch size of generated poses. + device: where to allocate the output. + radius: camera radius + theta_range: [min, max], should be in [0, \pi] + phi_range: [min, max], should be in [0, 2\pi] + Return: + poses: [size, 4, 4] + ''' + + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = - normalize(centers) + up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system... + right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) + up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + return poses + + +class NeRFDataset: + def __init__(self, opt, device, type='train', downscale=1, n_test=10): + super().__init__() + + self.opt = opt + self.device = device + self.type = type # train, val, test + self.downscale = downscale + self.root_path = opt.path + self.preload = opt.preload # preload data into GPU + self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. + self.offset = opt.offset # camera offset + self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. + self.fp16 = opt.fp16 # if preload, load into fp16. + + self.training = self.type in ['train', 'all', 'trainval'] + self.num_rays = self.opt.num_rays if self.training else -1 + + self.rand_pose = opt.rand_pose + + # auto-detect transforms.json and split mode. + if os.path.exists(os.path.join(self.root_path, 'transforms.json')): + self.mode = 'colmap' # manually split, use view-interpolation for test. + elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')): + self.mode = 'blender' # provided split + else: + raise NotImplementedError(f'[NeRFDataset] Cannot find transforms*.json under {self.root_path}') + + # load nerf-compatible format data. + if self.mode == 'colmap': + with open(os.path.join(self.root_path, 'transforms.json'), 'r') as f: + transform = json.load(f) + elif self.mode == 'blender': + # load all splits (train/valid/test), this is what instant-ngp in fact does... + if type == 'all': + transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) + transform = None + for transform_path in transform_paths: + with open(transform_path, 'r') as f: + tmp_transform = json.load(f) + if transform is None: + transform = tmp_transform + else: + transform['frames'].extend(tmp_transform['frames']) + # load train and val split + elif type == 'trainval': + with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: + transform = json.load(f) + with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: + transform_val = json.load(f) + transform['frames'].extend(transform_val['frames']) + # only load one specified split + else: + with open(os.path.join(self.root_path, f'transforms_{type}.json'), 'r') as f: + transform = json.load(f) + + else: + raise NotImplementedError(f'unknown dataset mode: {self.mode}') + + # load image size + if 'h' in transform and 'w' in transform: + self.H = int(transform['h']) // downscale + self.W = int(transform['w']) // downscale + else: + # we have to actually read an image to get H and W later. + self.H = self.W = None + + # read images + frames = transform["frames"] + #frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort... + + # for colmap, manually interpolate a test set. + if self.mode == 'colmap' and type == 'test': + + # choose two random poses, and interpolate between. + f0, f1 = np.random.choice(frames, 2, replace=False) + pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] + pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] + time0 = f0['time'] if 'time' in f0 else int(os.path.basename(f0['file_path'])[:-4]) + time1 = f1['time'] if 'time' in f1 else int(os.path.basename(f1['file_path'])[:-4]) + rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]])) + slerp = Slerp([0, 1], rots) + + self.poses = [] + self.images = None + self.times = [] + for i in range(n_test + 1): + ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5 + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = slerp(ratio).as_matrix() + pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3] + self.poses.append(pose) + time = (1 - ratio) * time0 + ratio * time1 + self.times.append(time) + + # manually find max time to normalize + if 'time' not in f0: + max_time = 0 + for f in frames: + max_time = max(max_time, int(os.path.basename(f['file_path'])[:-4])) + self.times = [t / max_time for t in self.times] + + else: + # for colmap, manually split a valid set (the first frame). + if self.mode == 'colmap': + if type == 'train': + frames = frames[1:] + elif type == 'val': + frames = frames[:1] + # else 'all' or 'trainval' : use all frames + + self.poses = [] + self.images = [] + self.times = [] + + # assume frames are already sorted by time! + for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): + f_path = os.path.join(self.root_path, f['file_path']) + if self.mode == 'blender' and '.' not in os.path.basename(f_path): + f_path += '.png' # so silly... + + # there are non-exist paths in fox... + if not os.path.exists(f_path): + continue + + pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] + pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) + + image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] + if self.H is None or self.W is None: + self.H = image.shape[0] // downscale + self.W = image.shape[1] // downscale + + # add support for the alpha channel as a mask. + if image.shape[-1] == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + + if image.shape[0] != self.H or image.shape[1] != self.W: + image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA) + + image = image.astype(np.float32) / 255 # [H, W, 3/4] + + # frame time + if 'time' in f: + time = f['time'] + else: + time = int(os.path.basename(f['file_path'])[:-4]) # assume frame index as time + + self.poses.append(pose) + self.images.append(image) + self.times.append(time) + + self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4] + if self.images is not None: + self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C] + self.times = torch.from_numpy(np.asarray(self.times, dtype=np.float32)).view(-1, 1) # [N, 1] + + # manual normalize + if self.times.max() > 1: + self.times = self.times / (self.times.max() + 1e-8) # normalize to [0, 1] + + # calculate mean radius of all camera poses + self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() + #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') + + # initialize error_map + if self.training and self.opt.error_map: + self.error_map = torch.ones([self.images.shape[0], 128 * 128], dtype=torch.float) # [B, 128 * 128], flattened for easy indexing, fixed resolution... + else: + self.error_map = None + + # [debug] uncomment to view all training poses. + # visualize_poses(self.poses.numpy()) + + # [debug] uncomment to view examples of randomly generated poses. + # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy()) + + if self.preload: + self.poses = self.poses.to(self.device) + if self.images is not None: + # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ? + if self.fp16 and self.opt.color_space != 'linear': + dtype = torch.half + else: + dtype = torch.float + self.images = self.images.to(dtype).to(self.device) + if self.error_map is not None: + self.error_map = self.error_map.to(self.device) + self.times = self.times.to(self.device) + + # load intrinsics + if 'fl_x' in transform or 'fl_y' in transform: + fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale + fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale + elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: + # blender, assert in radians. already downscaled since we use H/W + fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None + fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None + if fl_x is None: fl_x = fl_y + if fl_y is None: fl_y = fl_x + else: + raise RuntimeError('Failed to load focal length, please check the transforms.json!') + + cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2) + cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2) + + self.intrinsics = np.array([fl_x, fl_y, cx, cy]) + + + def collate(self, index): + + B = len(index) # a list of length 1 + + # random pose without gt images. + if self.rand_pose == 0 or index[0] >= len(self.poses): + + poses = rand_poses(B, self.device, radius=self.radius) + + # sample a low-resolution but full image for CLIP + s = np.sqrt(self.H * self.W / self.num_rays) # only in training, assert num_rays > 0 + rH, rW = int(self.H / s), int(self.W / s) + rays = get_rays(poses, self.intrinsics / s, rH, rW, -1) + + return { + 'H': rH, + 'W': rW, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + } + + poses = self.poses[index].to(self.device) # [B, 4, 4] + times = self.times[index].to(self.device) # [B, 1] + + error_map = None if self.error_map is None else self.error_map[index] + + rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map) + + results = { + 'time': times, + 'H': self.H, + 'W': self.W, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + } + + if self.images is not None: + images = self.images[index].to(self.device) # [B, H, W, 3/4] + if self.training: + C = images.shape[-1] + images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] + results['images'] = images + + # need inds to update error_map + if error_map is not None: + results['index'] = index + results['inds_coarse'] = rays['inds_coarse'] + + return results + + def dataloader(self): + size = len(self.poses) + if self.training and self.rand_pose > 0: + size += size // self.rand_pose # index >= size means we use random pose. + loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) + loader._data = self # an ugly fix... we need to access error_map & poses in trainer. + loader.has_gt = self.images is not None + return loader \ No newline at end of file diff --git a/torch-ngp/dnerf/renderer.py b/torch-ngp/dnerf/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4c6f2bc21a2081509d751de1de96acbfcb6dfb --- /dev/null +++ b/torch-ngp/dnerf/renderer.py @@ -0,0 +1,591 @@ +import math +import trimesh +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import raymarching +from .utils import custom_meshgrid + +def sample_pdf(bins, weights, n_samples, det=False): + # This implementation is from NeRF + # bins: [B, T], old_z_vals + # weights: [B, T - 1], bin weights. + # return: [B, n_samples], new_z_vals + + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) + + # Invert CDF + u = u.contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples + + +def plot_pointcloud(pc, color=None): + # pc: [N, 3] + # color: [N, 3/4] + print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) + pc = trimesh.PointCloud(pc, color) + # axis + axes = trimesh.creation.axis(axis_length=4) + # sphere + sphere = trimesh.creation.icosphere(radius=1) + trimesh.Scene([pc, axes, sphere]).show() + + +class NeRFRenderer(nn.Module): + def __init__(self, + bound=1, + cuda_ray=False, + density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. + min_near=0.2, + density_thresh=0.01, + bg_radius=-1, + ): + super().__init__() + + self.bound = bound + self.cascade = 1 + math.ceil(math.log2(bound)) + self.time_size = 64 + self.grid_size = 128 + self.density_scale = density_scale + self.min_near = min_near + self.density_thresh = density_thresh + self.bg_radius = bg_radius # radius of the background sphere. + + # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) + # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. + aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) + aabb_infer = aabb_train.clone() + self.register_buffer('aabb_train', aabb_train) + self.register_buffer('aabb_infer', aabb_infer) + + # extra state for cuda raymarching + self.cuda_ray = cuda_ray + if cuda_ray: + # density grid (with an extra time dimension) + density_grid = torch.zeros(self.time_size, self.cascade, self.grid_size ** 3) # [T, CAS, H * H * H] + density_bitfield = torch.zeros(self.time_size, self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [T, CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + # time stamps for density grid + times = ((torch.arange(self.time_size, dtype=torch.float32) + 0.5) / self.time_size).view(-1, 1, 1) # [T, 1, 1] + self.register_buffer('times', times) + # step counter + step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 + + def forward(self, x, d, t): + raise NotImplementedError() + + # separated density and color query (can accelerate non-cuda-ray mode.) + def density(self, x, t): + raise NotImplementedError() + + def color(self, x, d, t, mask=None, **kwargs): + raise NotImplementedError() + + def reset_extra_state(self): + if not self.cuda_ray: + return + # density grid + self.density_grid.zero_() + self.mean_density = 0 + self.iter_density = 0 + # step counter + self.step_counter.zero_() + self.mean_count = 0 + self.local_step = 0 + + def run(self, rays_o, rays_d, time, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # time: [B, 1] + # bg_color: [3] in range [0, 1] + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + # choose aabb + aabb = self.aabb_train if self.training else self.aabb_infer + + # sample steps + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) + nears.unsqueeze_(-1) + fars.unsqueeze_(-1) + + #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') + + z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] + z_vals = z_vals.expand((N, num_steps)) # [N, T] + z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] + + # perturb z_vals + sample_dist = (fars - nears) / num_steps + if perturb: + z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist + #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. + + # generate xyzs + xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] + xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + # query SDF and RGB + density_outputs = self.density(xyzs.reshape(-1, 3), time) + + #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] + for k, v in density_outputs.items(): + density_outputs[k] = v.view(N, num_steps, -1) + + # upsample z_vals (nerf-like) + if upsample_steps > 0: + with torch.no_grad(): + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + + alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] + + # sample new z_vals + z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] + new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] + + new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] + new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + # only forward new points to save computation + new_density_outputs = self.density(new_xyzs.reshape(-1, 3), time) + #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] + for k, v in new_density_outputs.items(): + new_density_outputs[k] = v.view(N, upsample_steps, -1) + + # re-order + z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] + z_vals, z_index = torch.sort(z_vals, dim=1) + + xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] + xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) + + for k in density_outputs: + tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) + density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] + + dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) + for k, v in density_outputs.items(): + density_outputs[k] = v.view(-1, v.shape[-1]) + + mask = weights > 1e-4 # hard coded + rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) + rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] + + #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) + + # calculate weight_sum (mask) + weights_sum = weights.sum(dim=-1) # [N] + + # calculate depth + ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) + depth = torch.sum(weights * ori_z_vals, dim=-1) + + # calculate color + image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] + + # mix background color + if self.bg_radius > 0: + # use the bg model to calculate bg_color + sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] + bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3] + elif bg_color is None: + bg_color = 1 + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + # tmp: reg loss in mip-nerf 360 + # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) + # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] + # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() + + return { + 'depth': depth, + 'image': image, + 'deform': density_outputs['deform'], + } + + + def run_cuda(self, rays_o, rays_d, time, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # time: [B, 1], B == 1, so only one time is used. + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) + + # mix background color + if self.bg_radius > 0: + # use the bg model to calculate bg_color + sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] + bg_color = self.background(sph, rays_d) # [N, 3] + elif bg_color is None: + bg_color = 1 + + # determine the correct frame of density grid to use + t = torch.floor(time[0][0] * self.time_size).clamp(min=0, max=self.time_size - 1).long() + + results = {} + + if self.training: + # setup counter + counter = self.step_counter[self.local_step % 16] + counter.zero_() # set to 0 + self.local_step += 1 + + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield[t], self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + sigmas, rgbs, deform = self(xyzs, dirs, time) + # density_outputs = self.density(xyzs, time) # [M,], use a dict since it may include extra things, like geo_feat for rgb. + # sigmas = density_outputs['sigma'] + # rgbs = self.color(xyzs, dirs, **density_outputs) + sigmas = self.density_scale * sigmas + + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + + # special case for CCNeRF's residual learning + if len(sigmas.shape) == 2: + K = sigmas.shape[0] + depths = [] + images = [] + for k in range(K): + weights_sum, depth, image = raymarching.composite_rays_train(sigmas[k], rgbs[k], deltas, rays) + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + images.append(image.view(*prefix, 3)) + depths.append(depth.view(*prefix)) + + depth = torch.stack(depths, axis=0) # [K, B, N] + image = torch.stack(images, axis=0) # [K, B, N, 3] + + else: + + weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays) + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + results['deform'] = deform + + else: + + # allocate outputs + # if use autocast, must init as half so it won't be autocasted and lose reference. + #dtype = torch.half if torch.is_autocast_enabled() else torch.float32 + # output should always be float32! only network inference uses half. + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + + step = 0 + + while step < max_steps: + + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + + xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield[t], self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) + + sigmas, rgbs, _ = self(xyzs, dirs, time) + # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. + # sigmas = density_outputs['sigma'] + # rgbs = self.color(xyzs, dirs, **density_outputs) + sigmas = self.density_scale * sigmas + + raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) + + rays_alive = rays_alive[rays_alive >= 0] + + #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') + + step += n_step + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + results['depth'] = depth + results['image'] = image + + return results + + @torch.no_grad() + def mark_untrained_grid(self, poses, intrinsic, S=64): + # poses: [B, 4, 4] + # intrinsic: [3, 3] + + if not self.cuda_ray: + return + + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + + B = poses.shape[0] + + fx, fy, cx, cy = intrinsic + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + count = torch.zeros_like(self.density_grid[0]) + poses = poses.to(count.device) + + # 5-level loop, forgive me... + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_world_xyzs = world_xyzs * (bound - half_grid_size) + + # split batch to avoid OOM + head = 0 + while head < B: + tail = min(head + S, B) + + # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) + cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) + cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] + + # query if point is covered by any camera + mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] + mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] + + # update count + count[cas, indices] += mask + head += S + + # mark untrained grid as -1 + self.density_grid[count.unsqueeze(0).expand_as(self.density_grid) == 0] = -1 + + print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}') + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # call before each epoch to update extra states. + + if not self.cuda_ray: + return + + ### update density grid + + tmp_grid = - torch.ones_like(self.density_grid) + + # full update. + if self.iter_density < 16: + #if True: + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + for t, time in enumerate(self.times): + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + half_time_size = 0.5 / self.time_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in coord [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # add noise in time [-hts, hts] + time_perturb = time + (torch.rand_like(time) * 2 - 1) * half_time_size + # query density + sigmas = self.density(cas_xyzs, time_perturb)['sigma'].reshape(-1).detach() + sigmas *= self.density_scale + # assign + tmp_grid[t, cas, indices] = sigmas + + # partial update (half the computation) + # just update 100 times should be enough... too time consuming. + elif self.iter_density < 100: + N = self.grid_size ** 3 // 4 # T * C * H * H * H / 4 + for t, time in enumerate(self.times): + for cas in range(self.cascade): + # random sample some positions + coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + # random sample occupied positions + occ_indices = torch.nonzero(self.density_grid[t, cas] > 0).squeeze(-1) # [Nz] + rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device) + occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication + occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] + # concat + indices = torch.cat([indices, occ_indices], dim=0) + coords = torch.cat([coords, occ_coords], dim=0) + # same below + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + half_time_size = 0.5 / self.time_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # add noise in time [-hts, hts] + time_perturb = time + (torch.rand_like(time) * 2 - 1) * half_time_size + # query density + sigmas = self.density(cas_xyzs, time_perturb)['sigma'].reshape(-1).detach() + sigmas *= self.density_scale + # assign + tmp_grid[t, cas, indices] = sigmas + + ## max-pool on tmp_grid for less aggressive culling [No significant improvement...] + # invalid_mask = tmp_grid < 0 + # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) + # tmp_grid[invalid_mask] = -1 + + # ema update + valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density. + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + for t in range(self.time_size): + raymarching.packbits(self.density_grid[t], density_thresh, self.density_bitfield[t]) + + ### update step counter + total_step = min(16, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 + + #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') + + + def render(self, rays_o, rays_d, time, staged=False, max_ray_batch=4096, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # return: pred_rgb: [B, N, 3] + + if self.cuda_ray: + _run = self.run_cuda + else: + _run = self.run + + B, N = rays_o.shape[:2] + device = rays_o.device + + # never stage when cuda_ray + if staged and not self.cuda_ray: + depth = torch.empty((B, N), device=device) + image = torch.empty((B, N, 3), device=device) + + for b in range(B): + head = 0 + while head < N: + tail = min(head + max_ray_batch, N) + results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], time[b:b+1], **kwargs) + depth[b:b+1, head:tail] = results_['depth'] + image[b:b+1, head:tail] = results_['image'] + head += max_ray_batch + + results = {} + results['depth'] = depth + results['image'] = image + + else: + results = _run(rays_o, rays_d, time, **kwargs) + + return results \ No newline at end of file diff --git a/torch-ngp/dnerf/utils.py b/torch-ngp/dnerf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99b5b439cd92770c7d5d02d99f0c65f3f309a21b --- /dev/null +++ b/torch-ngp/dnerf/utils.py @@ -0,0 +1,243 @@ +from nerf.utils import * +from nerf.utils import Trainer as _Trainer + + +class Trainer(_Trainer): + def __init__(self, + name, # name of this experiment + opt, # extra conf + model, # network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + eval_interval=1, # eval once every $ epoch + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.optimizer_fn = optimizer + self.lr_scheduler_fn = lr_scheduler + + super().__init__(name, opt, model, criterion, optimizer, ema_decay, lr_scheduler, metrics, local_rank, world_size, device, mute, fp16, eval_interval, max_keep_ckpt, workspace, best_mode, use_loss_as_metric, report_metric_at_train, use_checkpoint, use_tensorboardX, scheduler_update_every_step) + + ### ------------------------------ + + def train_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + time = data['time'] # [B, 1] + + # if there is no gt image, we train with CLIP loss. + if 'images' not in data: + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + # currently fix white bg, MUST force all rays! + outputs = self.model.render(rays_o, rays_d, time, staged=False, bg_color=None, perturb=True, force_all_rays=True, **vars(self.opt)) + pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() + + # [debug] uncomment to plot the images used in train_step + #torch_vis_2d(pred_rgb[0]) + + loss = self.clip_loss(pred_rgb) + + return pred_rgb, None, loss + + images = data['images'] # [B, N, 3/4] + + B, N, C = images.shape + + if self.opt.color_space == 'linear': + images[..., :3] = srgb_to_linear(images[..., :3]) + + if C == 3 or self.model.bg_radius > 0: + bg_color = 1 + # train with random background color if not using a bg model and has alpha channel. + else: + #bg_color = torch.ones(3, device=self.device) # [3], fixed white background + #bg_color = torch.rand(3, device=self.device) # [3], frame-wise random. + bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random. + + if C == 4: + gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) + else: + gt_rgb = images + + outputs = self.model.render(rays_o, rays_d, time, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt)) + + pred_rgb = outputs['image'] + + loss = self.criterion(pred_rgb, gt_rgb).mean(-1) # [B, N, 3] --> [B, N] + + # special case for CCNeRF's rank-residual training + if len(loss.shape) == 3: # [K, B, N] + loss = loss.mean(0) + + # update error_map + if self.error_map is not None: + index = data['index'] # [B] + inds = data['inds_coarse'] # [B, N] + + # take out, this is an advanced indexing and the copy is unavoidable. + error_map = self.error_map[index] # [B, H * W] + + # [debug] uncomment to save and visualize error map + # if self.global_step % 1001 == 0: + # tmp = error_map[0].view(128, 128).cpu().numpy() + # print(f'[write error map] {tmp.shape} {tmp.min()} ~ {tmp.max()}') + # tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min()) + # cv2.imwrite(os.path.join(self.workspace, f'{self.global_step}.jpg'), (tmp * 255).astype(np.uint8)) + + error = loss.detach().to(error_map.device) # [B, N], already in [0, 1] + + # ema update + ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error + error_map.scatter_(1, inds, ema_error) + + # put back + self.error_map[index] = error_map + + loss = loss.mean() + + # deform regularization + if 'deform' in outputs and outputs['deform'] is not None: + loss = loss + 1e-3 * outputs['deform'].abs().mean() + + return pred_rgb, gt_rgb, loss + + def eval_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + time = data['time'] # [B, 1] + images = data['images'] # [B, H, W, 3/4] + B, H, W, C = images.shape + + if self.opt.color_space == 'linear': + images[..., :3] = srgb_to_linear(images[..., :3]) + + # eval with fixed background color + bg_color = 1 + if C == 4: + gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) + else: + gt_rgb = images + + outputs = self.model.render(rays_o, rays_d, time, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt)) + + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + + loss = self.criterion(pred_rgb, gt_rgb).mean() + + return pred_rgb, pred_depth, gt_rgb, loss + + # moved out bg_color and perturb for more flexible control... + def test_step(self, data, bg_color=None, perturb=False): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + time = data['time'] # [B, 1] + H, W = data['H'], data['W'] + + if bg_color is not None: + bg_color = bg_color.to(self.device) + + outputs = self.model.render(rays_o, rays_d, time, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt)) + + pred_rgb = outputs['image'].reshape(-1, H, W, 3) + pred_depth = outputs['depth'].reshape(-1, H, W) + + return pred_rgb, pred_depth + + # [GUI] test on a single image + def test_gui(self, pose, intrinsics, W, H, time=0, bg_color=None, spp=1, downscale=1): + + # render resolution (may need downscale to for better frame rate) + rH = int(H * downscale) + rW = int(W * downscale) + intrinsics = intrinsics * downscale + + pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) + + rays = get_rays(pose, intrinsics, rH, rW, -1) + + data = { + 'time': torch.FloatTensor([[time]]).to(self.device), # from scalar to [1, 1] tensor. + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'H': rH, + 'W': rW, + } + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + # here spp is used as perturb random seed! + preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp) + + if self.ema is not None: + self.ema.restore() + + # interpolation to the original resolution + if downscale != 1: + # TODO: have to permute twice with torch... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + + outputs = { + 'image': pred, + 'depth': pred_depth, + } + + return outputs + + def save_mesh(self, time, save_path=None, resolution=256, threshold=10): + # time: scalar in [0, 1] + time = torch.FloatTensor([[time]]).to(self.device) + + if save_path is None: + save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + def query_func(pts): + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + sigma = self.model.density(pts.to(self.device), time)['sigma'] + return sigma + + vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func) + + mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... + mesh.export(save_path) + + self.log(f"==> Finished saving mesh.") \ No newline at end of file diff --git a/torch-ngp/encoding.py b/torch-ngp/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..17b1d7355d0be31b2bfc72b9f1de4adea0e095cb --- /dev/null +++ b/torch-ngp/encoding.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FreqEncoder(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, **kwargs): + + out = [] + if self.include_input: + out.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + out = torch.cat(out, dim=-1) + + + return out + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency': + #encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) + from freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'sphere_harmonics': + from shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) + + elif encoding == 'tiledgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) + + elif encoding == 'ash': + from ashencoder import AshEncoder + encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/torch-ngp/environment.yml b/torch-ngp/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..63ecc467e9a9721d532cc282558d033a3a4db22a --- /dev/null +++ b/torch-ngp/environment.yml @@ -0,0 +1,27 @@ +name: torch-ngp +channels: + - pytorch + - conda-forge +dependencies: + - python + - cudatoolkit=11.3 + - ninja + - trimesh + - opencv + - tensorboardX + - pytorch + - numpy + - pandas + - tqdm + - matplotlib + - rich + - packaging + - scipy + - pip: + - imageio + - lpips + - torch-ema + - PyMCubes + - pysdf + - dearpygui + - git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch diff --git a/torch-ngp/ffmlp/__init__.py b/torch-ngp/ffmlp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7559bad2bfb7e524a6dc746b7f2834fc1fbb5aca --- /dev/null +++ b/torch-ngp/ffmlp/__init__.py @@ -0,0 +1 @@ +from .ffmlp import FFMLP \ No newline at end of file diff --git a/torch-ngp/ffmlp/backend.py b/torch-ngp/ffmlp/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..c33724a1728e6e8b4a29a2b39130ed2e327d60ff --- /dev/null +++ b/torch-ngp/ffmlp/backend.py @@ -0,0 +1,46 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '--expt-extended-lambda', '--expt-relaxed-constexpr', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + nvcc_flags += ['-Xcompiler=-mf16c', '-Xcompiler=-Wno-float-conversion', '-Xcompiler=-fno-strict-aliasing'] + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_ffmlp', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + extra_include_paths=[ + os.path.join(_src_path, 'dependencies/cutlass/include'), + os.path.join(_src_path, 'dependencies/cutlass/tools/util/include'), + ], + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'ffmlp.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/torch-ngp/ffmlp/ffmlp.py b/torch-ngp/ffmlp/ffmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e211b651d770d437ef98fb3128a213b10ed9074 --- /dev/null +++ b/torch-ngp/ffmlp/ffmlp.py @@ -0,0 +1,169 @@ +import math +from turtle import backward, forward + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd +import atexit + +try: + import _ffmlp as _backend +except ImportError: + from .backend import _backend + +class _ffmlp_forward(Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.half) + def forward(ctx, inputs, weights, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, inference=False, calc_grad_inputs=False): + + B = inputs.shape[0] + + inputs = inputs.contiguous() + weights = weights.contiguous() + + # print('[inputs]', torch.any(torch.isnan(inputs)), inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + # print('[weights]', torch.any(torch.isnan(weights)), weights.shape, weights.dtype, weights.min().item(), weights.max().item()) + + # allocate output + outputs = torch.empty(B, output_dim, device=inputs.device, dtype=inputs.dtype) + + if not inference: + forward_buffer = torch.empty(num_layers, B, hidden_dim, device=inputs.device, dtype=inputs.dtype) + _backend.ffmlp_forward(inputs, weights, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, forward_buffer, outputs) + ctx.save_for_backward(inputs, weights, outputs, forward_buffer) + ctx.dims = (input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs) + + # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + # print('[forward_buffer]', torch.any(torch.isnan(forward_buffer)), forward_buffer.shape, forward_buffer.dtype, forward_buffer.min().item(), forward_buffer.max().item()) + else: + inference_buffer = torch.empty(B, hidden_dim, device=inputs.device, dtype=inputs.dtype) + _backend.ffmlp_inference(inputs, weights, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, inference_buffer, outputs) + + # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + # print('[inference_buffer]', torch.any(torch.isnan(inference_buffer)), inference_buffer.shape, inference_buffer.dtype, inference_buffer.min().item(), inference_buffer.max().item()) + + + return outputs + + @staticmethod + @custom_bwd + def backward(ctx, grad): + # grad: [B, output_dim] + + B = grad.shape[0] + + grad = grad.contiguous() + + # print('[grad]', torch.any(torch.isnan(grad)), grad.shape, grad.dtype, grad.min().item(), grad.max().item()) + # print(grad) + + inputs, weights, outputs, forward_buffer = ctx.saved_tensors + + input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs = ctx.dims + + # allocate outputs + if calc_grad_inputs: + grad_inputs = torch.zeros_like(inputs) + else: + grad_inputs = torch.zeros(1, device=grad.device, dtype=grad.dtype) # dummy + + grad_weights = torch.zeros_like(weights) + backward_buffer = torch.zeros(num_layers, B, hidden_dim, device=grad.device, dtype=grad.dtype) + + _backend.ffmlp_backward(grad, inputs, weights, forward_buffer, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs, backward_buffer, grad_inputs, grad_weights) + + # print('[grad_inputs]', grad_inputs.shape, grad_inputs.dtype, grad_inputs.min().item(), grad_inputs.max().item()) + # print('[grad_weights]', grad_weights.shape, grad_weights.dtype, grad_weights.min().item(), grad_weights.max().item()) + # print('[backward_buffer]', backward_buffer.shape, backward_buffer.dtype, backward_buffer.min().item(), backward_buffer.max().item()) + if calc_grad_inputs: + return grad_inputs, grad_weights, None, None, None, None, None, None, None, None + else: + return None, grad_weights, None, None, None, None, None, None, None, None + + +ffmlp_forward = _ffmlp_forward.apply + + +def convert_activation(act): + if act == 'relu': return 0 + elif act == 'exponential': return 1 + elif act == 'sine': return 2 + elif act == 'sigmoid': return 3 + elif act == 'squareplus': return 4 + elif act == 'softplus': return 5 + else: return 6 + + +class FFMLP(nn.Module): + def __init__(self, input_dim, output_dim, hidden_dim, num_layers, activation='relu'): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.activation = convert_activation(activation) + self.output_activation = convert_activation('none') # not supported currently + + self.tensorcore_width = 16 + + assert hidden_dim in [16, 32, 64, 128, 256], f"FFMLP only support hidden_dim in [16, 32, 64, 128, 256], but got {hidden_dim}" + assert input_dim > 0 and input_dim % 16 == 0, f"FFMLP input_dim should be 16 * m (m > 0), but got {input_dim}" + assert output_dim <= 16, f"FFMLP current only supports output dim <= 16, but got {output_dim}" + assert num_layers >= 2, f"FFMLP num_layers should be larger than 2 (3 matmuls), but got {num_layers}" + + # pad output + self.padded_output_dim = int(math.ceil(output_dim / 16)) * 16 + + # parameters (continuous in memory) + self.num_parameters = hidden_dim * (input_dim + hidden_dim * (num_layers - 1) + self.padded_output_dim) + self.weights = nn.Parameter(torch.zeros(self.num_parameters)) + self.reset_parameters() + + # allocate streams + _backend.allocate_splitk(self.num_layers + 1) + + # register destructor + #atexit.register(self.cleanup) # how to correctly clean? this gives CUDA Error: cudaEventDestroy(events[i]) failed with error context is destroyed + + + def cleanup(self): + # destroy streams + _backend.free_splitk() + + + def __repr__(self): + return f"FFMLP: input_dim={self.input_dim} output_dim={self.output_dim} hidden_dim={self.hidden_dim} num_layers={self.num_layers} activation={self.activation}" + + + def reset_parameters(self): + torch.manual_seed(42) + std = math.sqrt(3 / self.hidden_dim) + self.weights.data.uniform_(-std, std) + + + def forward(self, inputs): + # inputs: [B, input_dim] + # return: [B, outupt_dim] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item(), inputs.requires_grad) + + B, C = inputs.shape + #assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}." + + # pad input + pad = 128 - (B % 128) + if pad > 0: + inputs = torch.cat([inputs, torch.zeros(pad, C, dtype=inputs.dtype, device=inputs.device)], dim=0) + + outputs = ffmlp_forward(inputs, self.weights, self.input_dim, self.padded_output_dim, self.hidden_dim, self.num_layers, self.activation, self.output_activation, not self.training, inputs.requires_grad) + + # unpad output + if B != outputs.shape[0] or self.padded_output_dim != self.output_dim: + outputs = outputs[:B, :self.output_dim] + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs \ No newline at end of file diff --git a/torch-ngp/ffmlp/setup.py b/torch-ngp/ffmlp/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f0cdb2de2bdd1689507c90267c50258f6d90f9 --- /dev/null +++ b/torch-ngp/ffmlp/setup.py @@ -0,0 +1,56 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '--expt-extended-lambda', '--expt-relaxed-constexpr', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + nvcc_flags += ['-Xcompiler=-mf16c', '-Xcompiler=-Wno-float-conversion', '-Xcompiler=-fno-strict-aliasing'] + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='ffmlp', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_ffmlp', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'ffmlp.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + }, + include_dirs=[ + os.path.join(_src_path, 'dependencies/cutlass/include'), + os.path.join(_src_path, 'dependencies/cutlass/tools/util/include'), + ], + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/torch-ngp/ffmlp/src/bindings.cpp b/torch-ngp/ffmlp/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f094c7274f1ba47ca1d03729bbf7dccc50f9f347 --- /dev/null +++ b/torch-ngp/ffmlp/src/bindings.cpp @@ -0,0 +1,11 @@ +#include + +#include "ffmlp.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ffmlp_forward", &ffmlp_forward, "ffmlp_forward (CUDA)"); + m.def("ffmlp_inference", &ffmlp_inference, "ffmlp_inference (CUDA)"); + m.def("ffmlp_backward", &ffmlp_backward, "ffmlp_backward (CUDA)"); + m.def("allocate_splitk", &allocate_splitk, "allocate_splitk (CUDA)"); + m.def("free_splitk", &free_splitk, "free_splitk (CUDA)"); +} \ No newline at end of file diff --git a/torch-ngp/ffmlp/src/cutlass_matmul.h b/torch-ngp/ffmlp/src/cutlass_matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..be2a65f48cc6aed7a6b36778d8a73261f8a29dbb --- /dev/null +++ b/torch-ngp/ffmlp/src/cutlass_matmul.h @@ -0,0 +1,493 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *//* + */ + +/** @file cutlass_matmul.h + * @author Thomas Müller, NVIDIA + * @brief Matrix multiplication wrappers that call into CUTLASS (plus some custom modifications). + * The parameters are optimized to give optimal performance in a variety of situations. + * Parts of this file were adapted by starting from the CUTLASS sample code (see its BSD 3-clause license). + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "utils.h" + +//#define TCNN_VERBOSE_MEMORY_ALLOCS + +#define CUTLASS_CHECK(status) \ +{ \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ +} + +#define CUDA_CHECK(status) \ +{ \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ +} + +using SmArch = std::conditional_t= 80, + std::conditional_t::value, cutlass::arch::Sm75, cutlass::arch::Sm80>, + std::conditional_t= 75, + cutlass::arch::Sm75, + cutlass::arch::Sm70 + > +>; + +using TypeAccumulator = std::conditional_t::value, float, cutlass::half_t>; +using TypeCompute = std::conditional_t::value, float, cutlass::half_t>; + +template +using MMAOp = typename std::conditional< + std::is_same::value, + cutlass::arch::OpClassSimt, + cutlass::arch::OpClassTensorOp +>::type; + +template +using ShapeMMAOp = typename std::conditional< + std::is_same, cutlass::arch::OpClassTensorOp>::value, + typename std::conditional< + std::is_same::value || std::is_same::value, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::gemm::GemmShape<8, 8, 4> + >::type, + cutlass::gemm::GemmShape<1, 1, 1> +>::type; + +template +struct LayerConfig { + using k_thread_block = thread_block; + using k_warp = warp; +}; + +using FullLayerK = typename std::conditional< + std::is_same, cutlass::arch::OpClassSimt>::value, + LayerConfig, cutlass::gemm::GemmShape<32, 64, 8>>, + LayerConfig, cutlass::gemm::GemmShape<32, 32, 32>> +>::type; +using LastLayerK = FullLayerK; + +using FullLayer = typename std::conditional< + std::is_same, cutlass::arch::OpClassSimt>::value, + LayerConfig, cutlass::gemm::GemmShape<32, 64, 8>>, + LayerConfig, cutlass::gemm::GemmShape<64, 64, 32>> +>::type; +using LastLayer = FullLayer; + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// This code section describes the epilogue part of the kernel + +template +struct CutlassFragmentWrapper { + static const uint32_t num_elements = V::kElements; + V x; +}; + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest +> +class ActivationEpilogue { +public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = cutlass::Array; + using FragmentAccumulator = cutlass::Array; + using ComputeFragment = cutlass::Array; + + static cutlass::FloatRoundStyle const kRound = Round; + + struct Params { + Activation activation; + bool sum_source; + }; + +public: + CUTLASS_HOST_DEVICE + ActivationEpilogue(Params const ¶ms) : m_activation{params.activation}, m_sum_source{params.sum_source} { } + + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return m_sum_source; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator) const { + cutlass::NumericArrayConverter accumulator_converter; + + auto intermediate = CutlassFragmentWrapper{accumulator_converter(accumulator)}; + intermediate = warp_activation(m_activation, intermediate); + + cutlass::NumericArrayConverter destination_converter; + return destination_converter(intermediate.x); + } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const { + cutlass::NumericArrayConverter source_converter; + cutlass::NumericArrayConverter accumulator_converter; + + cutlass::plus plus_op; + auto intermediate = CutlassFragmentWrapper{accumulator_converter(accumulator)}; + if (m_sum_source) { + intermediate.x = plus_op(intermediate.x, source_converter(source)); + } + intermediate = warp_activation(m_activation, intermediate); + + cutlass::NumericArrayConverter destination_converter; + return destination_converter(intermediate.x); + } + +private: + Activation m_activation; + bool m_sum_source; +}; + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest +> +class ActivationTransferEpilogue { +public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = cutlass::Array; + using FragmentAccumulator = cutlass::Array; + using ComputeFragment = cutlass::Array; + + static cutlass::FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + Activation activation; + }; + +public: + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + ActivationTransferEpilogue(Params const ¶ms) : m_activation{params.activation} { } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + cutlass::NumericArrayConverter source_converter; + cutlass::NumericArrayConverter accumulator_converter; + + auto converted_source = CutlassFragmentWrapper{source_converter(source)}; + auto intermediate = CutlassFragmentWrapper{accumulator_converter(accumulator)}; + + intermediate = warp_activation_backward(m_activation, intermediate, converted_source); + + cutlass::NumericArrayConverter destination_converter; + return destination_converter(intermediate.x); + } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + cutlass::NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + cutlass::NumericArrayConverter destination_converter; + + return destination_converter(converted_accumulator); + } + +private: + Activation m_activation; +}; + + +template +static constexpr int n_vectorized_elements = std::is_same, cutlass::arch::OpClassTensorOp>::value ? (128 / cutlass::sizeof_bits::value) : 1; + +template +using SumOp = cutlass::epilogue::thread::LinearCombination, TypeAccumulator, TypeCompute>; + +template +using IntermediateActivationOp = ActivationEpilogue; + +template +using IntermediateActivationTransferOp = ActivationTransferEpilogue; + +template +using ActivationOp = ActivationEpilogue, TypeAccumulator, TypeCompute>; + +template +using ActivationTransferOp = ActivationTransferEpilogue, TypeAccumulator, TypeCompute>; + + +template +using OurGemm = cutlass::gemm::device::Gemm< + TypeA, + LayoutA, + TypeB, + LayoutB, + TypeOutput, + LayoutOutput, + TypeAccumulator, + MMAOp, + SmArch, + typename LayerConfig::k_thread_block, + typename LayerConfig::k_warp, + ShapeMMAOp, + EPILOGUE, + SwizzleThreadBlock, + 2 +>; + +template +using SplitKGemm = cutlass::gemm::device::GemmSplitKParallel< + TypeA, + LayoutA, + TypeB, + LayoutB, + TypeOutput, + LayoutOutput, + TypeAccumulator, + MMAOp, + SmArch, + typename LayerConfig::k_thread_block, + typename LayerConfig::k_warp, + ShapeMMAOp, + EPILOGUE +>; + +inline std::map>& cutlass_workspaces() { + static std::map> s_workspaces; + return s_workspaces; +} + +inline uint8_t* cutlass_get_workspace(size_t size, cudaStream_t stream) { + GPUMemory& workspace = cutlass_workspaces()[stream]; + if (size > workspace.size()) { + size *= 2; +#ifdef TCNN_VERBOSE_MEMORY_ALLOCS + std::cout << "CUTLASS GEMM: Allocating temporary workspace of " << bytes_to_string(size) << "." << std::endl; +#endif + + // Allocate twice the requested size to make sure we're not constantly allocating small increments. + workspace.resize(size); + } + return workspace.data(); +} + +inline void cutlass_free_workspace(cudaStream_t stream) { + if (cutlass_workspaces().count(stream) == 0) { + return; + } + +#ifdef TCNN_VERBOSE_MEMORY_ALLOCS + std::cout << "CUTLASS GEMM: Freeing temporary workspace of " << bytes_to_string(cutlass_workspaces().at(stream).size()) << "." << std::endl; +#endif + cutlass_workspaces().erase(stream); +} + + +template +void fc_multiply_impl(cudaStream_t stream, const typename Gemm::Arguments& args) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(args); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(args, cutlass_get_workspace(workspace_size, stream), stream); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(stream); + CUTLASS_CHECK(status); +} + +template +void fc_multiply_split_k_impl(cudaStream_t stream, const typename Gemm::Arguments& args) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(args); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(args, cutlass_get_workspace(workspace_size, stream)); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(stream); + CUTLASS_CHECK(status); +} + +////////////////////////////////////////////////////////////////////////////////// +//////////////////////////// modified /////////////////////////////// +////////////////////////////////////////////////////////////////////////////////// + +template +void fc_multiply(cudaStream_t stream, const int M, const int K, const int N, const __half* A, const __half* B, const __half* C, __half* D, Activation act = Activation::None, bool transfer = false, bool sum_source = false) { + // compute D = A @ B + C + // A: [M, K] + // B: [K, N] + // C, D: [M, N] + using CutlassLayoutA = typename std::conditional::type; + using CutlassLayoutB = typename std::conditional::type; + using CutlassLayoutC = typename std::conditional::type; + + using MatmulTypeCompute = cutlass::half_t; + using MatmulTypeAccumulator = cutlass::half_t; + + const int lda = RM_A ? K : M; + const int ldb = RM_B ? N : K; + const int ldc = RM_C ? N : M; + const int ldd = RM_C ? N : M; + + if (transfer) { + using Gemm = OurGemm, config, MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>; + typename Gemm::Arguments arguments{ + {M, N, K}, + {(MatmulTypeCompute*)A, lda}, + {(MatmulTypeCompute*)B, ldb}, + {(MatmulTypeAccumulator*)C, ldc}, + {(MatmulTypeAccumulator*)D, ldd}, + {act}, + 1 + }; + + fc_multiply_impl(stream, arguments); + } else { + using Gemm = OurGemm, config, MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>; + typename Gemm::Arguments arguments{ + {M, N, K}, + {(MatmulTypeCompute*)A, lda}, + {(MatmulTypeCompute*)B, ldb}, + {(MatmulTypeAccumulator*)C, ldc}, + {(MatmulTypeAccumulator*)D, ldd}, + {act, sum_source}, + 1 + }; + + fc_multiply_impl(stream, arguments); + } +} + + +template +void fc_multiply(cudaStream_t stream, const int M, const int K, const int N, const __half* A, const __half* B, __half* D, Activation act = Activation::None) { + fc_multiply(stream, M, K, N, A, B, D, D, act); +} + + +template +void fc_multiply_split_k(cudaStream_t stream, const int M, const int K, const int N, const __half* A, const __half* B, const __half* C, __half* D, int split_k_slices = 1) { + // A: [M, K] + // B: [K, N] + // C, D: [M, N] + using CutlassLayoutA = typename std::conditional::type; + using CutlassLayoutB = typename std::conditional::type; + using CutlassLayoutC = typename std::conditional::type; + + using MatmulTypeCompute = cutlass::half_t; + using MatmulTypeAccumulator = cutlass::half_t; + + const int lda = RM_A ? K : M; + const int ldb = RM_B ? N : K; + const int ldc = RM_C ? N : M; + const int ldd = RM_C ? N : M; + + using Gemm = SplitKGemm, config, MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>; + + typename Gemm::Arguments arguments{ + {M, N, K}, + {(MatmulTypeCompute*)A, lda}, + {(MatmulTypeCompute*)B, ldb}, + {(MatmulTypeAccumulator*)C, ldc}, + {(MatmulTypeAccumulator*)D, ldd}, + {(TypeCompute)1.0f, (TypeCompute)0.0f}, + split_k_slices + }; + + fc_multiply_split_k_impl(stream, arguments); +} + +template +void fc_multiply_split_k(cudaStream_t stream, const int M, const int K, const int N, const __half* A, const __half* B, __half* D, int split_k_slices = 1) { + fc_multiply_split_k(stream, M, K, N, A, B, D, D, split_k_slices); +} diff --git a/torch-ngp/ffmlp/src/ffmlp.cu b/torch-ngp/ffmlp/src/ffmlp.cu new file mode 100644 index 0000000000000000000000000000000000000000..48364bfc795e858ceced4fda89ed2ba657b01628 --- /dev/null +++ b/torch-ngp/ffmlp/src/ffmlp.cu @@ -0,0 +1,895 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + +#include + +#include "utils.h" +#include "cutlass_matmul.h" + + +__host__ __device__ Activation convert_activation(const uint32_t activation) { + switch (activation) { + case 0: return Activation::ReLU; + case 1: return Activation::Exponential; + case 2: return Activation::Sine; + case 3: return Activation::Sigmoid; + case 4: return Activation::Squareplus; + case 5: return Activation::Softplus; + case 6: return Activation::None; + default: return Activation::None; + } +} + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +void check_shmem_error(cudaError_t error) { + if (error != cudaSuccess) { + throw std::runtime_error{"FullyFusedMLP: insufficient shared memory available on the GPU. Reduce `n_neurons` or use `CutlassMLP` (better compatibility but slower) instead."}; + } +} + + +template +__device__ void threadblock_layer(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const OUT_T* __restrict__ activation_aux = nullptr) { + // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch. + // Can be forward activations or backward activations, depending on caller. + // weights_this_layer points to the weight matrix of the current layer. + // out_intermediate_threadblock_this_layer points to the location where intermediate activations produced by the thread block should be written to. + // Can be nullptr if nothing should be written. + // activation_aux points to additional arguments that the activation function may depend on. Points to the hidden forward activations when computing backward activations. + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + using namespace nvcuda; + + // If we're performing the backward pass, weights must be loaded in transposed form, which + // is achieved by interpreting the memory in row_major instead of col_major order. + using weights_layout_t = std::conditional_t; + + // Fragments + wmma::fragment act_frag; + wmma::fragment weights_frag[N_BLOCKS]; + wmma::fragment result_frag[N_ITERS]; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + const uint32_t weights_col = 16 * wi; + + __syncthreads(); + + // Load N_BLOCKS chunks of weights from global memory into registers. + #pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) { + if (BACKWARD) { + // If we're performing the backward pass, additional index swizzling is needed to + // load the weights in transposed form. + wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i * WIDTH + weights_col, WIDTH); + } else { + wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH); + } + } + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::fill_fragment(result_frag[l], 0.0f); + + #pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) { + // Load a chunk of intermediate activations from shared memory and multiply with chunk of weights + wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW), WIDTH + SKEW); + wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]); + } + + // Activation + if (BACKWARD) { + // Load the temporary forward matrix for the relu transfer + wmma::load_matrix_sync(act_frag, activation_aux + weights_col + (threadIdx.z + l * BLOCK_DIM_Z) * 16 * WIDTH, WIDTH); + warp_activation_backward<__half>(activation, result_frag[l], act_frag, result_frag[l]); + } else { + warp_activation<__half>(activation, result_frag[l], result_frag[l]); + } + } + + __syncthreads(); + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::store_matrix_sync(act_shmem + weights_col + (threadIdx.z + l * BLOCK_DIM_Z) * 16 * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); + } + + if (out_intermediate_threadblock_this_layer != nullptr) { + __syncthreads(); + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; + } + } +} + +template +__device__ void threadblock_load_input_static(__half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock) { + // act_shmem will be filled by the thread block's chunk of input_threadblock + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + #pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)] = *(int4*)&input_threadblock[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH]; + } +} + + +template +__device__ void threadblock_input_layer_forward_dynamic(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const uint32_t in_width) { + // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch + // input_threadblock points to the thread block's chunk of the input batch in global memory + // weights_this_layer points to the weight matrix of the current layer + // out_intermediate_threadblock_this_layer points to the location where intermediate activations produced by the thread block should be written to. + // Can be nullptr if nothing should be written. + // in_width is the dynamic width of the input layer + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + constexpr uint32_t INPUT_SKEW = 8; + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + using namespace nvcuda; + + // Fragments + wmma::fragment act_frag; + wmma::fragment weights_frag; + wmma::fragment result_frag[N_ITERS]; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + const uint32_t weights_col = 16 * wi; + + __half* __restrict__ weights_shmem = act_shmem + BLOCK_DIM_Z * 16 * (in_width + INPUT_SKEW); + + // Load input weight matrix (fits completely into shared memory) + // Each thread can load 8 fp16 elements (16 bytes) at once; we have N_BLOCKS*BLOCK_DIM_Z warps + const uint32_t n_elems_per_load = N_BLOCKS * 32 * BLOCK_DIM_Z * 8; + const uint32_t thread_elem_idx = (li + wi * 32 + threadIdx.z * N_BLOCKS * 32) * 8; + + const uint32_t n_elems_b = WIDTH * in_width; + + #pragma unroll + for (uint32_t idx = thread_elem_idx; idx < n_elems_b; idx += n_elems_per_load) { + const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; + *(int4*)&weights_shmem[idx_skewed] = *(int4*)&weights_this_layer[idx]; + } + + const uint32_t n_tensor_ops = in_width / 16; + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + // Load chunk of inputs into shmem. + // This is faster than loading it from gmem directly, even though it is only used once. + // (Possibly due to latency hiding through staging.) + const uint32_t n_elems_a = BLOCK_DIM_Z * 16 * in_width; + + #pragma unroll + for (uint32_t idx = thread_elem_idx; idx < n_elems_a; idx += n_elems_per_load) { + const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; + *(int4*)&act_shmem[idx_skewed] = *(int4*)&input_threadblock[l * n_elems_a + idx]; + } + + __syncthreads(); + + wmma::fill_fragment(result_frag[l], 0.0f); + #pragma unroll + for (uint32_t i = 0; i < n_tensor_ops; ++i) { + // Load chunk of inputs and weights from shared memory and multiply them + wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * threadIdx.z) * (in_width + INPUT_SKEW), in_width + INPUT_SKEW); + wmma::load_matrix_sync(weights_frag, weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW), in_width + INPUT_SKEW); + wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); + } + + __syncthreads(); + + warp_activation<__half>(activation, result_frag[l], result_frag[l]); + } + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::store_matrix_sync(act_shmem + weights_col + (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); + } + + if (out_intermediate_threadblock_this_layer != nullptr) { + __syncthreads(); + + #pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; + } + } +} + +template +__device__ void threadblock_last_layer_forward(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out, const uint32_t batch_size, const nvcuda::wmma::layout_t output_layout) { + // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch + // weights_this_layer points to the weight matrix of the current layer + // out points to the location where the result produced by the thread block should be written to. + // Can be nullptr if nothing should be written. + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + using namespace nvcuda; + + // Fragments + wmma::fragment act_frag; + wmma::fragment weights_frag[N_BLOCKS]; + wmma::fragment result_frag; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + __half* __restrict__ weights_shmem = act_shmem + N_ITERS * BLOCK_DIM_Z * 16 * (WIDTH + SKEW); + + const uint32_t weights_row = (8 * li) % WIDTH; + const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH; + + // Load weight matrix into shared memory for the last multiplication. + // Loading into shared memory as opposed to directly into registers is faster + // because unlike in the previous layers, each warp uses the same entries of the weight matrix. + if (threadIdx.z == 0) { + *(int4*)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] = *(int4*)&weights_this_layer[weights_row + weights_col * WIDTH]; + //printf("[last forward] base=%d, shmem=%d, weight=%d\n", N_ITERS * BLOCK_DIM_Z * 16 * (WIDTH + SKEW), weights_row + weights_col * (WIDTH + SKEW), weights_row + weights_col * WIDTH); + } + + __syncthreads(); + + #pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) + wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, WIDTH + SKEW); + + // Perform last layer by parallelizing over iters + for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) { + wmma::fill_fragment(result_frag, 0.0f); + + #pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) { + // Load a chunk of intermediate activations from shared memory and multiply with chunk of the weight matrix + wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * (threadIdx.z + idx * BLOCK_DIM_Z)) * (WIDTH + SKEW), WIDTH + SKEW); + wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag); + } + + warp_activation<__half>(activation, result_frag, result_frag); + + if (output_layout == wmma::mem_row_major) { + wmma::store_matrix_sync(out + (threadIdx.z + idx * BLOCK_DIM_Z) * 16 * 16, result_frag, 16, output_layout); + //printf("[last forward] RM write out %d, batch %d\n", (threadIdx.z + idx * BLOCK_DIM_Z) * 16 * 16, 16); + } else { + wmma::store_matrix_sync(out + (threadIdx.z + idx * BLOCK_DIM_Z) * 16, result_frag, batch_size, output_layout); + //printf("[last forward] CM write out %d, batch %d\n", (threadIdx.z + idx * BLOCK_DIM_Z) * 16, batch_size); + } + } +} + +template +__device__ void threadblock_write_output_static(const __half* __restrict__ act_shmem, __half* __restrict__ output_threadblock) { + // output_threadblock will be filled by the thread block's act_shmem + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + __syncthreads(); + + #pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&output_threadblock[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; + } +} + + +/////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////// + + +template +__global__ void kernel_mlp_fused( + const Activation activation, + const Activation output_activation, + const __half* __restrict__ input, + const __half* __restrict__ weights, + OUT_T* __restrict__ out_intermediate, + OUT_T* __restrict__ out, + const uint32_t batch_size, + const uint32_t in_width, + const uint32_t out_width, + const uint32_t n_hidden_matmuls, + const nvcuda::wmma::layout_t output_layout = nvcuda::wmma::mem_row_major +) { + // `input` points to the input matrix. Can be any width. + // `weights` points to the weight matrices (contiguous in memory). + // `out_intermediate` points to the memory where intermediate activations should be written. When performing inference, a value of nullptr is expected (intermediate results are not written). + // `out` points to the memory where the network output should be written. (Output width is assumed to be 16 neurons.) + + // if (threadIdx.x == 0) printf("[forward] call kernel_mlp_fused\n"); + // if (threadIdx.x == 0) printf("[forward] inputs=%f\n", (float)input[0]); + // if (threadIdx.x == 0) printf("[forward] weights=%f\n", (float)weights[0]); + + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + + // Shared memory contains the intermediate activations of blockDim.y*16 elements. + // In some cases, it also contains the weight matrix for the first and last layer. + extern __shared__ __half shmem[]; + __half* act_shmem = shmem; + + // Each block computes exactly one 16-element chunk of the batch. + const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS * BLOCK_DIM_Z; + + // First layer + if (in_width == WIDTH) { + // If the input has the same width as the network, we can simply use the network's regular layer routine (with static size) + // instead of using the slower dynamic input layer routine. + threadblock_load_input_static(act_shmem, input + elem_idx * WIDTH); + threadblock_layer(activation, act_shmem, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr); + } else { + threadblock_input_layer_forward_dynamic(activation, act_shmem, input + elem_idx * in_width, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width); + } + + // if (threadIdx.x == 0) printf("[forward] kernel_mlp_fused: passed first layer\n"); + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + + const uint32_t first_layer_size = WIDTH * in_width; + const uint32_t layer_stride = WIDTH * WIDTH; + const uint32_t output_stride = WIDTH * batch_size; + + // Hidden layers + for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { + threadblock_layer(activation, act_shmem, weights + first_layer_size + layer_stride * k, !INFERENCE ? (out_intermediate + output_stride * (k + 1) + elem_idx * WIDTH) : nullptr); + // if (threadIdx.x == 0) printf("[forward] kernel_mlp_fused: passed %d layer\n", k + 1); + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + } + + if (out_width > 16) { + // In the forward pass, intermediate activations are already written out. + if (INFERENCE) { + threadblock_write_output_static(act_shmem, out_intermediate + elem_idx * WIDTH); + } + } else if (out) { + // Last layer + if (output_layout == nvcuda::wmma::mem_row_major) { + //printf("[last layer] RM write to out %d\n", elem_idx * 16); + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + threadblock_last_layer_forward(output_activation, act_shmem, weights + first_layer_size + layer_stride * n_hidden_matmuls, out + elem_idx * 16, 16, output_layout); + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + } else { + //printf("[last layer] CM write to out %d\n", elem_idx); + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + threadblock_last_layer_forward(output_activation, act_shmem, weights + first_layer_size + layer_stride * n_hidden_matmuls, out + elem_idx, batch_size, output_layout); + //if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", (float)out_intermediate[0]); + } + } +} + + +template +__global__ void kernel_mlp_fused_backward( + const Activation activation, + const __half* __restrict__ dL_doutput, + const __half* __restrict__ weights, + __half* __restrict__ out_intermediate, + const __half* __restrict__ forward, + __half* __restrict__ dL_dinput, + const __half* __restrict__ weights_first_layer, + const uint32_t batch_size, + const uint32_t out_width, + const uint32_t n_hidden_matmuls +) { + // `dL_doutput` points to the input matrix of the backward pass, i.e. the loss gradients. Assumed to be 16 neurons wide. + // `weights` points to the weight matrices (contiguous in memory). + // `out_intermediate` points to the memory where backpropagated activation gradients should be written. + // `forward` points to the memory where the intermediate activations of the forward pass are located. (needed for activation backprop) + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + const uint32_t bi = blockIdx.x; // block index + + // Shared memory contains the intermediate activations of blockDim.y*16 elements. + // A skew is applied to the matrix storage to avoid bank conflicts. + extern __shared__ __half shmem[]; + __half* act_shmem = shmem; + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + // Multipying one 16-row chunk of intermediate activations with the weight matrix requires all warps of the block. + // Thus, each block computes exactly one 16-row chunk of the next layer's intermediate activations. + const uint32_t elem_idx_base = 16 * bi * N_ITERS * BLOCK_DIM_Z; + const uint32_t elem_idx = elem_idx_base + 16 * threadIdx.z; + + const uint32_t layer_stride = WIDTH * WIDTH; + const uint32_t output_stride = WIDTH * batch_size; + + // Backprop through last layer + if (out_width <= 16) { + using namespace nvcuda; + + // Fragments in registers + wmma::fragment act_frag; + wmma::fragment weights_frag; + wmma::fragment result_frag[N_ITERS]; + + // Load the relevant chunk of the last layer's weight matrix from global memory into registers + const uint32_t weights_col = 16 * wi; + + wmma::load_matrix_sync(weights_frag, weights + layer_stride * n_hidden_matmuls + weights_col, WIDTH); + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::fill_fragment(result_frag[l], 0.0f); + + // Load a chunk of output gradients from shared memory and multiply with previously loaded weights + if (std::is_same::value) { + wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * 16, 16); + } else { + wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * (threadIdx.z + l * BLOCK_DIM_Z)), batch_size); + } + + // NOTE: activation transfer of the _output_ activation is expected to be done _prior_ to calling this kernel + // in a separate pass, because the tranfered activation gradient is also needed to compute the weight + // gradient of the last weight matrix (see backward()). + wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); + + // Load the temporary forward matrix for the relu transfer + wmma::fragment forward_frag; + wmma::load_matrix_sync(forward_frag, forward + output_stride * n_hidden_matmuls + weights_col + (elem_idx + l * BLOCK_DIM_Z * 16) * WIDTH, WIDTH); + + warp_activation_backward<__half>(activation, result_frag[l], forward_frag, result_frag[l]); + } + + __syncthreads(); + + #pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::store_matrix_sync(act_shmem + weights_col + (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); + } + + __syncthreads(); + + #pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&out_intermediate[lane_offset + (row + elem_idx + i * BLOCK_DIM_Z * 16) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; + } + } else { + // If the output width is larger than 16, we will have used CUTLASS for backpropping through the last layer. + // Load the resulting gradients. + threadblock_load_input_static(act_shmem, out_intermediate + elem_idx * WIDTH); + } + + // Backprop through hidden layers + for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { + threadblock_layer(activation, act_shmem, weights + layer_stride * (n_hidden_matmuls - k - 1), out_intermediate + output_stride * (k + 1) + elem_idx_base * WIDTH, forward + output_stride * (n_hidden_matmuls - k - 1) + elem_idx_base * WIDTH); + } + + // Compute loss gradients w.r.t. input if desired. + // THIS CODE ASSUMES THAT THE INPUT WIDTH IS THE SAME AS THE NETWORK WIDTH. + // DON'T PASS A NON-NULL dL_dinput IF THIS REQUIREMENT IS NOT MET. + if (dL_dinput != nullptr) { + threadblock_layer(Activation::None, act_shmem, weights_first_layer, dL_dinput + elem_idx_base * WIDTH); + } +} + +////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////// + + +template // WIDTH is hidden_dim +void ffmlp_forward_cuda( + const __half *inputs, + const __half *weights, + const uint32_t B, + const uint32_t input_dim, + const uint32_t output_dim, + const uint32_t num_layers, + const Activation activation, + const Activation output_activation, + __half *forward_buffer, + __half *outputs +) { + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only support multiple-of-16 widths + constexpr uint32_t INPUT_SKEW = 8; // <- likewise with inputs + constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16; + + const int N_ITERS = WIDTH >= 256 ? 2 : 8; + const uint32_t BLOCK_DIM_Z = (INFERENCE && WIDTH == 128) ? 2 : 1; + + const dim3 threads = { 32u, N_BLOCK_ROWS, BLOCK_DIM_Z }; // 32 threads = 1 warp, N_BLOCK_ROWS warps per block for 16 rows, up to 2x 8 warps can share input (does not help vs. 1) + + uint32_t n_elems_per_block = 16 * BLOCK_DIM_Z * N_ITERS; + uint32_t n_blocks = div_round_up(B, n_elems_per_block); + + size_t shmem_size = sizeof(__half) * (16 + 16 * BLOCK_DIM_Z * N_ITERS) * (WIDTH + SKEW); // 16*WIDTH rows of weights (for the last layer; others are in registers only) + 16*WIDTH*BLOCK_DIM_Z*N_ITERS rows of intermediate activations + + // If the input width is dynamic, the input weight matrix as well as part of the input will live in extra shared memory + if (input_dim != WIDTH) { + shmem_size = std::max(shmem_size, sizeof(__half) * (WIDTH + 16 * BLOCK_DIM_Z) * (input_dim + INPUT_SKEW)); + } + + //printf("[ffmlp_forward_cuda] shmem size = %d\n", shmem_size); + + const dim3 blocks = { n_blocks, 1u, 1u }; + + check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size)); + + kernel_mlp_fused<<>>( + activation, + output_activation, + inputs, // CM + weights, // RM + forward_buffer, // CM + outputs, // CM + B, + input_dim, + output_dim, + num_layers - 1, + nvcuda::wmma::mem_row_major // reversed outputs's layout + ); +} + + +template // WIDTH is hidden_dim +void ffmlp_backward_cuda( + const __half *grad, + const __half *weights, + const uint32_t B, + const uint32_t input_dim, + const uint32_t output_dim, + const uint32_t num_layers, + const Activation activation, + const __half *forward_buffer, + __half *backward_buffer, + __half *grad_inputs +) { + + // locate + const __half * weights_first = weights; + const __half * weights_second = weights + input_dim * WIDTH; + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only support multiple-of-16 widths + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + const int N_ITERS = WIDTH >= 256 ? 2 : 8; + const uint32_t BLOCK_DIM_Z = 1; + + const dim3 threads = { 32u, N_BLOCKS, BLOCK_DIM_Z }; // 32 threads = 1 warp, N_BLOCK_ROWS warps per block for 16 rows, up to 2x 8 warps can share input (does not help vs. 1) + + uint32_t n_elems_per_block = 16 * BLOCK_DIM_Z * N_ITERS; + uint32_t n_blocks = div_round_up(B, n_elems_per_block); + + size_t shmem_size = sizeof(__half) * ((16 * BLOCK_DIM_Z * N_ITERS) * (WIDTH + SKEW)); // WIDTH rows of input and 16 * threads.z rows of weights + + const dim3 blocks = { n_blocks, 1u, 1u }; + + // The kernels operate with transposed layouts compared with the MLP code + check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel_mlp_fused_backward<<>>( + activation, + grad, // CM + weights_second, // RM + backward_buffer, // CM + forward_buffer, // CM + grad_inputs, // CM + weights_first, // RM + B, + output_dim, + num_layers - 1 + ); +} + + +// inputs: col-major [input_dim, B] +// weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim * (num_layers - 1)] + [output_dim * hidden_dim] +// forward_buffer: col-major [num_layers, hidden_dim, B] +// outputs: col-major [output_dim, B] +void ffmlp_forward(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor forward_buffer, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CONTIGUOUS(inputs); + CHECK_IS_HALF(inputs); + + CHECK_CUDA(weights); + CHECK_CONTIGUOUS(weights); + CHECK_IS_HALF(weights); + + Activation activation = convert_activation(activation_); + Activation output_activation = convert_activation(output_activation_); + + auto inputs_ptr = reinterpret_cast<__half*>(inputs.data_ptr()); + auto weights_ptr = reinterpret_cast<__half*>(weights.data_ptr()); + auto forward_buffer_ptr = reinterpret_cast<__half*>(forward_buffer.data_ptr()); + auto outputs_ptr = reinterpret_cast<__half*>(outputs.data_ptr()); + + switch (hidden_dim) { + case 16: ffmlp_forward_cuda<16, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; + case 32: ffmlp_forward_cuda<32, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; + case 64: ffmlp_forward_cuda<64, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; + case 128: ffmlp_forward_cuda<128, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; + case 256: ffmlp_forward_cuda<256, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; + default: throw std::runtime_error{"hidden_dim should in [16, 32, 64, 128, 256]"}; + } + + // for output_dim > 16 + if (output_dim > 16) { + fc_multiply(0, + output_dim, hidden_dim, B, + (weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // row-major, [output_dim, hidden_dim] + (forward_buffer_ptr + (num_layers - 1) * hidden_dim * B), // col-major [hidden_dim, B] + outputs_ptr, // col-major [outupt_dim, B] + output_activation + ); + } +} + +void ffmlp_inference(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor inference_buffer, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CONTIGUOUS(inputs); + CHECK_IS_HALF(inputs); + + CHECK_CUDA(weights); + CHECK_CONTIGUOUS(weights); + CHECK_IS_HALF(weights); + + Activation activation = convert_activation(activation_); + Activation output_activation = convert_activation(output_activation_); + + auto inputs_ptr = reinterpret_cast<__half*>(inputs.data_ptr()); + auto weights_ptr = reinterpret_cast<__half*>(weights.data_ptr()); + auto inference_buffer_ptr = reinterpret_cast<__half*>(inference_buffer.data_ptr()); + auto outputs_ptr = reinterpret_cast<__half*>(outputs.data_ptr()); + + switch (hidden_dim) { + case 16: ffmlp_forward_cuda<16, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; + case 32: ffmlp_forward_cuda<32, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; + case 64: ffmlp_forward_cuda<64, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; + case 128: ffmlp_forward_cuda<128, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; + case 256: ffmlp_forward_cuda<256, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; + default: throw std::runtime_error{"hidden_dim should in [16, 32, 64, 128, 256]"}; + } + + // for output_dim > 16 + if (output_dim > 16) { + fc_multiply(0, + output_dim, hidden_dim, B, + (weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // row-major, [output_dim, hidden_dim] + inference_buffer_ptr, // col-major [hidden_dim, B] + outputs_ptr, // col-major [outupt_dim, B] + output_activation + ); + } +} + +inline std::vector& streams_splitk() { + static std::vector res; + return res; +} + +inline std::vector& events_splitk() { + static std::vector res; + return res; +} + +void allocate_splitk(size_t size) { + auto& streams = streams_splitk(); + auto& events = events_splitk(); + streams.resize(size); + events.resize(size); + for (size_t i = 0; i < size; i++) { + CUDA_CHECK_THROW(cudaStreamCreate(&streams[i])); + CUDA_CHECK_THROW(cudaEventCreate(&events[i])); + } +} + +void free_splitk() { + auto& streams = streams_splitk(); + auto& events = events_splitk(); + for (size_t i = 0; i < streams.size(); i++) { + cutlass_free_workspace(streams[i]); + CUDA_CHECK_PRINT(cudaStreamDestroy(streams[i])); + CUDA_CHECK_PRINT(cudaEventDestroy(events[i])); + } +} + +// grad: col-major [output_dim, B] +// inputs: col-major [input_dim, B] +// weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim * (num_layers - 1)] + [output_dim * hidden_dim] +// forward_buffer: col-major [num_layers, hidden_dim, B] +// backward_buffer: col-major [num_layers, hidden_dim, B] +// grad_inputs: col-major [input_dim, B] +// grad_weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim * (num_layers - 1)] + [output_dim * hidden_dim] +void ffmlp_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor weights, const at::Tensor forward_buffer, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, const bool calc_grad_inputs, at::Tensor backward_buffer, at::Tensor grad_inputs, at::Tensor grad_weights) { + CHECK_CUDA(grad); + CHECK_CONTIGUOUS(grad); + CHECK_IS_HALF(grad); + + CHECK_CUDA(inputs); + CHECK_CONTIGUOUS(inputs); + CHECK_IS_HALF(inputs); + + CHECK_CUDA(weights); + CHECK_CONTIGUOUS(weights); + CHECK_IS_HALF(weights); + + CHECK_CUDA(forward_buffer); + CHECK_CONTIGUOUS(forward_buffer); + CHECK_IS_HALF(forward_buffer); + + CHECK_CUDA(backward_buffer); + CHECK_CONTIGUOUS(backward_buffer); + CHECK_IS_HALF(backward_buffer); + + CHECK_CUDA(grad_weights); + CHECK_CONTIGUOUS(grad_weights); + CHECK_IS_HALF(grad_weights); + + CHECK_CUDA(grad_inputs); + CHECK_CONTIGUOUS(grad_inputs); + CHECK_IS_HALF(grad_inputs); + + Activation activation = convert_activation(activation_); + Activation output_activation = convert_activation(output_activation_); + + // activation_backward_output_gpu (I gonna discard output_activation ...) + + int split_k_factor = B / std::min((uint32_t)(1 << 12), B); + + uint32_t forward_index = num_layers - 1; + uint32_t backward_index = 0; + + auto backward_buffer_ptr = reinterpret_cast<__half*>(backward_buffer.data_ptr()); + auto forward_buffer_ptr = reinterpret_cast<__half*>(forward_buffer.data_ptr()); + auto grad_ptr = reinterpret_cast<__half*>(grad.data_ptr()); + auto inputs_ptr = reinterpret_cast<__half*>(inputs.data_ptr()); + auto weights_ptr = reinterpret_cast<__half*>(weights.data_ptr()); + auto grad_weights_ptr = reinterpret_cast<__half*>(grad_weights.data_ptr()); + + auto grad_inputs_ptr = calc_grad_inputs ? reinterpret_cast<__half*>(grad_inputs.data_ptr()) : nullptr; + auto grad_inputs_fused_ptr = input_dim == hidden_dim ? grad_inputs_ptr : nullptr; + + + + // calc output layer, grad_weights + cudaEventRecord(events_splitk().at(backward_index), 0); + cudaStreamWaitEvent(streams_splitk().at(backward_index), events_splitk().at(backward_index), 0); + + fc_multiply_split_k(streams_splitk().at(backward_index), + output_dim, B, hidden_dim, + grad_ptr, // col-major, [output_dim, B] + (forward_buffer_ptr + forward_index * hidden_dim * B), // row-major, [B, hidden_dim] + (grad_weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // row-major, [output_dim, hidden_dim] + split_k_factor + ); + + cudaEventRecord(events_splitk().at(backward_index), streams_splitk().at(backward_index)); + + // prepare the last backward_buffer if output_dim > 16 + if (output_dim > 16) { + fc_multiply(0, + hidden_dim, output_dim, B, + (grad_weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // col-major, [hidden_dim, output_dim] + grad_ptr, // col-major, [output_dim, B] + (forward_buffer_ptr + forward_index * hidden_dim * B), // col-major, [hidden_dim, B] + (backward_buffer_ptr + backward_index * hidden_dim * B), // col-major [hidden_dim, B] + activation, + true + ); + } + + // prepare backward_buffer + // calc grad_inputs if input_dim == hidden_dim + switch (hidden_dim) { + case 16: ffmlp_backward_cuda<16>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; + case 32: ffmlp_backward_cuda<32>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; + case 64: ffmlp_backward_cuda<64>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; + case 128: ffmlp_backward_cuda<128>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; + case 256: ffmlp_backward_cuda<256>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; + default: throw std::runtime_error{"hidden_dim should in [16, 32, 64, 128, 256]"}; + } + + //printf("[backward] finished backward kernel\n"); + + forward_index--; + backward_index++; + + // calc middle layer's grad_weights + for (uint32_t i = 0; i < num_layers - 1; i++) { + + uint32_t matrix_index = num_layers - 2 - i; + + cudaEventRecord(events_splitk().at(backward_index), 0); + cudaStreamWaitEvent(streams_splitk().at(backward_index), events_splitk().at(backward_index), 0); + + fc_multiply_split_k(streams_splitk().at(backward_index), + hidden_dim, B, hidden_dim, + (backward_buffer_ptr + (backward_index - 1) * hidden_dim * B), // col-major [hidden_dim, B] + (forward_buffer_ptr + forward_index * hidden_dim * B), // row-major [B, hidden_dim] + (grad_weights_ptr + hidden_dim * input_dim + matrix_index * hidden_dim * hidden_dim), // row-major, [hidden_dim, hidden_dim] + split_k_factor + ); + + cudaEventRecord(events_splitk().at(backward_index), streams_splitk().at(backward_index)); + + forward_index--; + backward_index++; + } + + // calc input layer's grad_weights + cudaEventRecord(events_splitk().at(backward_index), 0); + cudaStreamWaitEvent(streams_splitk().at(backward_index), events_splitk().at(backward_index), 0); + + fc_multiply_split_k(streams_splitk().at(backward_index), + hidden_dim, B, input_dim, + (backward_buffer_ptr + (backward_index - 1) * hidden_dim * B), // col-major [hidden_dim, B] + inputs_ptr, // row-major, [B, input_dim] + grad_weights_ptr, // row-major, [hidden_dim, input_dim] + split_k_factor + ); + + cudaEventRecord(events_splitk().at(backward_index), streams_splitk().at(backward_index)); + + // calc grad_inputs if input_dim != hidden_dim + if (calc_grad_inputs && grad_inputs_fused_ptr == nullptr) { + fc_multiply(0, + input_dim, hidden_dim, B, + weights_ptr, // col-major [input_dim, hidden_dim] + (backward_buffer_ptr + (backward_index - 1) * hidden_dim * B), // col-major [hidden_dim, B] + grad_inputs_ptr // col-major [input_dim, B] + ); + } + + // All the per-layer split-k matrix multiplications summing over + // the batch are computed in parallel streams to the actual + // backpropagation. Here, we need to wait for all of these to complete. + for (auto& event : events_splitk()) { + cudaStreamWaitEvent(0, event, 0); + } +} \ No newline at end of file diff --git a/torch-ngp/ffmlp/src/ffmlp.h b/torch-ngp/ffmlp/src/ffmlp.h new file mode 100644 index 0000000000000000000000000000000000000000..0791d1ad3f7128b96d4968e2036028ba1d5ad1fc --- /dev/null +++ b/torch-ngp/ffmlp/src/ffmlp.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + + +// activation: should have been enum, here we just use int. +void ffmlp_forward(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor forward_buffer, at::Tensor outputs); +void ffmlp_inference(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor inference_buffer, at::Tensor outputs); + +void ffmlp_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor weights, const at::Tensor forward_buffer, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation, const uint32_t output_activation, const bool calc_grad_inputs, at::Tensor backward_buffer, at::Tensor grad_inputs, at::Tensor grad_weights); + +void allocate_splitk(size_t size); +void free_splitk(); \ No newline at end of file diff --git a/torch-ngp/ffmlp/src/utils.h b/torch-ngp/ffmlp/src/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6e5cfb2163bda939b5e337f19333753c4d875377 --- /dev/null +++ b/torch-ngp/ffmlp/src/utils.h @@ -0,0 +1,589 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") +#define CHECK_IS_HALF(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Half, #x " must be a Half tensor") + +static constexpr uint32_t MIN_GPU_ARCH = 70; + +using network_precision_t = __half; + +enum class Activation { + ReLU, + Exponential, + Sine, + Sigmoid, + Squareplus, + Softplus, + None, +}; + +static constexpr float PI = 3.14159265358979323846f; +static constexpr float SQRT2 = 1.41421356237309504880f; +static constexpr float K_ACT = 10.0f; + +__host__ __device__ inline float logistic(const float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__host__ __device__ inline float logit(const float x) { + return -logf(1.0f / (fminf(fmaxf(x, 1e-9f), 1.0f - 1e-9f)) - 1.0f); +} + +inline std::atomic& total_n_bytes_allocated() { + static std::atomic s_total_n_bytes_allocated{0}; + return s_total_n_bytes_allocated; +} + +/// Checks the result of a cudaXXXXXX call and throws an error on failure +#define CUDA_CHECK_THROW(x) \ + do { \ + cudaError_t result = x; \ + if (result != cudaSuccess) \ + throw std::runtime_error(std::string("CUDA Error: " #x " failed with error ") + cudaGetErrorString(result)); \ + } while(0) + +/// Checks the result of a cudaXXXXXX call and prints an error on failure +#define CUDA_CHECK_PRINT(x) \ + do { \ + cudaError_t result = x; \ + if (result != cudaSuccess) \ + std::cout << "CUDA Error: " #x " failed with error " << cudaGetErrorString(result) << std::endl; \ + } while(0) + +#define DEBUG_GUARD_SIZE 0 + +/// Managed memory on the Device +template +class GPUMemory { +private: + T* m_data = nullptr; + size_t m_size = 0; // Number of elements + bool m_owned = true; + +public: + GPUMemory() {} + + GPUMemory& operator=(GPUMemory&& other) { + std::swap(m_data, other.m_data); + std::swap(m_size, other.m_size); + return *this; + } + + GPUMemory(GPUMemory&& other) { + *this = std::move(other); + } + + __host__ __device__ GPUMemory(const GPUMemory &other) : m_data{other.m_data}, m_size{other.m_size}, m_owned{false} {} + + void check_guards() const { +#if DEBUG_GUARD_SIZE > 0 + if (!m_data) + return; + uint8_t buf[DEBUG_GUARD_SIZE]; + const uint8_t *rawptr=(const uint8_t *)m_data; + cudaMemcpy(buf, rawptr-DEBUG_GUARD_SIZE, DEBUG_GUARD_SIZE, cudaMemcpyDeviceToHost); + for (int i=0;i 0 + CUDA_CHECK_THROW(cudaMemset(rawptr , 0xff, DEBUG_GUARD_SIZE)); + CUDA_CHECK_THROW(cudaMemset(rawptr+n_bytes+DEBUG_GUARD_SIZE , 0xfe, DEBUG_GUARD_SIZE)); +#endif + if (rawptr) rawptr+=DEBUG_GUARD_SIZE; + m_data=(T*)(rawptr); + total_n_bytes_allocated() += n_bytes; + } + + void free_memory() { + if (!m_data) { + return; + } + + uint8_t *rawptr = (uint8_t*)m_data; + if (rawptr) rawptr-=DEBUG_GUARD_SIZE; + CUDA_CHECK_THROW(cudaFree(rawptr)); + + total_n_bytes_allocated() -= get_bytes(); + + m_data = nullptr; + } + + /// Allocates memory for size items of type T + GPUMemory(const size_t size) { + resize(size); + } + + /// Frees memory again + __host__ __device__ ~GPUMemory() { +#ifndef __CUDA_ARCH__ + if (!m_owned) { + return; + } + + try { + if (m_data) { + free_memory(); + m_size = 0; + } + } catch (std::runtime_error error) { + // Don't need to report on memory-free problems when the driver is shutting down. + if (std::string{error.what()}.find("driver shutting down") == std::string::npos) { + fprintf(stderr, "Could not free memory: %s\n", error.what()); + } + } +#endif + } + + /** @name Resizing/enlargement + * @{ + */ + /// Resizes the array to the exact new size, even if it is already larger + void resize(const size_t size) { + if (!m_owned) { + throw std::runtime_error("Cannot resize non-owned memory."); + } + + if (m_size != size) { + if (m_size) { + try { + free_memory(); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not free memory: ") + error.what()); + } + } + + if (size > 0) { + try { + allocate_memory(size * sizeof(T)); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not allocate memory: ") + error.what()); + } + } + + m_size = size; + } + } + + /// Enlarges the array if its size is smaller + void enlarge(const size_t size) { + if (size > m_size) { + resize(size); + } + } + /** @} */ + + /** @name Memset + * @{ + */ + /// Sets the memory of the first num_elements to value + void memset(const int value, const size_t num_elements, const size_t offset = 0) { + if (num_elements + offset > m_size) { + throw std::runtime_error("Could not set memory: Number of elements larger than allocated memory"); + } + + try { + CUDA_CHECK_THROW(cudaMemset(m_data + offset, value, num_elements * sizeof(T))); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not set memory: ") + error.what()); + } + } + + /// Sets the memory of the all elements to value + void memset(const int value) { + memset(value, m_size); + } + /** @} */ + + /** @name Copy operations + * @{ + */ + /// Copy data of num_elements from the raw pointer on the host + void copy_from_host(const T* host_data, const size_t num_elements) { + try { + CUDA_CHECK_THROW(cudaMemcpy(data(), host_data, num_elements * sizeof(T), cudaMemcpyHostToDevice)); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not copy from host: ") + error.what()); + } + } + + /// Copy num_elements from the host vector + void copy_from_host(const std::vector& data, const size_t num_elements) { + if (data.size() < num_elements) { + throw std::runtime_error(std::string("Trying to copy ") + std::to_string(num_elements) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); + } + copy_from_host(data.data(), num_elements); + } + + /// Copies data from the raw host pointer to fill the entire array + void copy_from_host(const T* data) { + copy_from_host(data, m_size); + } + + /// Copies num_elements of data from the raw host pointer after enlarging the array so that everything fits in + void enlarge_and_copy_from_host(const T* data, const size_t num_elements) { + enlarge(num_elements); + copy_from_host(data, num_elements); + } + + /// Copies num_elements from the host vector after enlarging the array so that everything fits in + void enlarge_and_copy_from_host(const std::vector& data, const size_t num_elements) { + enlarge_and_copy_from_host(data.data(), num_elements); + } + + /// Copies the entire host vector after enlarging the array so that everything fits in + void enlarge_and_copy_from_host(const std::vector& data) { + enlarge_and_copy_from_host(data.data(), data.size()); + } + + /// Copies num_elements of data from the raw host pointer after resizing the array + void resize_and_copy_from_host(const T* data, const size_t num_elements) { + resize(num_elements); + copy_from_host(data, num_elements); + } + + /// Copies num_elements from the host vector after resizing the array + void resize_and_copy_from_host(const std::vector& data, const size_t num_elements) { + resize_and_copy_from_host(data.data(), num_elements); + } + + /// Copies the entire host vector after resizing the array + void resize_and_copy_from_host(const std::vector& data) { + resize_and_copy_from_host(data.data(), data.size()); + } + + /// Copies the entire host vector to the device. Fails if there is not enough space available. + void copy_from_host(const std::vector& data) { + if (data.size() < m_size) { + throw std::runtime_error(std::string("Trying to copy ") + std::to_string(m_size) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); + } + copy_from_host(data.data(), m_size); + } + + /// Copies num_elements of data from the raw host pointer to the device. Fails if there is not enough space available. + void copy_to_host(T* host_data, const size_t num_elements) const { + if (num_elements > m_size) { + throw std::runtime_error(std::string("Trying to copy ") + std::to_string(num_elements) + std::string(" elements, but vector size is only ") + std::to_string(m_size)); + } + try { + CUDA_CHECK_THROW(cudaMemcpy(host_data, data(), num_elements * sizeof(T), cudaMemcpyDeviceToHost)); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not copy to host: ") + error.what()); + } + } + + /// Copies num_elements from the device to a vector on the host + void copy_to_host(std::vector& data, const size_t num_elements) const { + if (data.size() < num_elements) { + throw std::runtime_error(std::string("Trying to copy ") + std::to_string(num_elements) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); + } + copy_to_host(data.data(), num_elements); + } + + /// Copies num_elements from the device to a raw pointer on the host + void copy_to_host(T* data) const { + copy_to_host(data, m_size); + } + + /// Copies all elements from the device to a vector on the host + void copy_to_host(std::vector& data) const { + if (data.size() < m_size) { + throw std::runtime_error(std::string("Trying to copy ") + std::to_string(m_size) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); + } + copy_to_host(data.data(), m_size); + } + + /// Copies data from another device array to this one, automatically resizing it + void copy_from_device(const GPUMemory &other) { + if (m_size != other.m_size) { + resize(other.m_size); + } + + try { + CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, m_size * sizeof(T), cudaMemcpyDeviceToDevice)); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not copy from device: ") + error.what()); + } + } + + /// Copies size elements from another device array to this one, automatically resizing it + void copy_from_device(const GPUMemory &other, const size_t size) { + if (m_size < size) { + resize(size); + } + + try { + CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, size * sizeof(T), cudaMemcpyDeviceToDevice)); + } catch (std::runtime_error error) { + throw std::runtime_error(std::string("Could not copy from device: ") + error.what()); + } + } + + // Created an (owned) copy of the data + GPUMemory copy() const { + GPUMemory result{m_size}; + result.copy_from_device(*this); + return result; + } + + T* data() const { + check_guards(); + return m_data; + } + + __host__ __device__ T& operator[](size_t idx) const { +#ifdef DEBUG_BUFFER_OVERRUN + if (idx > m_size) { + printf("WARNING: buffer overrun of %p at idx %zu\n", idx); + } +#endif + return m_data[idx]; + } + + __host__ __device__ T& operator[](uint32_t idx) const { +#ifdef DEBUG_BUFFER_OVERRUN + if (idx > m_size) { + printf("WARNING: buffer overrun of %p at idx %u\n", idx); + } +#endif + return m_data[idx]; + } + + size_t get_num_elements() const { + return m_size; + } + + size_t size() const { + return get_num_elements(); + } + + size_t get_bytes() const { + return m_size * sizeof(T); + } + + size_t bytes() const { + return get_bytes(); + } +}; + + +inline std::string bytes_to_string(size_t bytes) { + std::array suffixes = {{ "B", "KB", "MB", "GB", "TB", "PB", "EB" }}; + + double count = (double)bytes; + uint32_t i = 0; + for (; i < suffixes.size() && count >= 1024; ++i) { + count /= 1024; + } + + std::ostringstream oss; + oss.precision(3); + oss << count << " " << suffixes[i]; + return oss.str(); +} + + +template +__host__ __device__ void warp_activation(Activation activation, const fragment_t& frag, fragment_t& result) { + switch (activation) { + case Activation::ReLU: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f); + } + return; + case Activation::Exponential: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(expf((float)frag.x[t])); + } + return; + case Activation::Sine: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(sinf((float)frag.x[t])); + } + return; + case Activation::Sigmoid: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(logistic((float)frag.x[t])); + } + return; + case Activation::Squareplus: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + float x = (float)frag.x[t] * K_ACT; + result.x[t] = (T)(0.5f * (x + sqrtf(x * x + 4)) / K_ACT); + } + return; + case Activation::Softplus: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = (T)(logf(expf((float)frag.x[t] * K_ACT) + 1.0f) / K_ACT); + } + return; + case Activation::None: result = frag; return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } +} + +template +__host__ __device__ fragment_t warp_activation(Activation activation, const fragment_t& frag) { + fragment_t result; + warp_activation(activation, frag, result); + return result; +} + +template +__host__ __device__ void warp_activation_backward_in(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag_in, fragment_t& result) { + switch (activation) { + case Activation::ReLU: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag_in.x[t] > (T)0.0f); + } + return; + case Activation::Exponential: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(expf(forward_frag_in.x[t])); + } + return; + case Activation::Sine: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(cosf(forward_frag_in.x[t])); + } + return; + case Activation::Sigmoid: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + float x = logistic(forward_frag_in.x[t]); + result.x[t] = frag.x[t] * (T)(x * (1.0f - x)); + } + return; + case Activation::Squareplus: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + float x = (float)forward_frag_in.x[t] * K_ACT; + float y = 0.5f * (x + sqrtf(x * x + 4)); + result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1)); + } + return; + case Activation::Softplus: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + float tmp = expf((float)frag.x[t] * K_ACT); + result.x[t] = frag.x[t] * (T)(tmp / (tmp + 1)); + } + return; + case Activation::None: result = frag; return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } +} + +template +__host__ __device__ fragment_t warp_activation_backward_in(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag_in) { + fragment_t result; + warp_activation_backward_in(activation, frag, forward_frag_in, result); + return result; +} + +template +__host__ __device__ void warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag, fragment_t& result) { + switch (activation) { + case Activation::ReLU: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f); + } + return; + case Activation::Exponential: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * forward_frag.x[t]; + } + return; + case Activation::Sine: + // Sine requires stored pre-activations, which we don't have. We only + // write out the post-activations. + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + case Activation::Sigmoid: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag.x[t] * ((T)1.0f - forward_frag.x[t])); + } + return; + case Activation::Squareplus: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + float y = (float)forward_frag.x[t] * K_ACT; + result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1)); + } + return; + case Activation::Softplus: + #pragma unroll + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(1.0f - expf(-(float)forward_frag.x[t] * K_ACT)); + } + return; + case Activation::None: result = frag; return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } +} + +template +__host__ __device__ fragment_t warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag) { + fragment_t result; + warp_activation_backward(activation, frag, forward_frag, result); + return result; +} \ No newline at end of file diff --git a/torch-ngp/freqencoder/__init__.py b/torch-ngp/freqencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69ec49cf6e87b396bfe9730fe7580097c0670a6c --- /dev/null +++ b/torch-ngp/freqencoder/__init__.py @@ -0,0 +1 @@ +from .freq import FreqEncoder \ No newline at end of file diff --git a/torch-ngp/freqencoder/backend.py b/torch-ngp/freqencoder/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd9131a779582dcd5261ed3789911ea29a0353c --- /dev/null +++ b/torch-ngp/freqencoder/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_freqencoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/torch-ngp/freqencoder/freq.py b/torch-ngp/freqencoder/freq.py new file mode 100644 index 0000000000000000000000000000000000000000..5cba1e660f339ffde62b6b2aac6013d6e6795f0d --- /dev/null +++ b/torch-ngp/freqencoder/freq.py @@ -0,0 +1,77 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _freqencoder as _backend +except ImportError: + from .backend import _backend + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/torch-ngp/freqencoder/setup.py b/torch-ngp/freqencoder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb4af77cee88d99f5930e60d3e5301a2d94a056 --- /dev/null +++ b/torch-ngp/freqencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='freqencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_freqencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/torch-ngp/freqencoder/src/bindings.cpp b/torch-ngp/freqencoder/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb5f285a97d9ff6426add0f9e55c7781fbd313e6 --- /dev/null +++ b/torch-ngp/freqencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "freqencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); +} \ No newline at end of file diff --git a/torch-ngp/freqencoder/src/freqencoder.cu b/torch-ngp/freqencoder/src/freqencoder.cu new file mode 100644 index 0000000000000000000000000000000000000000..072da7499932231ea49da6cf05964162e22f704b --- /dev/null +++ b/torch-ngp/freqencoder/src/freqencoder.cu @@ -0,0 +1,129 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + +inline constexpr __device__ float PI() { return 3.141592653589793f; } + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +// inputs: [B, D] +// outputs: [B, C], C = D + D * deg * 2 +__global__ void kernel_freq( + const float * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * outputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * C) return; + + // get index + const uint32_t b = t / C; + const uint32_t c = t - b * C; // t % C; + + // locate + inputs += b * D; + outputs += t; + + // write self + if (c < D) { + outputs[0] = inputs[c]; + // write freq + } else { + const uint32_t col = c / D - 1; + const uint32_t d = c % D; + const uint32_t freq = col / 2; + const float phase_shift = (col % 2) * (PI() / 2); + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); + } +} + +// grad: [B, C], C = D + D * deg * 2 +// outputs: [B, C] +// grad_inputs: [B, D] +__global__ void kernel_freq_backward( + const float * __restrict__ grad, + const float * __restrict__ outputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * grad_inputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; // t % D; + + // locate + grad += b * C; + outputs += b * C; + grad_inputs += t; + + // register + float result = grad[d]; + grad += D; + outputs += D; + + for (uint32_t f = 0; f < deg; f++) { + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); + grad += 2 * D; + outputs += 2 * D; + } + + // write + grad_inputs[0] = result; +} + + +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); +} + + +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(outputs); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(grad_inputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); +} \ No newline at end of file diff --git a/torch-ngp/freqencoder/src/freqencoder.h b/torch-ngp/freqencoder/src/freqencoder.h new file mode 100644 index 0000000000000000000000000000000000000000..34f28c79469b0ba639c742bf8697ca0515563b1d --- /dev/null +++ b/torch-ngp/freqencoder/src/freqencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); + +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); \ No newline at end of file diff --git a/torch-ngp/gridencoder/__init__.py b/torch-ngp/gridencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1476cef5314e0918b963d1ac64ee0613a7743d5 --- /dev/null +++ b/torch-ngp/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/torch-ngp/gridencoder/backend.py b/torch-ngp/gridencoder/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..d99acb1f4353786e16468948780f377008d94872 --- /dev/null +++ b/torch-ngp/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/torch-ngp/gridencoder/grid.py b/torch-ngp/gridencoder/grid.py new file mode 100644 index 0000000000000000000000000000000000000000..32b8bead0d9d0575b0988302afb1794dffcfe72d --- /dev/null +++ b/torch-ngp/gridencoder/grid.py @@ -0,0 +1,185 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) \ No newline at end of file diff --git a/torch-ngp/gridencoder/setup.py b/torch-ngp/gridencoder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..714bf1cad7880fe25dca319414748c15e86cc48e --- /dev/null +++ b/torch-ngp/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/torch-ngp/gridencoder/src/bindings.cpp b/torch-ngp/gridencoder/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93dea943c939cffc7ec73c76410aeff7afddc1f9 --- /dev/null +++ b/torch-ngp/gridencoder/src/bindings.cpp @@ -0,0 +1,9 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); +} \ No newline at end of file diff --git a/torch-ngp/gridencoder/src/gridencoder.cu b/torch-ngp/gridencoder/src/gridencoder.cu new file mode 100644 index 0000000000000000000000000000000000000000..cba5e94f5f4ca6b728bc9006c79e80cb0fce62dd --- /dev/null +++ b/torch-ngp/gridencoder/src/gridencoder.cu @@ -0,0 +1,645 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { + return min(max(v, lo), hi); +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } else { + pos_deriv[d] = 1.0f; // linear deriv is default to 1 + } + + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + // pos[d] -= (float)pos_grid[d]; // not used + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} \ No newline at end of file diff --git a/torch-ngp/gridencoder/src/gridencoder.h b/torch-ngp/gridencoder/src/gridencoder.h new file mode 100644 index 0000000000000000000000000000000000000000..1b385755d13711b04df4866dd654e88b48054554 --- /dev/null +++ b/torch-ngp/gridencoder/src/gridencoder.h @@ -0,0 +1,17 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/torch-ngp/loss.py b/torch-ngp/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..68c665a637ae44701ebe6c4d381c9fb58f6eafff --- /dev/null +++ b/torch-ngp/loss.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +def mape_loss(pred, target, reduction='mean'): + # pred, target: [B, 1], torch tenspr + difference = (pred - target).abs() + scale = 1 / (target.abs() + 1e-2) + loss = difference * scale + + if reduction == 'mean': + loss = loss.mean() + + return loss + +def huber_loss(pred, target, delta=0.1, reduction='mean'): + rel = (pred - target).abs() + sqr = 0.5 / delta * rel * rel + loss = torch.where(rel > delta, rel - 0.5 * delta, sqr) + + if reduction == 'mean': + loss = loss.mean() + + return loss + + +# ref: https://github.com/sunset1995/torch_efficient_distloss/blob/main/torch_efficient_distloss/eff_distloss.py +class EffDistLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, w, m, interval): + ''' + Efficient O(N) realization of distortion loss. + There are B rays each with N sampled points. + w: Float tensor in shape [B,N]. Volume rendering weights of each point. + m: Float tensor in shape [B,N]. Midpoint distance to camera of each point. + interval: Scalar or float tensor in shape [B,N]. The query interval of each point. + ''' + n_rays = np.prod(w.shape[:-1]) + wm = (w * m) + w_cumsum = w.cumsum(dim=-1) + wm_cumsum = wm.cumsum(dim=-1) + + w_total = w_cumsum[..., [-1]] + wm_total = wm_cumsum[..., [-1]] + w_prefix = torch.cat([torch.zeros_like(w_total), w_cumsum[..., :-1]], dim=-1) + wm_prefix = torch.cat([torch.zeros_like(wm_total), wm_cumsum[..., :-1]], dim=-1) + loss_uni = (1/3) * interval * w.pow(2) + loss_bi = 2 * w * (m * w_prefix - wm_prefix) + if torch.is_tensor(interval): + ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval) + ctx.interval = None + else: + ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total) + ctx.interval = interval + ctx.n_rays = n_rays + return (loss_bi.sum() + loss_uni.sum()) / n_rays + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_back): + interval = ctx.interval + n_rays = ctx.n_rays + if interval is None: + w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval = ctx.saved_tensors + else: + w, m, wm, w_prefix, w_total, wm_prefix, wm_total = ctx.saved_tensors + grad_uni = (1/3) * interval * 2 * w + w_suffix = w_total - (w_prefix + w) + wm_suffix = wm_total - (wm_prefix + wm) + grad_bi = 2 * (m * (w_prefix - w_suffix) + (wm_suffix - wm_prefix)) + grad = grad_back * (grad_bi + grad_uni) / n_rays + return grad, None, None, None + +eff_distloss = EffDistLoss.apply diff --git a/torch-ngp/main_CCNeRF.py b/torch-ngp/main_CCNeRF.py new file mode 100644 index 0000000000000000000000000000000000000000..09d89b2ed871abb6d275bb71f5c0fe3936536a96 --- /dev/null +++ b/torch-ngp/main_CCNeRF.py @@ -0,0 +1,228 @@ +import torch +import argparse + +from nerf.provider import NeRFDataset +from nerf.gui import NeRFGUI +from tensoRF.utils import * + +from scipy.spatial.transform import Rotation as Rot + +#torch.autograd.set_detect_anomaly(True) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--compose', action='store_true', help="compose mode") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', type=int, default=0) + ### training options + parser.add_argument('--iters', type=int, default=30000, help="training iters") + parser.add_argument('--lr0', type=float, default=2e-2, help="initial learning rate for embeddings") + parser.add_argument('--lr1', type=float, default=1e-3, help="initial learning rate for networks") + parser.add_argument('--ckpt', type=str, default='latest') + parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--l1_reg_weight', type=float, default=1e-5) + + ### network backbone options + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") + parser.add_argument('--resolution0', type=int, default=128) + parser.add_argument('--resolution1', type=int, default=300) + parser.add_argument("--upsample_model_steps", type=int, action="append", default=[2000, 3000, 4000, 5500, 7000]) + + ### dataset options + parser.add_argument('--color_space', type=str, default='linear', help="Color space, supports (linear, srgb)") + parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3") + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") + parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)") + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=1920, help="GUI width") + parser.add_argument('--H', type=int, default=1080, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") + parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel") + + ### experimental + parser.add_argument('--error_map', action='store_true', help="use error map to sample rays") + parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance") + parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses") + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.cuda_ray = True + opt.preload = True + + if opt.patch_size > 1: + opt.error_map = False # do not use error_map if use patch-based training + # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." + assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." + + print(opt) + seed_everything(opt.seed) + + assert opt.cuda_ray, 'CCNeRF only supports CUDA raymarching mode for now.' + + from tensoRF.network_cc import NeRFNetwork as CCNeRF + + criterion = torch.nn.MSELoss(reduction='none') + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # compose mode + if opt.compose: + + # init an empty scene. (necessary!) + model = CCNeRF( + rank_vec_density=[1], + rank_mat_density=[1], + rank_vec=[1], + rank_mat=[1], + resolution=[1] * 3, # fake resolution + bound=opt.bound, # a large bound is needed + cuda_ray=opt.cuda_ray, + density_scale=1, + min_near=opt.min_near, + density_thresh=opt.density_thresh, + bg_radius=opt.bg_radius, + ).to(device) + + # helper function to load a single model + def load_model(path): + checkpoint_dict = torch.load(path, map_location=device) + model = CCNeRF( + rank_vec_density=checkpoint_dict['rank_vec_density'], + rank_mat_density=checkpoint_dict['rank_mat_density'], + rank_vec=checkpoint_dict['rank_vec'], + rank_mat=checkpoint_dict['rank_mat'], + resolution=checkpoint_dict['resolution'], + bound=opt.bound, + cuda_ray=opt.cuda_ray, + density_scale=1, + min_near=opt.min_near, + density_thresh=opt.density_thresh, + bg_radius=opt.bg_radius, + ).to(device) + + model.load_state_dict(checkpoint_dict['model'], strict=False) + return model + + # compose example + hotdog = load_model('trial_cc_hotdog/checkpoints/64_16-64_64.pth') + chair = load_model('trial_cc_chair/checkpoints/64_16-64_64.pth') + ficus = load_model('trial_cc_ficus/checkpoints/64_16-64_64.pth') + + model.compose(hotdog, s=0.4, t=np.array([0, 0.2, 0])) + model.compose(ficus, s=0.6, t=np.array([0, 0, -0.5]), R=Rot.from_euler('zyx', [0, 0, 0], degrees=True).as_matrix()) + model.compose(chair, s=0.6, t=np.array([0, 0, 0.5]), R=Rot.from_euler('zyx', [0, -90, 0], degrees=True).as_matrix()) + model.compose(chair, s=0.6, t=np.array([-0.5, 0, 0]), R=Rot.from_euler('zyx', [0, 180, 0], degrees=True).as_matrix()) + model.compose(chair, s=0.6, t=np.array([0.5, 0, 0]), R=Rot.from_euler('zyx', [0, 0, 0], degrees=True).as_matrix()) + + # tell trainer not to load ckpt again + opt.ckpt = 'scratch' + + + # single model mode + else: + model = CCNeRF( + resolution=[opt.resolution0] * 3, + bound=opt.bound, + cuda_ray=opt.cuda_ray, + density_scale=1, + min_near=opt.min_near, + density_thresh=opt.density_thresh, + bg_radius=opt.bg_radius, + ).to(device) + + print(model) + + if opt.test: + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt) + + if opt.gui: + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + # compose mode have no gt, do not evaulate + if opt.compose: + trainer.test(test_loader, save_path=os.path.join(opt.workspace, 'compose')) + elif test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) + + #trainer.save_mesh(resolution=256, threshold=0.1) + + else: + + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr0, opt.lr1), betas=(0.9, 0.99), eps=1e-15) + + train_loader = NeRFDataset(opt, device=device, type='train').dataloader() + + # decay to 0.1 * init_lr at last iter step + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, eval_interval=50) + + # calc upsample target resolutions + upsample_resolutions = (np.round(np.exp(np.linspace(np.log(opt.resolution0), np.log(opt.resolution1), len(opt.upsample_model_steps) + 1)))).astype(np.int32).tolist()[1:] + print('upsample_resolutions:', upsample_resolutions) + trainer.upsample_resolutions = upsample_resolutions + + if opt.gui: + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader() + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, max_epoch) + + # also test + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + # save and test at multiple compression levels + K = model.K[0] + rank_vec_density = model.rank_vec_density[0][::-1] + rank_mat_density = model.rank_mat_density[0][::-1] + rank_vec = model.rank_vec[0][::-1] + rank_mat = model.rank_mat[0][::-1] + + model.finalize() + print(f'[INFO] ===== finalized model =====') + print(model) + + for k in range(K): + model.compress((rank_vec_density[k], rank_mat_density[k], rank_vec[k], rank_mat[k])) + name = f'{rank_vec_density[k]}_{rank_mat_density[k]}-{rank_vec[k]}_{rank_mat[k]}' + print(f'[INFO] ===== compressed at {name} =====') + print(model) + trainer.save_checkpoint(name, full=False, remove_old=False) + + if test_loader.has_gt: + trainer.evaluate(test_loader, name=name) # blender has gt, so evaluate it. + + trainer.test(test_loader, name=name) # test and save video + diff --git a/torch-ngp/main_dnerf.py b/torch-ngp/main_dnerf.py new file mode 100644 index 0000000000000000000000000000000000000000..52320036bdb0817071687d5288888fc33a3b163c --- /dev/null +++ b/torch-ngp/main_dnerf.py @@ -0,0 +1,156 @@ +import torch +import argparse + +from dnerf.provider import NeRFDataset +from dnerf.gui import NeRFGUI +from dnerf.utils import * + +from functools import partial +from loss import huber_loss + +#torch.autograd.set_detect_anomaly(True) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', type=int, default=0) + + ### training options + parser.add_argument('--iters', type=int, default=30000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate") + parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate") + parser.add_argument('--ckpt', type=str, default='latest') + parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=100, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=128, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + + ### network backbone options + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") + parser.add_argument('--basis', action='store_true', help="[experimental] use temporal basis instead of deformation to model dynamic scene (check Fourier PlenOctree and NeuVV)") + parser.add_argument('--hyper', action='store_true', help="[experimental] use hyper-nerf like ambient dim instead of deformation to model dynamic scene") + # parser.add_argument('--ff', action='store_true', help="use fully-fused MLP") + # parser.add_argument('--tcnn', action='store_true', help="use TCNN backend") + + ### dataset options + parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") + parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") + # (the default value is for the fox dataset) + parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3") + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") + parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=1920, help="GUI width") + parser.add_argument('--H', type=int, default=1080, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") + parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel") + + ### experimental + parser.add_argument('--error_map', action='store_true', help="use error map to sample rays") + parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance") + parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses") + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.cuda_ray = True + opt.preload = True + + if opt.patch_size > 1: + opt.error_map = False # do not use error_map if use patch-based training + # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." + assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." + + if opt.basis: + assert opt.cuda_ray, "Non-cuda-ray mode is temporarily broken with temporal basis mode" + from dnerf.network_basis import NeRFNetwork + elif opt.hyper: + from dnerf.network_hyper import NeRFNetwork + else: + from dnerf.network import NeRFNetwork + + print(opt) + + seed_everything(opt.seed) + + model = NeRFNetwork( + bound=opt.bound, + cuda_ray=opt.cuda_ray, + density_scale=1, + min_near=opt.min_near, + density_thresh=opt.density_thresh, + bg_radius=opt.bg_radius, + ) + + print(model) + + criterion = torch.nn.MSELoss(reduction='none') + #criterion = partial(huber_loss, reduction='none') + #criterion = torch.nn.HuberLoss(reduction='none', beta=0.1) # only available after torch 1.10 ? + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if opt.test: + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt) + + if opt.gui: + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) # test and save video + + #trainer.save_mesh(resolution=256, threshold=10) + + else: + + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr, opt.lr_net), betas=(0.9, 0.99), eps=1e-15) + + train_loader = NeRFDataset(opt, device=device, type='train').dataloader() + + # decay to 0.1 * init_lr at last iter step + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, eval_interval=50) + + if opt.gui: + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader() + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, max_epoch) + + # also test + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) # test and save video + + #trainer.save_mesh(resolution=256, threshold=10) \ No newline at end of file diff --git a/torch-ngp/main_nerf.py b/torch-ngp/main_nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..a4af7ddcab98ff4f45193e4099a63c43f4d9d3c9 --- /dev/null +++ b/torch-ngp/main_nerf.py @@ -0,0 +1,160 @@ +import torch +import argparse + +from nerf.provider import NeRFDataset +from nerf.gui import NeRFGUI +from nerf.utils import * + +from functools import partial +from loss import huber_loss + +#torch.autograd.set_detect_anomaly(True) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', type=int, default=0) + + ### training options + parser.add_argument('--iters', type=int, default=30000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate") + parser.add_argument('--ckpt', type=str, default='latest') + parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + + ### network backbone options + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") + parser.add_argument('--ff', action='store_true', help="use fully-fused MLP") + parser.add_argument('--tcnn', action='store_true', help="use TCNN backend") + + ### dataset options + parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") + parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") + # (the default value is for the fox dataset) + parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3") + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") + parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=1920, help="GUI width") + parser.add_argument('--H', type=int, default=1080, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") + parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel") + + ### experimental + parser.add_argument('--error_map', action='store_true', help="use error map to sample rays") + parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance") + parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses") + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.cuda_ray = True + opt.preload = True + + if opt.patch_size > 1: + opt.error_map = False # do not use error_map if use patch-based training + # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." + assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." + + + if opt.ff: + opt.fp16 = True + assert opt.bg_radius <= 0, "background model is not implemented for --ff" + from nerf.network_ff import NeRFNetwork + elif opt.tcnn: + opt.fp16 = True + assert opt.bg_radius <= 0, "background model is not implemented for --tcnn" + from nerf.network_tcnn import NeRFNetwork + else: + from nerf.network import NeRFNetwork + + print(opt) + + seed_everything(opt.seed) + + model = NeRFNetwork( + encoding="hashgrid", + bound=opt.bound, + cuda_ray=opt.cuda_ray, + density_scale=1, + min_near=opt.min_near, + density_thresh=opt.density_thresh, + bg_radius=opt.bg_radius, + ) + + print(model) + + criterion = torch.nn.MSELoss(reduction='none') + #criterion = partial(huber_loss, reduction='none') + #criterion = torch.nn.HuberLoss(reduction='none', beta=0.1) # only available after torch 1.10 ? + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if opt.test: + + metrics = [PSNRMeter(), LPIPSMeter(device=device)] + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) + + if opt.gui: + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) # test and save video + + trainer.save_mesh(resolution=256, threshold=10) + + else: + + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + + train_loader = NeRFDataset(opt, device=device, type='train').dataloader() + + # decay to 0.1 * init_lr at last iter step + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + metrics = [PSNRMeter(), LPIPSMeter(device=device)] + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=50) + + if opt.gui: + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader() + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, max_epoch) + + # also test + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) # test and save video + + trainer.save_mesh(resolution=256, threshold=10) \ No newline at end of file diff --git a/torch-ngp/main_sdf.py b/torch-ngp/main_sdf.py new file mode 100644 index 0000000000000000000000000000000000000000..7de7f0886275bd18ca9a04e56bdce82626ce22a6 --- /dev/null +++ b/torch-ngp/main_sdf.py @@ -0,0 +1,63 @@ +import torch +import argparse + +from sdf.utils import * + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--lr', type=float, default=1e-4, help="initial learning rate") + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") + parser.add_argument('--ff', action='store_true', help="use fully-fused MLP") + parser.add_argument('--tcnn', action='store_true', help="use TCNN backend") + + opt = parser.parse_args() + print(opt) + + seed_everything(opt.seed) + + if opt.ff: + assert opt.fp16, "fully-fused mode must be used with fp16 mode" + from sdf.netowrk_ff import SDFNetwork + elif opt.tcnn: + assert opt.fp16, "tcnn mode must be used with fp16 mode" + from sdf.network_tcnn import SDFNetwork + else: + from sdf.netowrk import SDFNetwork + + model = SDFNetwork(encoding="hashgrid") + print(model) + + if opt.test: + trainer = Trainer('ngp', model, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint='best', eval_interval=1) + trainer.save_mesh(os.path.join(opt.workspace, 'results', 'output.ply'), 1024) + + else: + from sdf.provider import SDFDataset + from loss import mape_loss + + train_dataset = SDFDataset(opt.path, size=100, num_samples=2**18) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) + + valid_dataset = SDFDataset(opt.path, size=1, num_samples=2**18) # just a dummy + valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1) + + criterion = mape_loss # torch.nn.L1Loss() + + optimizer = lambda model: torch.optim.Adam([ + {'name': 'encoding', 'params': model.encoder.parameters()}, + {'name': 'net', 'params': model.backbone.parameters(), 'weight_decay': 1e-6}, + ], lr=opt.lr, betas=(0.9, 0.99), eps=1e-15) + + scheduler = lambda optimizer: optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) + + trainer = Trainer('ngp', model, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint='latest', eval_interval=1) + + trainer.train(train_loader, valid_loader, 20) + + # also test + trainer.save_mesh(os.path.join(opt.workspace, 'results', 'output.ply'), 1024) diff --git a/torch-ngp/main_tensoRF.py b/torch-ngp/main_tensoRF.py new file mode 100644 index 0000000000000000000000000000000000000000..37a094379d65fd5aa1f298b94ff9ab797a4fd7ac --- /dev/null +++ b/torch-ngp/main_tensoRF.py @@ -0,0 +1,154 @@ +import torch +import argparse + +from nerf.provider import NeRFDataset +from nerf.gui import NeRFGUI +from tensoRF.utils import * + +#torch.autograd.set_detect_anomaly(True) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', type=int, default=0) + ### training options + parser.add_argument('--iters', type=int, default=30000, help="training iters") + parser.add_argument('--lr0', type=float, default=2e-2, help="initial learning rate for embeddings") + parser.add_argument('--lr1', type=float, default=1e-3, help="initial learning rate for networks") + parser.add_argument('--ckpt', type=str, default='latest') + parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + parser.add_argument('--l1_reg_weight', type=float, default=1e-4) + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + + ### network backbone options + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") + parser.add_argument('--cp', action='store_true', help="use TensorCP") + parser.add_argument('--resolution0', type=int, default=128) + parser.add_argument('--resolution1', type=int, default=300) + parser.add_argument("--upsample_model_steps", type=int, action="append", default=[2000, 3000, 4000, 5500, 7000]) + + ### dataset options + parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") + parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") + # (the default value is for the fox dataset) + parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3") + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") + parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") + parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=1920, help="GUI width") + parser.add_argument('--H', type=int, default=1080, help="GUI height") + parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") + parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel") + + ### experimental + parser.add_argument('--error_map', action='store_true', help="use error map to sample rays") + parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance") + parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses") + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.cuda_ray = True + opt.preload = True + + if opt.patch_size > 1: + opt.error_map = False # do not use error_map if use patch-based training + # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." + assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." + + print(opt) + seed_everything(opt.seed) + + if opt.cp: + assert opt.bg_radius <= 0, "background model is not implemented for --cp" + from tensoRF.network_cp import NeRFNetwork + else: + from tensoRF.network import NeRFNetwork + + model = NeRFNetwork( + resolution=[opt.resolution0] * 3, + bound=opt.bound, + cuda_ray=opt.cuda_ray, + density_scale=1, + min_near=opt.min_near, + density_thresh=opt.density_thresh, + bg_radius=opt.bg_radius, + ) + + print(model) + + criterion = torch.nn.MSELoss(reduction='none') + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if opt.test: + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt) + + if opt.gui: + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) # test and save video + + #trainer.save_mesh(resolution=256, threshold=0.1) + + else: + + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr0, opt.lr1), betas=(0.9, 0.99), eps=1e-15) + + train_loader = NeRFDataset(opt, device=device, type='train').dataloader() + + # decay to 0.1 * init_lr at last iter step + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, eval_interval=50) + + # calc upsample target resolutions + upsample_resolutions = (np.round(np.exp(np.linspace(np.log(opt.resolution0), np.log(opt.resolution1), len(opt.upsample_model_steps) + 1)))).astype(np.int32).tolist()[1:] + print('upsample_resolutions:', upsample_resolutions) + trainer.upsample_resolutions = upsample_resolutions + + if opt.gui: + gui = NeRFGUI(opt, trainer, train_loader) + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader() + + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, max_epoch) + + # also test + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader, write_video=True) # test and save video + + #trainer.save_mesh(resolution=256, threshold=0.1) \ No newline at end of file diff --git a/torch-ngp/nerf/clip_utils.py b/torch-ngp/nerf/clip_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c195eb979028c489524a6840a62ad6fa3f155806 --- /dev/null +++ b/torch-ngp/nerf/clip_utils.py @@ -0,0 +1,64 @@ +import random +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +import clip + +class CLIPLoss: + def __init__(self, device, name='ViT-B/16'): + self.device = device + self.name = name + self.clip_model, self.transform_PIL = clip.load(self.name, device=self.device, jit=False) + + # disable training + self.clip_model.eval() + for p in self.clip_model.parameters(): + p.requires_grad = False + + # image augmentation + self.transform = T.Compose([ + T.Resize((224, 224)), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + # placeholder + self.text_zs = None + self.image_zs = None + + def normalize(self, x): + return x / x.norm(dim=-1, keepdim=True) + + # image-text (e.g., dreamfields) + def prepare_text(self, texts): + # texts: list of strings. + texts = clip.tokenize(texts).to(self.device) + self.text_zs = self.normalize(self.clip_model.encode_text(texts)) + print(f'[INFO] prepared CLIP text feature: {self.text_zs.shape}') + + def __call__(self, images, mode='text'): + + images = self.transform(images) + image_zs = self.normalize(self.clip_model.encode_image(images)) + + if mode == 'text': + # if more than one string, randomly choose one. + if self.text_zs.shape[0] > 1: + idx = random.randint(0, self.text_zs.shape[0] - 1) + text_zs = self.text_zs[[idx]] + else: + text_zs = self.text_zs + # broadcast text_zs to all image_zs + loss = - (image_zs * text_zs).sum(-1).mean() + else: + raise NotImplementedError + + return loss + + # image-image (e.g., diet-nerf) + def prepare_image(self, dataset): + # images: a nerf dataset (we need both poses and images!) + pass \ No newline at end of file diff --git a/torch-ngp/nerf/gui.py b/torch-ngp/nerf/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..cb524f40bab498b8a1ece415ebc9bf9e08b6f7b9 --- /dev/null +++ b/torch-ngp/nerf/gui.py @@ -0,0 +1,436 @@ +import math +import torch +import numpy as np +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R + +from .utils import * + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention) + self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] -= self.radius + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2]) + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.radians(-0.1 * dx) + rotvec_y = side * np.radians(-0.1 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) + + +class NeRFGUI: + def __init__(self, opt, trainer, train_loader=None, debug=True): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) + self.debug = debug + self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg + self.training = False + self.step = 0 # training step + + self.trainer = trainer + self.train_loader = train_loader + if train_loader is not None: + self.trainer.error_map = train_loader._data.error_map + + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.mode = 'image' # choose from ['image', 'depth'] + + self.dynamic_resolution = True + self.downscale = 1 + self.train_steps = 16 + + dpg.create_context() + self.register_dpg() + self.test_step() + + + def __del__(self): + dpg.destroy_context() + + + def train_step(self): + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.train_gui(self.train_loader, step=self.train_steps) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + self.step += self.train_steps + self.need_update = True + + dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') + + # dynamic train steps + # max allowed train time per-frame is 500 ms + full_t = t / self.train_steps * 16 + train_steps = min(16, max(4, int(16 * 500 / full_t))) + if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: + self.train_steps = train_steps + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'] + else: + return np.expand_dims(outputs['depth'], -1).repeat(3, -1) + + + def test_step(self): + # TODO: seems we have to move data from GPU --> CPU --> GPU? + + if self.need_update or self.spp < self.opt.max_spp: + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(tag="_primary_window", width=self.W, height=self.H): + + # add the texture + dpg.add_image("_texture") + + dpg.set_primary_window("_primary_window", True) + + # control window + with dpg.window(label="Control", tag="_control_window", width=400, height=300): + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # time + if not self.opt.test: + with dpg.group(horizontal=True): + dpg.add_text("Train time: ") + dpg.add_text("no data", tag="_log_train_time") + + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + with dpg.group(horizontal=True): + dpg.add_text("SPP: ") + dpg.add_text("1", tag="_log_spp") + + # train button + if not self.opt.test: + with dpg.collapsing_header(label="Train", default_open=True): + + # train / stop + with dpg.group(horizontal=True): + dpg.add_text("Train: ") + + def callback_train(sender, app_data): + if self.training: + self.training = False + dpg.configure_item("_button_train", label="start") + else: + self.training = True + dpg.configure_item("_button_train", label="stop") + + dpg.add_button(label="start", tag="_button_train", callback=callback_train) + dpg.bind_item_theme("_button_train", theme_button) + + def callback_reset(sender, app_data): + @torch.no_grad() + def weight_reset(m: nn.Module): + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + self.trainer.model.apply(fn=weight_reset) + self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter + self.need_update = True + + dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) + dpg.bind_item_theme("_button_reset", theme_button) + + # save ckpt + with dpg.group(horizontal=True): + dpg.add_text("Checkpoint: ") + + def callback_save(sender, app_data): + self.trainer.save_checkpoint(full=True, best=False) + dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="save", tag="_button_save", callback=callback_save) + dpg.bind_item_theme("_button_save", theme_button) + + dpg.add_text("", tag="_log_ckpt") + + # save mesh + with dpg.group(horizontal=True): + dpg.add_text("Marching Cubes: ") + + def callback_mesh(sender, app_data): + self.trainer.save_mesh(resolution=256, threshold=10) + dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) + dpg.bind_item_theme("_button_mesh", theme_button) + + dpg.add_text("", tag="_log_mesh") + + with dpg.group(horizontal=True): + dpg.add_text("", tag="_log_train_log") + + + # rendering options + with dpg.collapsing_header(label="Options", default_open=True): + + # dynamic rendering resolution + with dpg.group(horizontal=True): + + def callback_set_dynamic_resolution(sender, app_data): + if self.dynamic_resolution: + self.dynamic_resolution = False + self.downscale = 1 + else: + self.dynamic_resolution = True + self.need_update = True + + dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) + dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") + + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) + + # bg_color picker + def callback_change_bg(sender, app_data): + self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] + self.need_update = True + + dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + # dt_gamma slider + def callback_set_dt_gamma(sender, app_data): + self.opt.dt_gamma = app_data + self.need_update = True + + dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) + + # max_steps slider + def callback_set_max_steps(sender, app_data): + self.opt.max_steps = app_data + self.need_update = True + + dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) + + # aabb slider + def callback_set_aabb(sender, app_data, user_data): + # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) + self.trainer.model.aabb_infer[user_data] = app_data + + # also change train aabb ? [better not...] + #self.trainer.model.aabb_train[user_data] = app_data + + self.need_update = True + + dpg.add_separator() + dpg.add_text("Axis-aligned bounding box:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) + + + # debug info + if self.debug: + with dpg.collapsing_header(label="Debug"): + # pose + dpg.add_separator() + dpg.add_text("Camera Pose:") + dpg.add_text(str(self.cam.pose), tag="_log_pose") + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) + + + dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) + + # TODO: seems dearpygui doesn't support resizing texture... + # def callback_resize(sender, app_data): + # self.W = app_data[0] + # self.H = app_data[1] + # # how to reload texture ??? + + # dpg.set_viewport_resize_callback(callback_resize) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + # update texture every frame + if self.training: + self.train_step() + self.test_step() + dpg.render_dearpygui_frame() \ No newline at end of file diff --git a/torch-ngp/nerf/network.py b/torch-ngp/nerf/network.py new file mode 100644 index 0000000000000000000000000000000000000000..28a763af32de7fd560d31edc9a2db0104bbb8a4c --- /dev/null +++ b/torch-ngp/nerf/network.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from activation import trunc_exp +from .renderer import NeRFRenderer + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + encoding="hashgrid", + encoding_dir="sphere_harmonics", + encoding_bg="hashgrid", + num_layers=2, + hidden_dim=64, + geo_feat_dim=15, + num_layers_color=3, + hidden_dim_color=64, + num_layers_bg=2, + hidden_dim_bg=64, + bound=1, + **kwargs, + ): + super().__init__(bound, **kwargs) + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) + + sigma_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + else: + in_dim = hidden_dim + + if l == num_layers - 1: + out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color + else: + out_dim = hidden_dim + + sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.sigma_net = nn.ModuleList(sigma_net) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) + + color_net = [] + for l in range(num_layers_color): + if l == 0: + in_dim = self.in_dim_dir + self.geo_feat_dim + else: + in_dim = hidden_dim_color + + if l == num_layers_color - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_color + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background network + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = self.in_dim_bg + self.in_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = nn.ModuleList(bg_net) + else: + self.bg_net = None + + + def forward(self, x, d): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + # sigma + x = self.encoder(x, bound=self.bound) + + h = x + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + color = torch.sigmoid(h) + + return sigma, color + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + x = self.encoder(x, bound=self.bound) + h = x + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + } + + def background(self, x, d): + # x: [N, 2], in [-1, 1] + + h = self.encoder_bg(x) # [N, C] + d = self.encoder_dir(d) + + h = torch.cat([d, h], dim=-1) + for l in range(self.num_layers_bg): + h = self.bg_net[l](h) + if l != self.num_layers_bg - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + return params diff --git a/torch-ngp/nerf/network_ff.py b/torch-ngp/nerf/network_ff.py new file mode 100644 index 0000000000000000000000000000000000000000..a89a3f20a6e27525d96f4fc82076bb8831959d11 --- /dev/null +++ b/torch-ngp/nerf/network_ff.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from activation import trunc_exp +from ffmlp import FFMLP + +from .renderer import NeRFRenderer + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + encoding="hashgrid", + encoding_dir="sphere_harmonics", + num_layers=2, + hidden_dim=64, + geo_feat_dim=15, + num_layers_color=3, + hidden_dim_color=64, + bound=1, + **kwargs + ): + super().__init__(bound, **kwargs) + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) + + self.sigma_net = FFMLP( + input_dim=self.in_dim, + output_dim=1 + self.geo_feat_dim, + hidden_dim=self.hidden_dim, + num_layers=self.num_layers, + ) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + self.encoder_dir, self.in_dim_color = get_encoder(encoding_dir) + self.in_dim_color += self.geo_feat_dim + 1 # a manual fixing to make it 32, as done in nerf_network.h#178 + + self.color_net = FFMLP( + input_dim=self.in_dim_color, + output_dim=3, + hidden_dim=self.hidden_dim_color, + num_layers=self.num_layers_color, + ) + + def forward(self, x, d): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + # sigma + x = self.encoder(x, bound=self.bound) + h = self.sigma_net(x) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + d = self.encoder_dir(d) + + # TODO: preallocate space and avoid this cat? + p = torch.zeros_like(geo_feat[..., :1]) # manual input padding + h = torch.cat([d, geo_feat, p], dim=-1) + h = self.color_net(h) + + # sigmoid activation for rgb + rgb = torch.sigmoid(h) + + return sigma, rgb + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + x = self.encoder(x, bound=self.bound) + h = self.sigma_net(x) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + } + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + #starter.record() + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + #print(x.shape, rgbs.shape) + + #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'mask = {curr_time}') + #starter.record() + + d = self.encoder_dir(d) + + p = torch.zeros_like(geo_feat[..., :1]) # manual input padding + h = torch.cat([d, geo_feat, p], dim=-1) + + h = self.color_net(h) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'call = {curr_time}') + #starter.record() + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) + else: + rgbs = h + + #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'unmask = {curr_time}') + #starter.record() + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + return params \ No newline at end of file diff --git a/torch-ngp/nerf/network_tcnn.py b/torch-ngp/nerf/network_tcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..d76cb0328eb46f5aa418845577e949cab532d3ab --- /dev/null +++ b/torch-ngp/nerf/network_tcnn.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +import tinycudann as tcnn +from activation import trunc_exp +from .renderer import NeRFRenderer + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + encoding="HashGrid", + encoding_dir="SphericalHarmonics", + num_layers=2, + hidden_dim=64, + geo_feat_dim=15, + num_layers_color=3, + hidden_dim_color=64, + bound=1, + **kwargs + ): + super().__init__(bound, **kwargs) + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + + per_level_scale = np.exp2(np.log2(2048 * bound / 16) / (16 - 1)) + + self.encoder = tcnn.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "per_level_scale": per_level_scale, + }, + ) + + self.sigma_net = tcnn.Network( + n_input_dims=32, + n_output_dims=1 + self.geo_feat_dim, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim, + "n_hidden_layers": num_layers - 1, + }, + ) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + + self.encoder_dir = tcnn.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "SphericalHarmonics", + "degree": 4, + }, + ) + + self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim + + self.color_net = tcnn.Network( + n_input_dims=self.in_dim_color, + n_output_dims=3, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_color, + "n_hidden_layers": num_layers_color - 1, + }, + ) + + + def forward(self, x, d): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + + # sigma + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + x = self.encoder(x) + h = self.sigma_net(x) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] + d = self.encoder_dir(d) + + #p = torch.zeros_like(geo_feat[..., :1]) # manual input padding + h = torch.cat([d, geo_feat], dim=-1) + h = self.color_net(h) + + # sigmoid activation for rgb + color = torch.sigmoid(h) + + return sigma, color + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + x = self.encoder(x) + h = self.sigma_net(x) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + } + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + # color + d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] + d = self.encoder_dir(d) + + h = torch.cat([d, geo_feat], dim=-1) + h = self.color_net(h) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + return params \ No newline at end of file diff --git a/torch-ngp/nerf/provider.py b/torch-ngp/nerf/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..2339339eec88c022b79734fcdeffe4fee14bef80 --- /dev/null +++ b/torch-ngp/nerf/provider.py @@ -0,0 +1,332 @@ +import os +import cv2 +import glob +import json +from cv2 import transform +import tqdm +import numpy as np +from scipy.spatial.transform import Slerp, Rotation + +import trimesh + +import torch +from torch.utils.data import DataLoader + +from .utils import get_rays + + +# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 +def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): + # for the fox dataset, 0.33 scales camera radius to ~ 2 + new_pose = np.array([ + [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], + [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], + [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], + [0, 0, 0, 1], + ], dtype=np.float32) + return new_pose + + +def visualize_poses(poses, size=0.1): + # poses: [B, 4, 4] + + axes = trimesh.creation.axis(axis_length=4) + box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() + box.colors = np.array([[128, 128, 128]] * len(box.entities)) + objects = [axes, box] + + for pose in poses: + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + + dir = (a + b + c + d) / 4 - pos + dir = dir / (np.linalg.norm(dir) + 1e-8) + o = pos + dir * 3 + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) + segs = trimesh.load_path(segs) + objects.append(segs) + + trimesh.Scene(objects).show() + + +def rand_poses(size, device, radius=1, theta_range=[np.pi/3, 2*np.pi/3], phi_range=[0, 2*np.pi]): + ''' generate random poses from an orbit camera + Args: + size: batch size of generated poses. + device: where to allocate the output. + radius: camera radius + theta_range: [min, max], should be in [0, \pi] + phi_range: [min, max], should be in [0, 2\pi] + Return: + poses: [size, 4, 4] + ''' + + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = - normalize(centers) + up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) # confused at the coordinate system... + right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) + up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + return poses + + +class NeRFDataset: + def __init__(self, opt, device, type='train', downscale=1, n_test=10): + super().__init__() + + self.opt = opt + self.device = device + self.type = type # train, val, test + self.downscale = downscale + self.root_path = opt.path + self.preload = opt.preload # preload data into GPU + self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. + self.offset = opt.offset # camera offset + self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. + self.fp16 = opt.fp16 # if preload, load into fp16. + + self.training = self.type in ['train', 'all', 'trainval'] + self.num_rays = self.opt.num_rays if self.training else -1 + + self.rand_pose = opt.rand_pose + + # auto-detect transforms.json and split mode. + if os.path.exists(os.path.join(self.root_path, 'transforms.json')): + self.mode = 'colmap' # manually split, use view-interpolation for test. + elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')): + self.mode = 'blender' # provided split + else: + raise NotImplementedError(f'[NeRFDataset] Cannot find transforms*.json under {self.root_path}') + + # load nerf-compatible format data. + if self.mode == 'colmap': + with open(os.path.join(self.root_path, 'transforms.json'), 'r') as f: + transform = json.load(f) + elif self.mode == 'blender': + # load all splits (train/valid/test), this is what instant-ngp in fact does... + if type == 'all': + transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) + transform = None + for transform_path in transform_paths: + with open(transform_path, 'r') as f: + tmp_transform = json.load(f) + if transform is None: + transform = tmp_transform + else: + transform['frames'].extend(tmp_transform['frames']) + # load train and val split + elif type == 'trainval': + with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: + transform = json.load(f) + with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: + transform_val = json.load(f) + transform['frames'].extend(transform_val['frames']) + # only load one specified split + else: + with open(os.path.join(self.root_path, f'transforms_{type}.json'), 'r') as f: + transform = json.load(f) + + else: + raise NotImplementedError(f'unknown dataset mode: {self.mode}') + + # load image size + if 'h' in transform and 'w' in transform: + self.H = int(transform['h']) // downscale + self.W = int(transform['w']) // downscale + else: + # we have to actually read an image to get H and W later. + self.H = self.W = None + + # read images + frames = transform["frames"] + #frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort... + + # for colmap, manually interpolate a test set. + if self.mode == 'colmap' and type == 'test': + + # choose two random poses, and interpolate between. + f0, f1 = np.random.choice(frames, 2, replace=False) + pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] + pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4] + rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]])) + slerp = Slerp([0, 1], rots) + + self.poses = [] + self.images = None + for i in range(n_test + 1): + ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5 + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = slerp(ratio).as_matrix() + pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3] + self.poses.append(pose) + + else: + # for colmap, manually split a valid set (the first frame). + if self.mode == 'colmap': + if type == 'train': + frames = frames[1:] + elif type == 'val': + frames = frames[:1] + # else 'all' or 'trainval' : use all frames + + self.poses = [] + self.images = [] + for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): + f_path = os.path.join(self.root_path, f['file_path']) + if self.mode == 'blender' and '.' not in os.path.basename(f_path): + f_path += '.png' # so silly... + + # there are non-exist paths in fox... + if not os.path.exists(f_path): + continue + + pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] + pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) + + image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] + if self.H is None or self.W is None: + self.H = image.shape[0] // downscale + self.W = image.shape[1] // downscale + + # add support for the alpha channel as a mask. + if image.shape[-1] == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + + if image.shape[0] != self.H or image.shape[1] != self.W: + image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA) + + image = image.astype(np.float32) / 255 # [H, W, 3/4] + + self.poses.append(pose) + self.images.append(image) + + self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4] + if self.images is not None: + self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C] + + # calculate mean radius of all camera poses + self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() + #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') + + # initialize error_map + if self.training and self.opt.error_map: + self.error_map = torch.ones([self.images.shape[0], 128 * 128], dtype=torch.float) # [B, 128 * 128], flattened for easy indexing, fixed resolution... + else: + self.error_map = None + + # [debug] uncomment to view all training poses. + # visualize_poses(self.poses.numpy()) + + # [debug] uncomment to view examples of randomly generated poses. + # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy()) + + if self.preload: + self.poses = self.poses.to(self.device) + if self.images is not None: + # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ? + if self.fp16 and self.opt.color_space != 'linear': + dtype = torch.half + else: + dtype = torch.float + self.images = self.images.to(dtype).to(self.device) + if self.error_map is not None: + self.error_map = self.error_map.to(self.device) + + # load intrinsics + if 'fl_x' in transform or 'fl_y' in transform: + fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale + fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale + elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: + # blender, assert in radians. already downscaled since we use H/W + fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None + fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None + if fl_x is None: fl_x = fl_y + if fl_y is None: fl_y = fl_x + else: + raise RuntimeError('Failed to load focal length, please check the transforms.json!') + + cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2) + cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2) + + self.intrinsics = np.array([fl_x, fl_y, cx, cy]) + + + def collate(self, index): + + B = len(index) # a list of length 1 + + # random pose without gt images. + if self.rand_pose == 0 or index[0] >= len(self.poses): + + poses = rand_poses(B, self.device, radius=self.radius) + + # sample a low-resolution but full image for CLIP + s = np.sqrt(self.H * self.W / self.num_rays) # only in training, assert num_rays > 0 + rH, rW = int(self.H / s), int(self.W / s) + rays = get_rays(poses, self.intrinsics / s, rH, rW, -1) + + return { + 'H': rH, + 'W': rW, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + } + + poses = self.poses[index].to(self.device) # [B, 4, 4] + + error_map = None if self.error_map is None else self.error_map[index] + + rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map, self.opt.patch_size) + + results = { + 'H': self.H, + 'W': self.W, + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + } + + if self.images is not None: + images = self.images[index].to(self.device) # [B, H, W, 3/4] + if self.training: + C = images.shape[-1] + images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] + results['images'] = images + + # need inds to update error_map + if error_map is not None: + results['index'] = index + results['inds_coarse'] = rays['inds_coarse'] + + return results + + def dataloader(self): + size = len(self.poses) + if self.training and self.rand_pose > 0: + size += size // self.rand_pose # index >= size means we use random pose. + loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) + loader._data = self # an ugly fix... we need to access error_map & poses in trainer. + loader.has_gt = self.images is not None + return loader \ No newline at end of file diff --git a/torch-ngp/nerf/renderer.py b/torch-ngp/nerf/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae4d6f66bdd9822fa63fb174bea99ffad5871b9 --- /dev/null +++ b/torch-ngp/nerf/renderer.py @@ -0,0 +1,574 @@ +import math +import trimesh +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import raymarching +from .utils import custom_meshgrid + +def sample_pdf(bins, weights, n_samples, det=False): + # This implementation is from NeRF + # bins: [B, T], old_z_vals + # weights: [B, T - 1], bin weights. + # return: [B, n_samples], new_z_vals + + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) + + # Invert CDF + u = u.contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples + + +def plot_pointcloud(pc, color=None): + # pc: [N, 3] + # color: [N, 3/4] + print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) + pc = trimesh.PointCloud(pc, color) + # axis + axes = trimesh.creation.axis(axis_length=4) + # sphere + sphere = trimesh.creation.icosphere(radius=1) + trimesh.Scene([pc, axes, sphere]).show() + + +class NeRFRenderer(nn.Module): + def __init__(self, + bound=1, + cuda_ray=False, + density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. + min_near=0.2, + density_thresh=0.01, + bg_radius=-1, + ): + super().__init__() + + self.bound = bound + self.cascade = 1 + math.ceil(math.log2(bound)) + self.grid_size = 128 + self.density_scale = density_scale + self.min_near = min_near + self.density_thresh = density_thresh + self.bg_radius = bg_radius # radius of the background sphere. + + # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) + # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. + aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) + aabb_infer = aabb_train.clone() + self.register_buffer('aabb_train', aabb_train) + self.register_buffer('aabb_infer', aabb_infer) + + # extra state for cuda raymarching + self.cuda_ray = cuda_ray + if cuda_ray: + # density grid + density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + # step counter + step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 + + def forward(self, x, d): + raise NotImplementedError() + + # separated density and color query (can accelerate non-cuda-ray mode.) + def density(self, x): + raise NotImplementedError() + + def color(self, x, d, mask=None, **kwargs): + raise NotImplementedError() + + def reset_extra_state(self): + if not self.cuda_ray: + return + # density grid + self.density_grid.zero_() + self.mean_density = 0 + self.iter_density = 0 + # step counter + self.step_counter.zero_() + self.mean_count = 0 + self.local_step = 0 + + def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # bg_color: [3] in range [0, 1] + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + # choose aabb + aabb = self.aabb_train if self.training else self.aabb_infer + + # sample steps + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) + nears.unsqueeze_(-1) + fars.unsqueeze_(-1) + + #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') + + z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] + z_vals = z_vals.expand((N, num_steps)) # [N, T] + z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] + + # perturb z_vals + sample_dist = (fars - nears) / num_steps + if perturb: + z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist + #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. + + # generate xyzs + xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] + xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + # query SDF and RGB + density_outputs = self.density(xyzs.reshape(-1, 3)) + + #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] + for k, v in density_outputs.items(): + density_outputs[k] = v.view(N, num_steps, -1) + + # upsample z_vals (nerf-like) + if upsample_steps > 0: + with torch.no_grad(): + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + + alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] + + # sample new z_vals + z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] + new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] + + new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] + new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + # only forward new points to save computation + new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) + #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] + for k, v in new_density_outputs.items(): + new_density_outputs[k] = v.view(N, upsample_steps, -1) + + # re-order + z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] + z_vals, z_index = torch.sort(z_vals, dim=1) + + xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] + xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) + + for k in density_outputs: + tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) + density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] + + dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) + for k, v in density_outputs.items(): + density_outputs[k] = v.view(-1, v.shape[-1]) + + mask = weights > 1e-4 # hard coded + rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) + rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] + + #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) + + # calculate weight_sum (mask) + weights_sum = weights.sum(dim=-1) # [N] + + # calculate depth + ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) + depth = torch.sum(weights * ori_z_vals, dim=-1) + + # calculate color + image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] + + # mix background color + if self.bg_radius > 0: + # use the bg model to calculate bg_color + sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] + bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3] + elif bg_color is None: + bg_color = 1 + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + # tmp: reg loss in mip-nerf 360 + # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) + # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] + # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() + + return { + 'depth': depth, + 'image': image, + 'weights_sum': weights_sum, + } + + + def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) + + # mix background color + if self.bg_radius > 0: + # use the bg model to calculate bg_color + sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] + bg_color = self.background(sph, rays_d) # [N, 3] + elif bg_color is None: + bg_color = 1 + + results = {} + + if self.training: + # setup counter + counter = self.step_counter[self.local_step % 16] + counter.zero_() # set to 0 + self.local_step += 1 + + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + sigmas, rgbs = self(xyzs, dirs) + # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. + # sigmas = density_outputs['sigma'] + # rgbs = self.color(xyzs, dirs, **density_outputs) + sigmas = self.density_scale * sigmas + + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + + # special case for CCNeRF's residual learning + if len(sigmas.shape) == 2: + K = sigmas.shape[0] + depths = [] + images = [] + for k in range(K): + weights_sum, depth, image = raymarching.composite_rays_train(sigmas[k], rgbs[k], deltas, rays, T_thresh) + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + images.append(image.view(*prefix, 3)) + depths.append(depth.view(*prefix)) + + depth = torch.stack(depths, axis=0) # [K, B, N] + image = torch.stack(images, axis=0) # [K, B, N, 3] + + else: + + weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh) + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + results['weights_sum'] = weights_sum + + else: + + # allocate outputs + # if use autocast, must init as half so it won't be autocasted and lose reference. + #dtype = torch.half if torch.is_autocast_enabled() else torch.float32 + # output should always be float32! only network inference uses half. + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + + step = 0 + + while step < max_steps: + + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + + xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) + + sigmas, rgbs = self(xyzs, dirs) + # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. + # sigmas = density_outputs['sigma'] + # rgbs = self.color(xyzs, dirs, **density_outputs) + sigmas = self.density_scale * sigmas + + raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) + + rays_alive = rays_alive[rays_alive >= 0] + + #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') + + step += n_step + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + results['depth'] = depth + results['image'] = image + + return results + + @torch.no_grad() + def mark_untrained_grid(self, poses, intrinsic, S=64): + # poses: [B, 4, 4] + # intrinsic: [3, 3] + + if not self.cuda_ray: + return + + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + + B = poses.shape[0] + + fx, fy, cx, cy = intrinsic + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + count = torch.zeros_like(self.density_grid) + poses = poses.to(count.device) + + # 5-level loop, forgive me... + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_world_xyzs = world_xyzs * (bound - half_grid_size) + + # split batch to avoid OOM + head = 0 + while head < B: + tail = min(head + S, B) + + # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) + cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) + cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] + + # query if point is covered by any camera + mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] + mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] + + # update count + count[cas, indices] += mask + head += S + + # mark untrained grid as -1 + self.density_grid[count == 0] = -1 + + print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}') + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # call before each epoch to update extra states. + + if not self.cuda_ray: + return + + ### update density grid + + tmp_grid = - torch.ones_like(self.density_grid) + + # full update. + if self.iter_density < 16: + #if True: + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() + sigmas *= self.density_scale + # assign + tmp_grid[cas, indices] = sigmas + + # partial update (half the computation) + # TODO: why no need of maxpool ? + else: + N = self.grid_size ** 3 // 4 # H * H * H / 4 + for cas in range(self.cascade): + # random sample some positions + coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + # random sample occupied positions + occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz] + rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device) + occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication + occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] + # concat + indices = torch.cat([indices, occ_indices], dim=0) + coords = torch.cat([coords, occ_coords], dim=0) + # same below + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() + sigmas *= self.density_scale + # assign + tmp_grid[cas, indices] = sigmas + + ## max-pool on tmp_grid for less aggressive culling [No significant improvement...] + # invalid_mask = tmp_grid < 0 + # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) + # tmp_grid[invalid_mask] = -1 + + # ema update + valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density. + #self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) + + ### update step counter + total_step = min(16, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 + + #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') + + + def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # return: pred_rgb: [B, N, 3] + + if self.cuda_ray: + _run = self.run_cuda + else: + _run = self.run + + B, N = rays_o.shape[:2] + device = rays_o.device + + # never stage when cuda_ray + if staged and not self.cuda_ray: + depth = torch.empty((B, N), device=device) + image = torch.empty((B, N, 3), device=device) + + for b in range(B): + head = 0 + while head < N: + tail = min(head + max_ray_batch, N) + results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) + depth[b:b+1, head:tail] = results_['depth'] + image[b:b+1, head:tail] = results_['image'] + head += max_ray_batch + + results = {} + results['depth'] = depth + results['image'] = image + + else: + results = _run(rays_o, rays_d, **kwargs) + + return results \ No newline at end of file diff --git a/torch-ngp/nerf/utils.py b/torch-ngp/nerf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d99ffaaedab2e4462537163a4163996dd65e9c4 --- /dev/null +++ b/torch-ngp/nerf/utils.py @@ -0,0 +1,1137 @@ +import os +import glob +import tqdm +import math +import imageio +import random +import warnings +import tensorboardX + +import numpy as np +import pandas as pd + +import time +from datetime import datetime + +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader + +import trimesh +import mcubes +from rich.console import Console +from torch_ema import ExponentialMovingAverage + +from packaging import version as pver +import lpips +from torchmetrics.functional import structural_similarity_index_measure + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +@torch.jit.script +def linear_to_srgb(x): + return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) + + +@torch.jit.script +def srgb_to_linear(x): + return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, error_map=None, patch_size=1): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + error_map: [B, 128 * 128], sample probability based on training error + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float + i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H*W) + + # if use patch-based sampling, ignore error_map + if patch_size > 1: + + # random sample left-top cores. + # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas. + num_patch = N // (patch_size ** 2) + inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device) + inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device) + inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2] + + # create meshgrid for each patch + pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device)) + offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2] + + inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2] + inds = inds.view(-1, 2) # [N, 2] + inds = inds[:, 0] * W + inds[:, 1] # [N], flatten + + inds = inds.expand([B, N]) + + elif error_map is None: + inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + else: + + # weighted sample on a low-reso grid + inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) + + # map to the original resolution with random perturb. + inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. + sx, sy = H / 128, W / 128 + inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) + inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) + inds = inds_x * W + inds_y + + results['inds_coarse'] = inds_coarse # need this when updating error_map + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + + results['inds'] = inds + + else: + inds = torch.arange(H*W, device=device).expand([B, H*W]) + + zs = torch.ones_like(i) + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o + results['rays_d'] = rays_d + + return results + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + +def torch_vis_2d(x, renormalize=False): + # x: [3, H, W] or [1, H, W] or [H, W] + import matplotlib.pyplot as plt + import numpy as np + import torch + + if isinstance(x, torch.Tensor): + if len(x.shape) == 3: + x = x.permute(1,2,0).squeeze() + x = x.detach().cpu().numpy() + + print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') + + x = x.astype(np.float32) + + # renormalize + if renormalize: + x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) + + plt.imshow(x) + plt.show() + + +def extract_fields(bound_min, bound_max, resolution, query_func, S=128): + + X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = custom_meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] + val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] + u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val + return u + + +def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): + #print('threshold: {}'.format(threshold)) + u = extract_fields(bound_min, bound_max, resolution, query_func) + + #print(u.shape, u.max(), u.min(), np.percentile(u, 50)) + + vertices, triangles = mcubes.marching_cubes(u, threshold) + + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles + + +class PSNRMeter: + def __init__(self): + self.V = 0 + self.N = 0 + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + if torch.is_tensor(inp): + inp = inp.detach().cpu().numpy() + outputs.append(inp) + + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, N, 3] or [B, H, W, 3], range[0, 1] + + # simplified since max_pixel_value is 1 here. + psnr = -10 * np.log10(np.mean((preds - truths) ** 2)) + + self.V += psnr + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) + + def report(self): + return f'PSNR = {self.measure():.6f}' + + +class SSIMMeter: + def __init__(self, device=None): + self.V = 0 + self.N = 0 + + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] + inp = inp.to(self.device) + outputs.append(inp) + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] + + ssim = structural_similarity_index_measure(preds, truths) + + self.V += ssim + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, "SSIM"), self.measure(), global_step) + + def report(self): + return f'SSIM = {self.measure():.6f}' + + +class LPIPSMeter: + def __init__(self, net='alex', device=None): + self.V = 0 + self.N = 0 + self.net = net + + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.fn = lpips.LPIPS(net=net).eval().to(self.device) + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] + inp = inp.to(self.device) + outputs.append(inp) + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] + v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] + self.V += v + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step) + + def report(self): + return f'LPIPS ({self.net}) = {self.measure():.6f}' + +class Trainer(object): + def __init__(self, + name, # name of this experiment + opt, # extra conf + model, # network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + eval_interval=1, # eval once every $ epoch + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.name = name + self.opt = opt + self.mute = mute + self.metrics = metrics + self.local_rank = local_rank + self.world_size = world_size + self.workspace = workspace + self.ema_decay = ema_decay + self.fp16 = fp16 + self.best_mode = best_mode + self.use_loss_as_metric = use_loss_as_metric + self.report_metric_at_train = report_metric_at_train + self.max_keep_ckpt = max_keep_ckpt + self.eval_interval = eval_interval + self.use_checkpoint = use_checkpoint + self.use_tensorboardX = use_tensorboardX + self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") + self.scheduler_update_every_step = scheduler_update_every_step + self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.console = Console() + + model.to(self.device) + if self.world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + self.model = model + + if isinstance(criterion, nn.Module): + criterion.to(self.device) + self.criterion = criterion + + # optionally use LPIPS loss for patch-based training + if self.opt.patch_size > 1: + import lpips + self.criterion_lpips = lpips.LPIPS(net='alex').to(self.device) + + if optimizer is None: + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam + else: + self.optimizer = optimizer(self.model) + + if lr_scheduler is None: + self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + else: + self.lr_scheduler = lr_scheduler(self.optimizer) + + if ema_decay is not None: + self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) + else: + self.ema = None + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + # variable init + self.epoch = 0 + self.global_step = 0 + self.local_step = 0 + self.stats = { + "loss": [], + "valid_loss": [], + "results": [], # metrics[0], or valid_loss + "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt + "best_result": None, + } + + # auto fix + if len(metrics) == 0 or self.use_loss_as_metric: + self.best_mode = 'min' + + # workspace prepare + self.log_ptr = None + if self.workspace is not None: + os.makedirs(self.workspace, exist_ok=True) + self.log_path = os.path.join(workspace, f"log_{self.name}.txt") + self.log_ptr = open(self.log_path, "a+") + + self.ckpt_path = os.path.join(self.workspace, 'checkpoints') + self.best_path = f"{self.ckpt_path}/{self.name}.pth" + os.makedirs(self.ckpt_path, exist_ok=True) + + self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') + self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + + if self.workspace is not None: + if self.use_checkpoint == "scratch": + self.log("[INFO] Training from scratch ...") + elif self.use_checkpoint == "latest": + self.log("[INFO] Loading latest checkpoint ...") + self.load_checkpoint() + elif self.use_checkpoint == "latest_model": + self.log("[INFO] Loading latest checkpoint (model only)...") + self.load_checkpoint(model_only=True) + elif self.use_checkpoint == "best": + if os.path.exists(self.best_path): + self.log("[INFO] Loading best checkpoint ...") + self.load_checkpoint(self.best_path) + else: + self.log(f"[INFO] {self.best_path} not found, loading latest ...") + self.load_checkpoint() + else: # path to ckpt + self.log(f"[INFO] Loading {self.use_checkpoint} ...") + self.load_checkpoint(self.use_checkpoint) + + # clip loss prepare + if opt.rand_pose >= 0: # =0 means only using CLIP loss, >0 means a hybrid mode. + from nerf.clip_utils import CLIPLoss + self.clip_loss = CLIPLoss(self.device) + self.clip_loss.prepare_text([self.opt.clip_text]) # only support one text prompt now... + + + def __del__(self): + if self.log_ptr: + self.log_ptr.close() + + + def log(self, *args, **kwargs): + if self.local_rank == 0: + if not self.mute: + #print(*args) + self.console.print(*args, **kwargs) + if self.log_ptr: + print(*args, file=self.log_ptr) + self.log_ptr.flush() # write immediately to file + + ### ------------------------------ + + def train_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + + # if there is no gt image, we train with CLIP loss. + if 'images' not in data: + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + # currently fix white bg, MUST force all rays! + outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=None, perturb=True, force_all_rays=True, **vars(self.opt)) + pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() + + # [debug] uncomment to plot the images used in train_step + #torch_vis_2d(pred_rgb[0]) + + loss = self.clip_loss(pred_rgb) + + return pred_rgb, None, loss + + images = data['images'] # [B, N, 3/4] + + B, N, C = images.shape + + if self.opt.color_space == 'linear': + images[..., :3] = srgb_to_linear(images[..., :3]) + + if C == 3 or self.model.bg_radius > 0: + bg_color = 1 + # train with random background color if not using a bg model and has alpha channel. + else: + #bg_color = torch.ones(3, device=self.device) # [3], fixed white background + #bg_color = torch.rand(3, device=self.device) # [3], frame-wise random. + bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random. + + if C == 4: + gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) + else: + gt_rgb = images + + outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if self.opt.patch_size == 1 else True, **vars(self.opt)) + # outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=True, **vars(self.opt)) + + pred_rgb = outputs['image'] + + # MSE loss + loss = self.criterion(pred_rgb, gt_rgb).mean(-1) # [B, N, 3] --> [B, N] + + # patch-based rendering + if self.opt.patch_size > 1: + gt_rgb = gt_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() + pred_rgb = pred_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() + + # torch_vis_2d(gt_rgb[0]) + # torch_vis_2d(pred_rgb[0]) + + # LPIPS loss [not useful...] + loss = loss + 1e-3 * self.criterion_lpips(pred_rgb, gt_rgb) + + # special case for CCNeRF's rank-residual training + if len(loss.shape) == 3: # [K, B, N] + loss = loss.mean(0) + + # update error_map + if self.error_map is not None: + index = data['index'] # [B] + inds = data['inds_coarse'] # [B, N] + + # take out, this is an advanced indexing and the copy is unavoidable. + error_map = self.error_map[index] # [B, H * W] + + # [debug] uncomment to save and visualize error map + # if self.global_step % 1001 == 0: + # tmp = error_map[0].view(128, 128).cpu().numpy() + # print(f'[write error map] {tmp.shape} {tmp.min()} ~ {tmp.max()}') + # tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min()) + # cv2.imwrite(os.path.join(self.workspace, f'{self.global_step}.jpg'), (tmp * 255).astype(np.uint8)) + + error = loss.detach().to(error_map.device) # [B, N], already in [0, 1] + + # ema update + ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error + error_map.scatter_(1, inds, ema_error) + + # put back + self.error_map[index] = error_map + + loss = loss.mean() + + # extra loss + # pred_weights_sum = outputs['weights_sum'] + 1e-8 + # loss_ws = - 1e-1 * pred_weights_sum * torch.log(pred_weights_sum) # entropy to encourage weights_sum to be 0 or 1. + # loss = loss + loss_ws.mean() + + return pred_rgb, gt_rgb, loss + + def eval_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + images = data['images'] # [B, H, W, 3/4] + B, H, W, C = images.shape + + if self.opt.color_space == 'linear': + images[..., :3] = srgb_to_linear(images[..., :3]) + + # eval with fixed background color + bg_color = 1 + if C == 4: + gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) + else: + gt_rgb = images + + outputs = self.model.render(rays_o, rays_d, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt)) + + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + + loss = self.criterion(pred_rgb, gt_rgb).mean() + + return pred_rgb, pred_depth, gt_rgb, loss + + # moved out bg_color and perturb for more flexible control... + def test_step(self, data, bg_color=None, perturb=False): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + H, W = data['H'], data['W'] + + if bg_color is not None: + bg_color = bg_color.to(self.device) + + outputs = self.model.render(rays_o, rays_d, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt)) + + pred_rgb = outputs['image'].reshape(-1, H, W, 3) + pred_depth = outputs['depth'].reshape(-1, H, W) + + return pred_rgb, pred_depth + + + def save_mesh(self, save_path=None, resolution=256, threshold=10): + + if save_path is None: + save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + def query_func(pts): + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + sigma = self.model.density(pts.to(self.device))['sigma'] + return sigma + + vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func) + + mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... + mesh.export(save_path) + + self.log(f"==> Finished saving mesh.") + + ### ------------------------------ + + def train(self, train_loader, valid_loader, max_epochs): + if self.use_tensorboardX and self.local_rank == 0: + self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) + + # mark untrained region (i.e., not covered by any camera from the training dataset) + if self.model.cuda_ray: + self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) + + # get a ref to error_map + self.error_map = train_loader._data.error_map + + for epoch in range(self.epoch + 1, max_epochs + 1): + self.epoch = epoch + + self.train_one_epoch(train_loader) + + if self.workspace is not None and self.local_rank == 0: + self.save_checkpoint(full=True, best=False) + + if self.epoch % self.eval_interval == 0: + self.evaluate_one_epoch(valid_loader) + self.save_checkpoint(full=False, best=True) + + if self.use_tensorboardX and self.local_rank == 0: + self.writer.close() + + def evaluate(self, loader, name=None): + self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX + self.evaluate_one_epoch(loader, name) + self.use_tensorboardX = use_tensorboardX + + def test(self, loader, save_path=None, name=None, write_video=True): + + if save_path is None: + save_path = os.path.join(self.workspace, 'results') + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + os.makedirs(save_path, exist_ok=True) + + self.log(f"==> Start Test, save results to {save_path}") + + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + self.model.eval() + + if write_video: + all_preds = [] + all_preds_depth = [] + + with torch.no_grad(): + + for i, data in enumerate(loader): + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth = self.test_step(data) + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth * 255).astype(np.uint8) + + if write_video: + all_preds.append(pred) + all_preds_depth.append(pred_depth) + else: + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth) + + pbar.update(loader.batch_size) + + if write_video: + all_preds = np.stack(all_preds, axis=0) + all_preds_depth = np.stack(all_preds_depth, axis=0) + imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1) + imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1) + + self.log(f"==> Finished Test.") + + # [GUI] just train for 16 steps, without any other overhead that may slow down rendering. + def train_gui(self, train_loader, step=16): + + self.model.train() + + total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) + + loader = iter(train_loader) + + # mark untrained grid + if self.global_step == 0: + self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) + + for _ in range(step): + + # mimic an infinite loop dataloader (in case the total dataset is smaller than step) + try: + data = next(loader) + except StopIteration: + loader = iter(train_loader) + data = next(loader) + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + total_loss += loss.detach() + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss.item() / step + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + outputs = { + 'loss': average_loss, + 'lr': self.optimizer.param_groups[0]['lr'], + } + + return outputs + + + # [GUI] test on a single image + def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1): + + # render resolution (may need downscale to for better frame rate) + rH = int(H * downscale) + rW = int(W * downscale) + intrinsics = intrinsics * downscale + + pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) + + rays = get_rays(pose, intrinsics, rH, rW, -1) + + data = { + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'H': rH, + 'W': rW, + } + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + # here spp is used as perturb random seed! (but not perturb the first sample) + preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp) + + if self.ema is not None: + self.ema.restore() + + # interpolation to the original resolution + if downscale != 1: + # TODO: have to permute twice with torch... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + + outputs = { + 'image': pred, + 'depth': pred_depth, + } + + return outputs + + def train_one_epoch(self, loader): + self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + for data in loader: + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.local_step += 1 + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + + if self.local_rank == 0: + if self.report_metric_at_train: + for metric in self.metrics: + metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar("train/loss", loss_val, self.global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + self.log(f"==> Finished Epoch {self.epoch}.") + + + def evaluate_one_epoch(self, loader, name=None): + self.log(f"++> Evaluate at epoch {self.epoch} ...") + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + total_loss = 0 + if self.local_rank == 0: + for metric in self.metrics: + metric.clear() + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + with torch.no_grad(): + self.local_step = 0 + + for data in loader: + self.local_step += 1 + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, truths, loss = self.eval_step(data) + + # all_gather/reduce the statistics (NCCL only support all_*) + if self.world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + loss = loss / self.world_size + + preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_list, preds) + preds = torch.cat(preds_list, dim=0) + + preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_depth_list, preds_depth) + preds_depth = torch.cat(preds_depth_list, dim=0) + + truths_list = [torch.zeros_like(truths).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(truths_list, truths) + truths = torch.cat(truths_list, dim=0) + + loss_val = loss.item() + total_loss += loss_val + + # only rank = 0 will perform evaluation. + if self.local_rank == 0: + + for metric in self.metrics: + metric.update(preds, truths) + + # save image + save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png') + save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png') + + #self.log(f"==> Saving validation image to {save_path}") + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth * 255).astype(np.uint8) + + cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(save_path_depth, pred_depth) + + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + + average_loss = total_loss / self.local_step + self.stats["valid_loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if not self.use_loss_as_metric and len(self.metrics) > 0: + result = self.metrics[0].measure() + self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result + else: + self.stats["results"].append(average_loss) # if no metric, choose best by min loss + + for metric in self.metrics: + self.log(metric.report(), style="blue") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="evaluate") + metric.clear() + + if self.ema is not None: + self.ema.restore() + + self.log(f"++> Evaluate epoch {self.epoch} Finished.") + + def save_checkpoint(self, name=None, full=False, best=False, remove_old=True): + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + state = { + 'epoch': self.epoch, + 'global_step': self.global_step, + 'stats': self.stats, + } + + if self.model.cuda_ray: + state['mean_count'] = self.model.mean_count + state['mean_density'] = self.model.mean_density + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{self.ckpt_path}/{name}.pth" + + if remove_old: + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = self.stats["checkpoints"].pop(0) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, file_path) + + else: + if len(self.stats["results"]) > 0: + if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]: + self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") + self.stats["best_result"] = self.stats["results"][-1] + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + # we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf) + if 'density_grid' in state['model']: + del state['model']['density_grid'] + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None, model_only=False): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log("[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + if 'model' not in checkpoint_dict: + self.model.load_state_dict(checkpoint_dict) + self.log("[INFO] loaded model.") + return + + missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + self.ema.load_state_dict(checkpoint_dict['ema']) + + if self.model.cuda_ray: + if 'mean_count' in checkpoint_dict: + self.model.mean_count = checkpoint_dict['mean_count'] + if 'mean_density' in checkpoint_dict: + self.model.mean_density = checkpoint_dict['mean_density'] + + if model_only: + return + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + self.global_step = checkpoint_dict['global_step'] + self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer.") + + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler.") + + if self.scaler and 'scaler' in checkpoint_dict: + try: + self.scaler.load_state_dict(checkpoint_dict['scaler']) + self.log("[INFO] loaded scaler.") + except: + self.log("[WARN] Failed to load scaler.") \ No newline at end of file diff --git a/torch-ngp/raymarching/__init__.py b/torch-ngp/raymarching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26d3cc6d4430c112603bba68bbd1bedd0ccbc7ac --- /dev/null +++ b/torch-ngp/raymarching/__init__.py @@ -0,0 +1 @@ +from .raymarching import * \ No newline at end of file diff --git a/torch-ngp/raymarching/backend.py b/torch-ngp/raymarching/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a9a03227b3ad718d622a653bf33bfb11e88218 --- /dev/null +++ b/torch-ngp/raymarching/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_raymarching', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/torch-ngp/raymarching/raymarching.py b/torch-ngp/raymarching/raymarching.py new file mode 100644 index 0000000000000000000000000000000000000000..80db197ee281fb17781c9149b2d0a6c3c4842078 --- /dev/null +++ b/torch-ngp/raymarching/raymarching.py @@ -0,0 +1,373 @@ +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _raymarching as _backend +except ImportError: + from .backend import _backend + + +# ---------------------------------------- +# utils +# ---------------------------------------- + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + _backend.morton3D(coords.int(), N, indices) + + return indices + +morton3D = _morton3D.apply + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + _backend.morton3D_invert(indices.int(), N, coords) + + return coords + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + _backend.packbits(grid, N, thresh, bitfield) + + return bitfield + +packbits = _packbits.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) + rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + M = N * max_steps # init max points number in total + + # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) + # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. + if not force_all_rays and mean_count > 0: + if align > 0: + mean_count += align - mean_count % align + M = mean_count + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + + if step_counter is None: + step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number + + #print(step_counter, M) + + # only used at the first (few) epochs. + if force_all_rays or mean_count <= 0: + m = step_counter[0].item() # D2H copy + if align > 0: + m += align - m % align + xyzs = xyzs[:m] + dirs = dirs[:m] + deltas = deltas[:m] + + torch.cuda.empty_cache() + + return xyzs, dirs, deltas, rays + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + + _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs) + + return grad_sigmas, grad_rgbs, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + M = n_alive * n_step + + if align > 0: + M += align - (M % align) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) + + return xyzs, dirs, deltas + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) + return tuple() + + +composite_rays = _composite_rays.apply \ No newline at end of file diff --git a/torch-ngp/raymarching/setup.py b/torch-ngp/raymarching/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d97449970ad3381d98fe74535ab7b6ca106bcbbc --- /dev/null +++ b/torch-ngp/raymarching/setup.py @@ -0,0 +1,62 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +''' +Usage: + +python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) + +python setup.py install # build extensions and install (copy) to PATH. +pip install . # ditto but better (e.g., dependency & metadata handling) + +python setup.py develop # build extensions and install (symbolic) to PATH. +pip install -e . # ditto but better (e.g., dependency & metadata handling) + +''' +setup( + name='raymarching', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_raymarching', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/torch-ngp/raymarching/src/bindings.cpp b/torch-ngp/raymarching/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47920bc7cd44813f6cee2ba47c9693e1ad25adce --- /dev/null +++ b/torch-ngp/raymarching/src/bindings.cpp @@ -0,0 +1,19 @@ +#include + +#include "raymarching.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // utils + m.def("packbits", &packbits, "packbits (CUDA)"); + m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); + m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); + m.def("morton3D", &morton3D, "morton3D (CUDA)"); + m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); + // train + m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); + m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("march_rays", &march_rays, "march rays (CUDA)"); + m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); +} \ No newline at end of file diff --git a/torch-ngp/raymarching/src/raymarching.cu b/torch-ngp/raymarching/src/raymarching.cu new file mode 100644 index 0000000000000000000000000000000000000000..16065033cfb2e3caed9d5fc8083a6c25da9e0be5 --- /dev/null +++ b/torch-ngp/raymarching/src/raymarching.cu @@ -0,0 +1,914 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } +inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } +inline constexpr __device__ float PI() { return 3.141592653589793f; } +inline constexpr __device__ float RPI() { return 0.3183098861837907f; } + + +template +inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +inline __host__ __device__ float signf(const float x) { + return copysignf(1.0, x); +} + +inline __host__ __device__ float clamp(const float x, const float min, const float max) { + return fminf(max, fmaxf(min, x)); +} + +inline __host__ __device__ void swapf(float& a, float& b) { + float c = a; a = b; b = c; +} + +inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { + const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); + int exponent; + frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { + const float mx = dt * H * 0.5; + int exponent; + frexpf(mx, &exponent); + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __host__ __device__ uint32_t __expand_bits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) +{ + uint32_t xx = __expand_bits(x); + uint32_t yy = __expand_bits(y); + uint32_t zz = __expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) +{ + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + + +//////////////////////////////////////////////////// +///////////// utils ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// nears/fars: [N] +// scalar_t should always be float in use. +template +__global__ void kernel_near_far_from_aabb( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const scalar_t * __restrict__ aabb, + const uint32_t N, + const float min_near, + scalar_t * nears, scalar_t * fars +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // get near far (assume cube scene) + float near = (aabb[0] - ox) * rdx; + float far = (aabb[3] - ox) * rdx; + if (near > far) swapf(near, far); + + float near_y = (aabb[1] - oy) * rdy; + float far_y = (aabb[4] - oy) * rdy; + if (near_y > far_y) swapf(near_y, far_y); + + if (near > far_y || near_y > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_y > near) near = near_y; + if (far_y < far) far = far_y; + + float near_z = (aabb[2] - oz) * rdz; + float far_z = (aabb[5] - oz) * rdz; + if (near_z > far_z) swapf(near_z, far_z); + + if (near > far_z || near_z > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_z > near) near = near_z; + if (far_z < far) far = far_z; + + if (near < min_near) near = min_near; + + nears[n] = near; + fars[n] = far; +} + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "near_far_from_aabb", ([&] { + kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); + })); +} + + +// rays_o/d: [N, 3] +// radius: float +// coords: [N, 2] +template +__global__ void kernel_sph_from_ray( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const float radius, + const uint32_t N, + scalar_t * coords +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + coords += n * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // solve t from || o + td || = radius + const float A = dx * dx + dy * dy + dz * dz; + const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 + const float C = ox * ox + oy * oy + oz * oz - radius * radius; + + const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) + + // solve theta, phi (assume y is the up axis) + const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; + const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) + const float phi = atan2(z, x); // [-PI, PI) + + // normalize to [-1, 1] + coords[0] = 2 * theta * RPI() - 1; + coords[1] = phi * RPI(); +} + + +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "sph_from_ray", ([&] { + kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); + })); +} + + +// coords: int32, [N, 3] +// indices: int32, [N] +__global__ void kernel_morton3D( + const int * __restrict__ coords, + const uint32_t N, + int * indices +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + indices[n] = __morton3D(coords[0], coords[1], coords[2]); +} + + +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); +} + + +// indices: int32, [N] +// coords: int32, [N, 3] +__global__ void kernel_morton3D_invert( + const int * __restrict__ indices, + const uint32_t N, + int * coords +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + + const int ind = indices[n]; + + coords[0] = __morton3D_invert(ind >> 0); + coords[1] = __morton3D_invert(ind >> 1); + coords[2] = __morton3D_invert(ind >> 2); +} + + +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); +} + + +// grid: float, [C, H, H, H] +// N: int, C * H * H * H / 8 +// density_thresh: float +// bitfield: uint8, [N] +template +__global__ void kernel_packbits( + const scalar_t * __restrict__ grid, + const uint32_t N, + const float density_thresh, + uint8_t * bitfield +) { + // parallel per byte + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grid += n * 8; + + uint8_t bits = 0; + + #pragma unroll + for (uint8_t i = 0; i < 8; i++) { + bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; + } + + bitfield[n] = bits; +} + + +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid.scalar_type(), "packbits", ([&] { + kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); + })); +} + +//////////////////////////////////////////////////// +///////////// training ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// grid: [CHHH / 8] +// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] +// dirs: [M, 3] +// rays: [N, 3], idx, offset, num_steps +template +__global__ void kernel_march_rays_train( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const uint8_t * __restrict__ grid, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, + int * rays, + int * counter, + const scalar_t* __restrict__ noises +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + // ray marching + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[n]; + const float far = fars[n]; + const float noise = noises[n]; + + const float dt_min = 2 * SQRT3() / max_steps; + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + + float t0 = near; + + // perturb + t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; + + // first pass: estimation of num_steps + float t = t0; + uint32_t num_steps = 0; + + //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); + + while (t < far && num_steps < max_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); + + if (occ) { + num_steps++; + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } + + //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); + + // second pass: really locate and write points & dirs + uint32_t point_index = atomicAdd(counter, num_steps); + uint32_t ray_index = atomicAdd(counter + 1, 1); + + //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); + + // write rays + rays[ray_index * 3] = n; + rays[ray_index * 3 + 1] = point_index; + rays[ray_index * 3 + 2] = num_steps; + + if (num_steps == 0) return; + if (point_index + num_steps > M) return; + + xyzs += point_index * 3; + dirs += point_index * 3; + deltas += point_index * 2; + + t = t0; + uint32_t step = 0; + + float last_t = t; + + while (t < far && step < num_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + // query grid + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + t += dt; + deltas[0] = dt; + deltas[1] = t - last_t; // used to calc depth + last_t = t; + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays_train", ([&] { + kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); + })); +} + + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + t += deltas[1]; // real delta + d += weight * t; + + ws += weight; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_forward", ([&] { + kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_image += index * 3; + weights_sum += index; + image += index * 3; + sigmas += offset; + rgbs += offset * 3; + deltas += offset * 2; + grad_sigmas += offset; + grad_rgbs += offset * 3; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + + step++; + } +} + + +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_backward", ([&] { + kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + +template +__global__ void kernel_march_rays( + const uint32_t n_alive, + const uint32_t n_step, + const int* __restrict__ rays_alive, + const scalar_t* __restrict__ rays_t, + const scalar_t* __restrict__ rays_o, + const scalar_t* __restrict__ rays_d, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t C, const uint32_t H, + const uint8_t * __restrict__ grid, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, + const scalar_t* __restrict__ noises +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + const float noise = noises[n]; + + // locate + rays_o += index * 3; + rays_d += index * 3; + xyzs += n * n_step * 3; + dirs += n * n_step * 3; + deltas += n * n_step * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + float t = rays_t[index]; // current ray's t + const float near = nears[index], far = fars[index]; + + const float dt_min = 2 * SQRT3() / max_steps; + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + + // march for n_step steps, record points + uint32_t step = 0; + + // introduce some randomness + t += clamp(t * dt_gamma, dt_min, dt_max) * noise; + + float last_t = t; + + while (t < far && step < n_step) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + // calc dt + t += dt; + deltas[0] = dt; + deltas[1] = t - last_t; // used to calc depth + last_t = t; + // step + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays", ([&] { + kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); + })); +} + + +template +__global__ void kernel_composite_rays( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t += deltas[1]; // real delta + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; +} + + +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays", ([&] { + kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} \ No newline at end of file diff --git a/torch-ngp/raymarching/src/raymarching.h b/torch-ngp/raymarching/src/raymarching.h new file mode 100644 index 0000000000000000000000000000000000000000..3a2e692cfb8f6fbdd7fbd7a7e89b7deb05d09d42 --- /dev/null +++ b/torch-ngp/raymarching/src/raymarching.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs); + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); \ No newline at end of file diff --git a/torch-ngp/readme.md b/torch-ngp/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..cd88cbc6a2ef96841ab8bb3a37b9d08f8347f5a5 --- /dev/null +++ b/torch-ngp/readme.md @@ -0,0 +1,308 @@ +# torch-ngp + +This repository contains: +* A pytorch implementation of the SDF and NeRF part (grid encoder, density grid ray sampler) in [instant-ngp](https://github.com/NVlabs/instant-ngp), as described in [_Instant Neural Graphics Primitives with a Multiresolution Hash Encoding_](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.pdf). +* A pytorch implementation of [TensoRF](https://github.com/apchenstu/TensoRF), as described in [_TensoRF: Tensorial Radiance Fields_](https://arxiv.org/abs/2203.09517), adapted to instant-ngp's NeRF framework. +* A pytorch implementation of [CCNeRF](https://github.com/ashawkey/CCNeRF), as described in [_Compressible-composable NeRF via Rank-residual Decomposition_](https://arxiv.org/abs/2205.14870). +* [New!] An implementation of [D-NeRF](https://github.com/albertpumarola/D-NeRF) adapted to instant-ngp's framework, as described in [_D-NeRF: Neural Radiance Fields for Dynamic Scenes_](https://openaccess.thecvf.com/content/CVPR2021/papers/Pumarola_D-NeRF_Neural_Radiance_Fields_for_Dynamic_Scenes_CVPR_2021_paper.pdf). +* Some experimental features in the NeRF framework (e.g., text-guided NeRF editig similar to [CLIP-NeRF](https://arxiv.org/abs/2112.05139)). +* A GUI for training/visualizing NeRF! + +**News**: A clean and improved version focusing on static NeRF reconstruction of realistic scenes has been separated into [nerf_template](https://github.com/ashawkey/nerf_template), as this repository has been hard to maintain. + +### [Gallery](assets/gallery.md) | [Update Logs](assets/update_logs.md) + +Instant-ngp interactive training/rendering on lego: + +https://user-images.githubusercontent.com/25863658/176174011-e7b7c4ab-9b6f-4f65-9952-7eceafe609b7.mp4 + +Also the first interactive deformable-nerf implementation: + +https://user-images.githubusercontent.com/25863658/175821784-63ba79f6-29be-47b5-b3fc-dab5282fce7a.mp4 + + +### Other related projects + +* [ngp_pl](https://github.com/kwea123/ngp_pl): PyTorch+CUDA trained with pytorch-lightning. + +* [JNeRF](https://github.com/Jittor/JNeRF): An NeRF benchmark based on Jittor. + +* [HashNeRF-pytorch](https://github.com/yashbhalgat/HashNeRF-pytorch): A pure PyTorch implementation. + +* [dreamfields-torch](https://github.com/ashawkey/dreamfields-torch): PyTorch+CUDA implementation of [_Zero-Shot Text-Guided Object Generation with Dream Fields_](https://arxiv.org/abs/2112.01455) based on this repository. + +# Install +```bash +git clone --recursive https://github.com/ashawkey/torch-ngp.git +cd torch-ngp +``` + +### Install with pip +```bash +pip install -r requirements.txt + +# (optional) install the tcnn backbone +pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +``` + +### Install with conda +```bash +conda env create -f environment.yml +conda activate torch-ngp +``` + +### Build extension (optional) +By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. +However, this may be inconvenient sometimes. +Therefore, we also provide the `setup.py` to build each extension: +```bash +# install all extension modules +bash scripts/install_ext.sh + +# if you want to install manually, here is an example: +cd raymarching +python setup.py build_ext --inplace # build ext only, do not install (only can be used in the parent directory) +pip install . # install to python path (you still need the raymarching/ folder, since this only install the built extension.) +``` + +### Tested environments +* Ubuntu 20 with torch 1.10 & CUDA 11.3 on a TITAN RTX. +* Ubuntu 16 with torch 1.8 & CUDA 10.1 on a V100. +* Windows 10 with torch 1.11 & CUDA 11.3 on a RTX 3070. + +Currently, `--ff` only supports GPUs with CUDA architecture `>= 70`. +For GPUs with lower architecture, `--tcnn` can still be used, but the speed will be slower compared to more recent GPUs. + + +# Usage + +We use the same data format as instant-ngp, e.g., [armadillo](https://github.com/NVlabs/instant-ngp/blob/master/data/sdf/armadillo.obj) and [fox](https://github.com/NVlabs/instant-ngp/tree/master/data/nerf/fox). +Please download and put them under `./data`. + +We also support self-captured dataset and converting other formats (e.g., LLFF, Tanks&Temples, Mip-NeRF 360) to the nerf-compatible format, with details in the following code block. + +
+ Supported datasets + + * [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) + + * [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip): [[conversion script]](./scripts/tanks2nerf.py) + + * [LLFF](https://drive.google.com/drive/folders/14boI-o5hGO9srnWaaogTU5_ji7wkX2S7): [[conversion script]](./scripts/llff2nerf.py) + + * [Mip-NeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip): [[conversion script]](./scripts/llff2nerf.py) + + * (dynamic) [D-NeRF](https://www.dropbox.com/s/0bf6fl0ye2vz3vr/data.zip?dl=0) + + * (dynamic) [Hyper-NeRF](https://github.com/google/hypernerf/releases/tag/v0.1): [[conversion script]](./scripts/hyper2nerf.py) + +
+ +First time running will take some time to compile the CUDA extensions. + +```bash +### Instant-ngp NeRF +# train with different backbones (with slower pytorch ray marching) +# for the colmap dataset, the default dataset setting `--bound 2 --scale 0.33` is used. +python main_nerf.py data/fox --workspace trial_nerf # fp32 mode +python main_nerf.py data/fox --workspace trial_nerf --fp16 # fp16 mode (pytorch amp) +python main_nerf.py data/fox --workspace trial_nerf --fp16 --ff # fp16 mode + FFMLP (this repo's implementation) +python main_nerf.py data/fox --workspace trial_nerf --fp16 --tcnn # fp16 mode + official tinycudann's encoder & MLP + +# use CUDA to accelerate ray marching (much more faster!) +python main_nerf.py data/fox --workspace trial_nerf --fp16 --cuda_ray # fp16 mode + cuda raymarching + +# preload data into GPU, accelerate training but use more GPU memory. +python main_nerf.py data/fox --workspace trial_nerf --fp16 --preload + +# one for all: -O means --fp16 --cuda_ray --preload, which usually gives the best results balanced on speed & performance. +python main_nerf.py data/fox --workspace trial_nerf -O + +# test mode +python main_nerf.py data/fox --workspace trial_nerf -O --test + +# construct an error_map for each image, and sample rays based on the training error (slow down training but get better performance with the same number of training steps) +python main_nerf.py data/fox --workspace trial_nerf -O --error_map + +# use a background model (e.g., a sphere with radius = 32), can supress noises for real-world 360 dataset +python main_nerf.py data/firekeeper --workspace trial_nerf -O --bg_radius 32 + +# start a GUI for NeRF training & visualization +# always use with `--fp16 --cuda_ray` for an acceptable framerate! +python main_nerf.py data/fox --workspace trial_nerf -O --gui + +# test mode for GUI +python main_nerf.py data/fox --workspace trial_nerf -O --gui --test + +# for the blender dataset, you should add `--bound 1.0 --scale 0.8 --dt_gamma 0` +# --bound means the scene is assumed to be inside box[-bound, bound] +# --scale adjusts the camera locaction to make sure it falls inside the above bounding box. +# --dt_gamma controls the adaptive ray marching speed, set to 0 turns it off. +python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf -O --bound 1.0 --scale 0.8 --dt_gamma 0 +python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui + +# for the LLFF dataset, you should first convert it to nerf-compatible format: +python scripts/llff2nerf.py data/nerf_llff_data/fern # by default it use full-resolution images, and write `transforms.json` to the folder +python scripts/llff2nerf.py data/nerf_llff_data/fern --images images_4 --downscale 4 # if you prefer to use the low-resolution images +# then you can train as a colmap dataset (you'll need to tune the scale & bound if necessary): +python main_nerf.py data/nerf_llff_data/fern --workspace trial_nerf -O +python main_nerf.py data/nerf_llff_data/fern --workspace trial_nerf -O --gui + +# for the Tanks&Temples dataset, you should first convert it to nerf-compatible format: +python scripts/tanks2nerf.py data/TanksAndTemple/Family # write `trainsforms_{split}.json` for [train, val, test] +# then you can train as a blender dataset (you'll need to tune the scale & bound if necessary) +python main_nerf.py data/TanksAndTemple/Family --workspace trial_nerf_family -O --bound 1.0 --scale 0.33 --dt_gamma 0 +python main_nerf.py data/TanksAndTemple/Family --workspace trial_nerf_family -O --bound 1.0 --scale 0.33 --dt_gamma 0 --gui + +# for custom dataset, you should: +# 1. take a video / many photos from different views +# 2. put the video under a path like ./data/custom/video.mp4 or the images under ./data/custom/images/*.jpg. +# 3. call the preprocess code: (should install ffmpeg and colmap first! refer to the file for more options) +python scripts/colmap2nerf.py --video ./data/custom/video.mp4 --run_colmap # if use video +python scripts/colmap2nerf.py --images ./data/custom/images/ --run_colmap # if use images +python scripts/colmap2nerf.py --video ./data/custom/video.mp4 --run_colmap --dynamic # if the scene is dynamic (for D-NeRF settings), add the time for each frame. +# 4. it should create the transform.json, and you can train with: (you'll need to try with different scale & bound & dt_gamma to make the object correctly located in the bounding box and render fluently.) +python main_nerf.py data/custom --workspace trial_nerf_custom -O --gui --scale 2.0 --bound 1.0 --dt_gamma 0.02 + +### Instant-ngp SDF +python main_sdf.py data/armadillo.obj --workspace trial_sdf +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --ff +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --tcnn + +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --test + +### TensoRF +# almost the same as Instant-ngp NeRF, just replace the main script. +python main_tensoRF.py data/fox --workspace trial_tensoRF -O +python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF -O --bound 1.0 --scale 0.8 --dt_gamma 0 + +### CCNeRF +# training on single objects, turn on --error_map for better quality. +python main_CCNeRF.py data/nerf_synthetic/chair --workspace trial_cc_chair -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map +python main_CCNeRF.py data/nerf_synthetic/ficus --workspace trial_cc_ficus -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map +python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map +# compose, use a larger bound and more samples per ray for better quality. +python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 2.0 --scale 0.67 --dt_gamma 0 --max_steps 2048 --test --compose +# compose + gui, only about 1 FPS without dynamic resolution... just for quick verification of composition results. +python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 2.0 --scale 0.67 --dt_gamma 0 --test --compose --gui + +### D-NeRF +# almost the same as Instant-ngp NeRF, just replace the main script. +# use deformation to model dynamic scene +python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 +python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui +# use temporal basis to model dynamic scene +python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_basis_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 --basis +python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_basis_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 --basis --gui +# for the hypernerf dataset, first convert it into nerf-compatible format: +python scripts/hyper2nerf.py data/split-cookie --downscale 2 # will generate transforms*.json +python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies -O --bound 1 --scale 0.3 --dt_gamma 0 +``` + +check the `scripts` directory for more provided examples. + +# Performance Reference + +Tested with the default settings on the Lego dataset. +Here the speed refers to the `iterations per second` on a V100. + +| Model | Split | PSNR | Train Speed | Test Speed | +| - | - | - | - | - | +| instant-ngp (paper) | trainval? | 36.39 | - | - | +| instant-ngp (`-O`) | train (30K steps) | 34.15 | 97 | 7.8 | +| instant-ngp (`-O --error_map`) | train (30K steps) | 34.88 | 50 | 7.8 | +| instant-ngp (`-O`) | trainval (40k steps) | 35.22 | 97 | 7.8 | +| instant-ngp (`-O --error_map`) | trainval (40k steps) | 36.00 | 50 | 7.8 | +| TensoRF (paper) | train (30K steps) | 36.46 | - | - | +| TensoRF (`-O`) | train (30K steps) | 35.05 | 51 | 2.8 | +| TensoRF (`-O --error_map`) | train (30K steps) | 35.84 | 14 | 2.8 | + +# Tips + +**Q**: How to choose the network backbone? + +**A**: The `-O` flag which uses pytorch's native mixed precision is suitable for most cases. I don't find very significant improvement for `--tcnn` and `--ff`, and they require extra building. Also, some new features may only be available for the default `-O` mode. + +**Q**: CUDA Out Of Memory for my dataset. + +**A**: You could try to turn off `--preload` which loads all images in to GPU for acceleration (if use `-O`, change it to `--fp16 --cuda_ray`). Another solution is to manually set `downscale` in `NeRFDataset` to lower the image resolution. + +**Q**: How to adjust `bound` and `scale`? + +**A**: You could start with a large `bound` (e.g., 16) or a small `scale` (e.g., 0.3) to make sure the object falls into the bounding box. The GUI mode can be used to interactively shrink the `bound` to find the suitable value. Uncommenting [this line](https://github.com/ashawkey/torch-ngp/blob/main/nerf/provider.py#L219) will visualize the camera poses, and some good examples can be found in [this issue](https://github.com/ashawkey/torch-ngp/issues/59). + +**Q**: Noisy novel views for realistic datasets. + +**A**: You could try setting `bg_radius` to a large value, e.g., 32. It trains an extra environment map to model the background in realistic photos. A larger `bound` will also help. +An example for `bg_radius` in the [firekeeper](https://drive.google.com/file/d/19C0K6_crJ5A9ftHijUmJysxmY-G4DMzq/view?usp=sharing) dataset: +![bg_model](./assets/bg_model.jpg) + + +# Difference from the original implementation + +* Instead of assuming the scene is bounded in the unit box `[0, 1]` and centered at `(0.5, 0.5, 0.5)`, this repo assumes **the scene is bounded in box `[-bound, bound]`, and centered at `(0, 0, 0)`**. Therefore, the functionality of `aabb_scale` is replaced by `bound` here. +* For the hashgrid encoder, this repo only implements the linear interpolation mode. +* For TensoRF, we don't implement regularizations other than L1, and use `trunc_exp` as the density activation instead of `softplus`. The alpha mask pruning is replaced by the density grid sampler from instant-ngp, which shares the same logic for acceleration. + + +# Citation + +If you find this work useful, a citation will be appreciated via: +``` +@misc{torch-ngp, + Author = {Jiaxiang Tang}, + Year = {2022}, + Note = {https://github.com/ashawkey/torch-ngp}, + Title = {Torch-ngp: a PyTorch implementation of instant-ngp} +} + +@article{tang2022compressible, + title = {Compressible-composable NeRF via Rank-residual Decomposition}, + author = {Tang, Jiaxiang and Chen, Xiaokang and Wang, Jingbo and Zeng, Gang}, + journal = {arXiv preprint arXiv:2205.14870}, + year = {2022} +} +``` + +# Acknowledgement + +* Credits to [Thomas Müller](https://tom94.net/) for the amazing [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) and [instant-ngp](https://github.com/NVlabs/instant-ngp): + ``` + @misc{tiny-cuda-nn, + Author = {Thomas M\"uller}, + Year = {2021}, + Note = {https://github.com/nvlabs/tiny-cuda-nn}, + Title = {Tiny {CUDA} Neural Network Framework} + } + + @article{mueller2022instant, + title = {Instant Neural Graphics Primitives with a Multiresolution Hash Encoding}, + author = {Thomas M\"uller and Alex Evans and Christoph Schied and Alexander Keller}, + journal = {arXiv:2201.05989}, + year = {2022}, + month = jan + } + ``` + +* The framework of NeRF is adapted from [nerf_pl](https://github.com/kwea123/nerf_pl): + ``` + @misc{queianchen_nerf, + author = {Quei-An, Chen}, + title = {Nerf_pl: a pytorch-lightning implementation of NeRF}, + url = {https://github.com/kwea123/nerf_pl/}, + year = {2020}, + } + ``` + +* The official TensoRF [implementation](https://github.com/apchenstu/TensoRF): + ``` + @article{TensoRF, + title={TensoRF: Tensorial Radiance Fields}, + author={Chen, Anpei and Xu, Zexiang and Geiger, Andreas and Yu, Jingyi and Su, Hao}, + journal={arXiv preprint arXiv:2203.09517}, + year={2022} + } + ``` + +* The NeRF GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui). diff --git a/torch-ngp/requirements.txt b/torch-ngp/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bbf1fb86d45fc6740ca974097c98f6504bc4fc90 --- /dev/null +++ b/torch-ngp/requirements.txt @@ -0,0 +1,19 @@ +torch-ema +ninja +trimesh +opencv-python +tensorboardX +torch +numpy +pandas +tqdm +matplotlib +PyMCubes +rich +pysdf +dearpygui +packaging +scipy +lpips +imageio +torchmetrics \ No newline at end of file diff --git a/torch-ngp/scripts/colmap2nerf.py b/torch-ngp/scripts/colmap2nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c1b3d44e6c8841f60bc9bfa4f35b6a5f4905d4 --- /dev/null +++ b/torch-ngp/scripts/colmap2nerf.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import argparse +import os +from pathlib import Path + +import numpy as np +import json +import sys +import math +import cv2 +import os +import shutil + +def parse_args(): + parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") + + parser.add_argument("--video", default="", help="input path to the video") + parser.add_argument("--images", default="", help="input path to the images folder, ignored if --video is provided") + parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") + + parser.add_argument("--dynamic", action="store_true", help="for dynamic scene, extraly save time calculated from frame index.") + parser.add_argument("--estimate_affine_shape", action="store_true", help="colmap SiftExtraction option, may yield better results, yet can only be run on CPU.") + parser.add_argument('--hold', type=int, default=8, help="hold out for validation every $ images") + + parser.add_argument("--video_fps", default=3) + parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") + + parser.add_argument("--colmap_matcher", default="exhaustive", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") + parser.add_argument("--skip_early", default=0, help="skip this many images from the start") + + parser.add_argument("--colmap_text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") + parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") + + args = parser.parse_args() + return args + +def do_system(arg): + print(f"==== running: {arg}") + err = os.system(arg) + if err: + print("FATAL: command failed") + sys.exit(err) + +def run_ffmpeg(args): + video = args.video + images = args.images + fps = float(args.video_fps) or 1.0 + + print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") + if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": + sys.exit(1) + + try: + shutil.rmtree(images) + except: + pass + + do_system(f"mkdir {images}") + + time_slice_value = "" + time_slice = args.time_slice + if time_slice: + start, end = time_slice.split(",") + time_slice_value = f",select='between(t\,{start}\,{end})'" + + do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") + +def run_colmap(args): + db = args.colmap_db + images = args.images + text = args.colmap_text + flag_EAS = int(args.estimate_affine_shape) # 0 / 1 + + db_noext = str(Path(db).with_suffix("")) + sparse = db_noext + "_sparse" + + print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") + if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": + sys.exit(1) + if os.path.exists(db): + os.remove(db) + do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape {flag_EAS} --SiftExtraction.domain_size_pooling {flag_EAS} --ImageReader.single_camera 1 --database_path {db} --image_path {images}") + do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching {flag_EAS} --database_path {db}") + try: + shutil.rmtree(sparse) + except: + pass + do_system(f"mkdir {sparse}") + do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") + do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") + try: + shutil.rmtree(text) + except: + pass + do_system(f"mkdir {text}") + do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") + +def variance_of_laplacian(image): + return cv2.Laplacian(image, cv2.CV_64F).var() + +def sharpness(imagePath): + image = cv2.imread(imagePath) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + fm = variance_of_laplacian(gray) + return fm + +def qvec2rotmat(qvec): + return np.array([ + [ + 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] + ], [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] + ], [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 + ] + ]) + +def rotmat(a, b): + a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) + v = np.cross(a, b) + c = np.dot(a, b) + # handle exception for the opposite direction input + if c < -1 + 1e-10: + return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) + s = np.linalg.norm(v) + kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) + return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) + +def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel + da = da / np.linalg.norm(da) + db = db / np.linalg.norm(db) + c = np.cross(da, db) + denom = np.linalg.norm(c)**2 + t = ob - oa + ta = np.linalg.det([t, db, c]) / (denom + 1e-10) + tb = np.linalg.det([t, da, c]) / (denom + 1e-10) + if ta > 0: + ta = 0 + if tb > 0: + tb = 0 + return (oa+ta*da+ob+tb*db) * 0.5, denom + +if __name__ == "__main__": + args = parse_args() + + if args.video != "": + root_dir = os.path.dirname(args.video) + args.images = os.path.join(root_dir, "images") # override args.images + run_ffmpeg(args) + else: + args.images = args.images[:-1] if args.images[-1] == '/' else args.images # remove trailing / (./a/b/ --> ./a/b) + root_dir = os.path.dirname(args.images) + + args.colmap_db = os.path.join(root_dir, args.colmap_db) + args.colmap_text = os.path.join(root_dir, args.colmap_text) + + if args.run_colmap: + run_colmap(args) + + SKIP_EARLY = int(args.skip_early) + TEXT_FOLDER = args.colmap_text + + with open(os.path.join(TEXT_FOLDER, "cameras.txt"), "r") as f: + angle_x = math.pi / 2 + for line in f: + # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 + # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 + # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 + if line[0] == "#": + continue + els = line.split(" ") + w = float(els[2]) + h = float(els[3]) + fl_x = float(els[4]) + fl_y = float(els[4]) + k1 = 0 + k2 = 0 + p1 = 0 + p2 = 0 + cx = w / 2 + cy = h / 2 + if els[1] == "SIMPLE_PINHOLE": + cx = float(els[5]) + cy = float(els[6]) + elif els[1] == "PINHOLE": + fl_y = float(els[5]) + cx = float(els[6]) + cy = float(els[7]) + elif els[1] == "SIMPLE_RADIAL": + cx = float(els[5]) + cy = float(els[6]) + k1 = float(els[7]) + elif els[1] == "RADIAL": + cx = float(els[5]) + cy = float(els[6]) + k1 = float(els[7]) + k2 = float(els[8]) + elif els[1] == "OPENCV": + fl_y = float(els[5]) + cx = float(els[6]) + cy = float(els[7]) + k1 = float(els[8]) + k2 = float(els[9]) + p1 = float(els[10]) + p2 = float(els[11]) + else: + print("unknown camera model ", els[1]) + # fl = 0.5 * w / tan(0.5 * angle_x); + angle_x = math.atan(w / (fl_x * 2)) * 2 + angle_y = math.atan(h / (fl_y * 2)) * 2 + fovx = angle_x * 180 / math.pi + fovy = angle_y * 180 / math.pi + + print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") + + with open(os.path.join(TEXT_FOLDER, "images.txt"), "r") as f: + i = 0 + + bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) + + frames = [] + + up = np.zeros(3) + for line in f: + line = line.strip() + + if line[0] == "#": + continue + + i = i + 1 + if i < SKIP_EARLY*2: + continue + + if i % 2 == 1: + elems = line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) + + name = '_'.join(elems[9:]) + full_name = os.path.join(args.images, name) + rel_name = full_name[len(root_dir) + 1:] + + b = sharpness(full_name) + # print(name, "sharpness =",b) + + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + R = qvec2rotmat(-qvec) + t = tvec.reshape([3, 1]) + m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) + c2w = np.linalg.inv(m) + + c2w[0:3, 2] *= -1 # flip the y and z axis + c2w[0:3, 1] *= -1 + c2w = c2w[[1, 0, 2, 3],:] # swap y and z + c2w[2, :] *= -1 # flip whole world upside down + + up += c2w[0:3, 1] + + frame = { + "file_path": rel_name, + "sharpness": b, + "transform_matrix": c2w + } + + frames.append(frame) + + N = len(frames) + up = up / np.linalg.norm(up) + + print("[INFO] up vector was", up) + + R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] + R = np.pad(R, [0, 1]) + R[-1, -1] = 1 + + for f in frames: + f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis + + # find a central point they are all looking at + print("[INFO] computing center of attention...") + totw = 0.0 + totp = np.array([0.0, 0.0, 0.0]) + for f in frames: + mf = f["transform_matrix"][0:3,:] + for g in frames: + mg = g["transform_matrix"][0:3,:] + p, weight = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) + if weight > 0.01: + totp += p * weight + totw += weight + totp /= totw + for f in frames: + f["transform_matrix"][0:3,3] -= totp + avglen = 0. + for f in frames: + avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) + avglen /= N + print("[INFO] avg camera distance from origin", avglen) + for f in frames: + f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" + + # sort frames by id + frames.sort(key=lambda d: d['file_path']) + + # add time if scene is dynamic + if args.dynamic: + for i, f in enumerate(frames): + f['time'] = i / N + + for f in frames: + f["transform_matrix"] = f["transform_matrix"].tolist() + + # construct frames + + def write_json(filename, frames): + + out = { + "camera_angle_x": angle_x, + "camera_angle_y": angle_y, + "fl_x": fl_x, + "fl_y": fl_y, + "k1": k1, + "k2": k2, + "p1": p1, + "p2": p2, + "cx": cx, + "cy": cy, + "w": w, + "h": h, + "frames": frames, + } + + output_path = os.path.join(root_dir, filename) + print(f"[INFO] writing {len(frames)} frames to {output_path}") + with open(output_path, "w") as outfile: + json.dump(out, outfile, indent=2) + + # just one transforms.json, don't do data split + if args.hold <= 0: + + write_json('transforms.json', frames) + + else: + all_ids = np.arange(N) + test_ids = all_ids[::args.hold] + train_ids = np.array([i for i in all_ids if i not in test_ids]) + + frames_train = [f for i, f in enumerate(frames) if i in train_ids] + frames_test = [f for i, f in enumerate(frames) if i in test_ids] + + write_json('transforms_train.json', frames_train) + write_json('transforms_val.json', frames_test[::10]) + write_json('transforms_test.json', frames_test) \ No newline at end of file diff --git a/torch-ngp/scripts/hyper2nerf.py b/torch-ngp/scripts/hyper2nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..5a868703362dd0792250ba1ae1c997f50a519c3c --- /dev/null +++ b/torch-ngp/scripts/hyper2nerf.py @@ -0,0 +1,224 @@ +import os +import numpy as np +import math +import json +import argparse +import trimesh + + +def visualize_poses(poses, size=0.1): + # poses: [B, 4, 4] + + axes = trimesh.creation.axis(axis_length=4) + box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() + box.colors = np.array([[128, 128, 128]] * len(box.entities)) + objects = [axes, box] + + for pose in poses: + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + + dir = (a + b + c + d) / 4 - pos + dir = dir / (np.linalg.norm(dir) + 1e-8) + o = pos + dir * 3 + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) + segs = trimesh.load_path(segs) + objects.append(segs) + + trimesh.Scene(objects).show() + +# returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel +def closest_point_2_lines(oa, da, ob, db): + da = da / np.linalg.norm(da) + db = db / np.linalg.norm(db) + c = np.cross(da, db) + denom = np.linalg.norm(c)**2 + t = ob - oa + ta = np.linalg.det([t, db, c]) / (denom + 1e-10) + tb = np.linalg.det([t, da, c]) / (denom + 1e-10) + if ta > 0: + ta = 0 + if tb > 0: + tb = 0 + return (oa+ta*da+ob+tb*db) * 0.5, denom + +def rotmat(a, b): + a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) + v = np.cross(a, b) + c = np.dot(a, b) + # handle exception for the opposite direction input + if c < -1 + 1e-10: + return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) + s = np.linalg.norm(v) + kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) + return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="root directory to the HyperNeRF dataset (contains camera/, rgb/, dataset.json, scene.json)") + parser.add_argument('--downscale', type=int, default=2, help="image size down scale, choose from [2, 4, 8, 16], e.g., 8") + parser.add_argument('--interval', type=int, default=4, help="used for interp dataset's train/val split, should > 2 and be even") + + opt = parser.parse_args() + + print(f'[INFO] process {opt.path}') + + # load data + with open(os.path.join(opt.path, 'dataset.json'), 'r') as f: + json_dataset = json.load(f) + + names = json_dataset['ids'] + val_names = json_dataset['val_ids'] + + # data split mode following hypernerf (vrig / interp) + if len(val_names) > 0: + train_names = json_dataset['train_ids'] + val_ids = [] + train_ids = [] + for i, name in enumerate(names): + if name in val_names: + val_ids.append(i) + elif name in train_names: + train_ids.append(i) + else: + all_ids = np.arange(len(names)) + train_ids = all_ids[::opt.interval] + val_ids = (train_ids[:-1] + train_ids[1:]) // 2 + + print(f'[INFO] train_ids: {len(train_ids)}, val_ids: {len(val_ids)}') + + with open(os.path.join(opt.path, 'scene.json'), 'r') as f: + json_scene = json.load(f) + + scale = json_scene['scale'] + center = json_scene['center'] + + with open(os.path.join(opt.path, 'metadata.json'), 'r') as f: + json_meta = json.load(f) + + images = [] + times = [] + poses = [] + H, W, f, cx, cy = None, None, None, None, None + + for name in names: + + # load image + images.append(os.path.join('rgb', f'{opt.downscale}x', f'{name}.png')) + + # load time + times.append(json_meta[name]['time_id']) + + # load pose + with open(os.path.join(opt.path, 'camera', f'{name}.json'), 'r') as f: + cam = json.load(f) + + # TODO: we use a simplified pinhole camera model rather than the original openCV camera model... hope it won't influence results seriously... + + pose = np.eye(4, 4) + pose[:3, :3] = np.array(cam['orientation']).T # it works... + #pose[:3, 3] = (np.array(cam['position']) - center) * scale * 4 + pose[:3, 3] = np.array(cam['position']) + + # CHECK: simply assume all intrinsic are same ? + W, H = cam['image_size'] # before scale + cx, cy = cam['principal_point'] + fl = cam['focal_length'] + + poses.append(pose) + + poses = np.stack(poses, axis=0) # [N, 4, 4] + times = np.asarray(times, dtype=np.float32) # [N] + times = times / times.max() # normalize to [0, 1] + + N = len(images) + + W = W // opt.downscale + H = H // opt.downscale + cx = cx / opt.downscale + cy = cy / opt.downscale + fl = fl / opt.downscale + + print(f'[INFO] H = {H}, W = {W}, fl = {fl} (downscale = {opt.downscale})') + + # visualize_poses(poses) + + # the following stuff are from colmap2nerf... + poses[:, 0:3, 1] *= -1 + poses[:, 0:3, 2] *= -1 + poses = poses[:, [1, 0, 2, 3], :] # swap y and z + poses[:, 2, :] *= -1 # flip whole world upside down + + up = poses[:, 0:3, 1].sum(0) + up = up / np.linalg.norm(up) + R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] + R = np.pad(R, [0, 1]) + R[-1, -1] = 1 + + poses = R @ poses + + totw = 0.0 + totp = np.array([0.0, 0.0, 0.0]) + for i in range(N): + mf = poses[i, :3, :] + for j in range(i + 1, N): + mg = poses[j, :3, :] + p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) + #print(i, j, p, w) + if w > 0.01: + totp += p * w + totw += w + totp /= totw + print(f'[INFO] totp = {totp}') + poses[:, :3, 3] -= totp + avglen = np.linalg.norm(poses[:, :3, 3], axis=-1).mean() + poses[:, :3, 3] *= 4.0 / avglen + print(f'[INFO] average radius = {avglen}') + + # visualize_poses(poses) + + # construct frames + frames_train = [] + for i in train_ids: + frames_train.append({ + 'file_path': images[i], + 'time': float(times[i]), + 'transform_matrix': poses[i].tolist(), + }) + + frames_val = [] + for i in val_ids: + frames_val.append({ + 'file_path': images[i], + 'time': float(times[i]), + 'transform_matrix': poses[i].tolist(), + }) + + def write_json(filename, frames): + + # construct a transforms.json + out = { + 'w': W, + 'h': H, + 'fl_x': fl, + 'fl_y': fl, + 'cx': cx, + 'cy': cy, + 'frames': frames, + } + + # write + output_path = os.path.join(opt.path, filename) + print(f'[INFO] write {len(frames)} images to {output_path}') + with open(output_path, 'w') as f: + json.dump(out, f, indent=2) + + write_json('transforms_train.json', frames_train) + write_json('transforms_val.json', frames_val[::10]) + write_json('transforms_test.json', frames_val) \ No newline at end of file diff --git a/torch-ngp/scripts/install_ext.sh b/torch-ngp/scripts/install_ext.sh new file mode 100644 index 0000000000000000000000000000000000000000..e34808ab8f8af19c8cec69ec97ae232e42efe85e --- /dev/null +++ b/torch-ngp/scripts/install_ext.sh @@ -0,0 +1,10 @@ +pip install ./raymarching + +pip install ./gridencoder + +pip install ./shencoder + +pip install ./freqencoder + +# turned off by default, very slow to compile, and performance is not good enough. +#pip install ./ffmlp \ No newline at end of file diff --git a/torch-ngp/scripts/llff2nerf.py b/torch-ngp/scripts/llff2nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..ac39a72b363583922267f1d653d044cd93fd3fa9 --- /dev/null +++ b/torch-ngp/scripts/llff2nerf.py @@ -0,0 +1,183 @@ +import os +import glob +import numpy as np +import math +import json +import trimesh +import argparse + +# returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel +def closest_point_2_lines(oa, da, ob, db): + da = da / np.linalg.norm(da) + db = db / np.linalg.norm(db) + c = np.cross(da, db) + denom = np.linalg.norm(c)**2 + t = ob - oa + ta = np.linalg.det([t, db, c]) / (denom + 1e-10) + tb = np.linalg.det([t, da, c]) / (denom + 1e-10) + if ta > 0: + ta = 0 + if tb > 0: + tb = 0 + return (oa+ta*da+ob+tb*db) * 0.5, denom + +def rotmat(a, b): + a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) + v = np.cross(a, b) + c = np.dot(a, b) + # handle exception for the opposite direction input + if c < -1 + 1e-10: + return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) + s = np.linalg.norm(v) + kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) + return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) + + +def visualize_poses(poses, size=0.1): + # poses: [B, 4, 4] + + axes = trimesh.creation.axis(axis_length=4) + box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() + box.colors = np.array([[128, 128, 128]] * len(box.entities)) + objects = [axes, box] + + for pose in poses: + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + + dir = (a + b + c + d) / 4 - pos + dir = dir / (np.linalg.norm(dir) + 1e-8) + o = pos + dir * 3 + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) + segs = trimesh.load_path(segs) + objects.append(segs) + + trimesh.Scene(objects).show() + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="root directory to the LLFF dataset (contains images/ and pose_bounds.npy)") + parser.add_argument('--images', type=str, default='images_8', help="images folder (do not include full path, e.g., just use `images_4`)") + parser.add_argument('--downscale', type=float, default=8, help="image size down scale, e.g., 4") + parser.add_argument('--hold', type=int, default=8, help="hold out for validation every $ images") + + opt = parser.parse_args() + print(f'[INFO] process {opt.path}') + + # path must end with / to make sure image path is relative + if opt.path[-1] != '/': + opt.path += '/' + + # load data + images = [f[len(opt.path):] for f in sorted(glob.glob(os.path.join(opt.path, opt.images, "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')] + + poses_bounds = np.load(os.path.join(opt.path, 'poses_bounds.npy')) + N = poses_bounds.shape[0] + + print(f'[INFO] loaded {len(images)} images, {N} poses_bounds as {poses_bounds.shape}') + + assert N == len(images) + + poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N, 3, 5) + bounds = poses_bounds[:, -2:] # (N, 2) + + H, W, fl = poses[0, :, -1] + + H = H // opt.downscale + W = W // opt.downscale + fl = fl / opt.downscale + + print(f'[INFO] H = {H}, W = {W}, fl = {fl} (downscale = {opt.downscale})') + + # inversion of this: https://github.com/Fyusion/LLFF/blob/c6e27b1ee59cb18f054ccb0f87a90214dbe70482/llff/poses/pose_utils.py#L51 + poses = np.concatenate([poses[..., 1:2], poses[..., 0:1], -poses[..., 2:3], poses[..., 3:4]], -1) # (N, 3, 4) + + # to homogeneous + last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N, 1, 4) + poses = np.concatenate([poses, last_row], axis=1) # (N, 4, 4) + + # visualize_poses(poses) + + # the following stuff are from colmap2nerf... [flower fails, the camera must be in-ward...] + poses[:, 0:3, 1] *= -1 + poses[:, 0:3, 2] *= -1 + poses = poses[:, [1, 0, 2, 3], :] # swap y and z + poses[:, 2, :] *= -1 # flip whole world upside down + + up = poses[:, 0:3, 1].sum(0) + up = up / np.linalg.norm(up) + R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] + R = np.pad(R, [0, 1]) + R[-1, -1] = 1 + + poses = R @ poses + + totw = 0.0 + totp = np.array([0.0, 0.0, 0.0]) + for i in range(N): + mf = poses[i, :3, :] + for j in range(i + 1, N): + mg = poses[j, :3, :] + p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) + #print(i, j, p, w) + if w > 0.01: + totp += p * w + totw += w + totp /= totw + print(f'[INFO] totp = {totp}') + poses[:, :3, 3] -= totp + avglen = np.linalg.norm(poses[:, :3, 3], axis=-1).mean() + poses[:, :3, 3] *= 4.0 / avglen + print(f'[INFO] average radius = {avglen}') + + # visualize_poses(poses) + + # construct frames + + all_ids = np.arange(N) + test_ids = all_ids[::opt.hold] + train_ids = np.array([i for i in all_ids if i not in test_ids]) + + frames_train = [] + frames_test = [] + for i in train_ids: + frames_train.append({ + 'file_path': images[i], + 'transform_matrix': poses[i].tolist(), + }) + for i in test_ids: + frames_test.append({ + 'file_path': images[i], + 'transform_matrix': poses[i].tolist(), + }) + + def write_json(filename, frames): + + # construct a transforms.json + out = { + 'w': W, + 'h': H, + 'fl_x': fl, + 'fl_y': fl, + 'cx': W // 2, + 'cy': H // 2, + 'aabb_scale': 2, + 'frames': frames, + } + + # write + output_path = os.path.join(opt.path, filename) + print(f'[INFO] write {len(frames)} images to {output_path}') + with open(output_path, 'w') as f: + json.dump(out, f, indent=2) + + write_json('transforms_train.json', frames_train) + write_json('transforms_val.json', frames_test[::10]) + write_json('transforms_test.json', frames_test) + diff --git a/torch-ngp/scripts/run_ccnerf.sh b/torch-ngp/scripts/run_ccnerf.sh new file mode 100644 index 0000000000000000000000000000000000000000..5caec7898f718dabeababbddba0f5bc27bfdfb2e --- /dev/null +++ b/torch-ngp/scripts/run_ccnerf.sh @@ -0,0 +1,10 @@ +#! /bin/bash + +# train single objects +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/ficus --workspace trial_cc_ficus -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/chair --workspace trial_cc_chair -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map + +# compose +# use more samples per ray (--max_steps) and a larger bound for better results +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 2.0 --scale 0.67 --dt_gamma 0 --max_steps 2048 --test --compose \ No newline at end of file diff --git a/torch-ngp/scripts/run_dnerf.sh b/torch-ngp/scripts/run_dnerf.sh new file mode 100644 index 0000000000000000000000000000000000000000..68d3c850b75eec349d38ade446820a03a88c8167 --- /dev/null +++ b/torch-ngp/scripts/run_dnerf.sh @@ -0,0 +1,11 @@ +#! /bin/bash + +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=6 python main_dnerf.py data/dnerf/bouncingballs --workspace trial_dnerf_bouncingballs -O --bound 1 --scale 0.8 --dt_gamma 0 #--gui --test +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=6 python main_dnerf.py data/dnerf/bouncingballs --workspace trial_dnerf_basis_bouncingballs -O --bound 1 --scale 0.8 --dt_gamma 0 --basis #--gui --test + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=6 python main_dnerf.py data/dnerf/standup --workspace trial_dnerf_standup -O --bound 1 --scale 0.8 --dt_gamma 0 #--gui --test + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies -O --bound 1 --scale 0.3 #--gui --test +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies_ncr --preload --fp16 --bound 1 --scale 0.3 #--gui --test + +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=4 python main_dnerf.py data/vrig-3dprinter/ --workspace trial_dnerf_printer -O --bound 2 --scale 0.33 #--gui --test \ No newline at end of file diff --git a/torch-ngp/scripts/run_gui_nerf.sh b/torch-ngp/scripts/run_gui_nerf.sh new file mode 100644 index 0000000000000000000000000000000000000000..d91e24914d2c4945e1f8efe3921829cab90df784 --- /dev/null +++ b/torch-ngp/scripts/run_gui_nerf.sh @@ -0,0 +1,9 @@ +#! /bin/bash + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_fox -O --gui #--error_map +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui #--error_map +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_llff_data/orchids --workspace trial_nerf_orchids -O --gui --bound 2.0 --scale 0.6 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/TanksAndTemple/Family --workspace trial_nerf_family -O --bound 1.0 --scale 0.33 --dt_gamma 0 --gui + +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/figure --workspace trial_nerf_fig -O --gui --bound 1.0 --scale 0.3 --bg_radius 128 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=5 python main_nerf.py data/vasedeck --workspace trial_nerf_vase -O --gui --bound 4.0 --scale 0.3 \ No newline at end of file diff --git a/torch-ngp/scripts/run_gui_nerf_clip.sh b/torch-ngp/scripts/run_gui_nerf_clip.sh new file mode 100644 index 0000000000000000000000000000000000000000..8531a7b3309e5645198953bc608c699a696b02b4 --- /dev/null +++ b/torch-ngp/scripts/run_gui_nerf_clip.sh @@ -0,0 +1,7 @@ +#! /bin/bash + +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1.0 --scale 0.67 --dt_gamma 0 --gui +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1.0 --scale 0.67 --dt_gamma 0 --gui --rand_pose 0 --clip_text "red" --lr 1e-3 --ckpt latest_model + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_llff_data/orchids --workspace trial_nerf_orchids -O --gui --bound 2.0 --scale 0.6 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_llff_data/orchids --workspace trial_nerf_orchids -O --gui --bound 2.0 --scale 0.6 --rand_pose 0 --clip_text "blue flower" --lr 1e-3 --ckpt latest_model \ No newline at end of file diff --git a/torch-ngp/scripts/run_gui_tensoRF.sh b/torch-ngp/scripts/run_gui_tensoRF.sh new file mode 100644 index 0000000000000000000000000000000000000000..22e9f92816bbd0f7c4010ad4e78e1b2aed585901 --- /dev/null +++ b/torch-ngp/scripts/run_gui_tensoRF.sh @@ -0,0 +1,9 @@ +#! /bin/bash + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensoRF_fox -O --gui +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensorCP_fox --cp -O --gui +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensorCP_lego --cp -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui + +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/figure --workspace trial_tensoRF_fig -O --gui --scale 0.33 --bound 1.0 --bg_radius 32 \ No newline at end of file diff --git a/torch-ngp/scripts/run_nerf.sh b/torch-ngp/scripts/run_nerf.sh new file mode 100644 index 0000000000000000000000000000000000000000..1ad6e73d26b8ce3a47fb6382b5add8d254b9fc2e --- /dev/null +++ b/torch-ngp/scripts/run_nerf.sh @@ -0,0 +1,12 @@ +#! /bin/bash + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/fox --workspace trial_nerf_fox -O +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1 --scale 0.8 --dt_gamma 0 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_emap -O --bound 1 --scale 0.8 --dt_gamma 0 --error_map +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/TanksAndTemple/Barn --workspace trial_nerf_barn -O --bound 1.0 --scale 0.33 --dt_gamma 0 + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/firekeeper --workspace trial_nerf_firekeeper_bg_32 -O --bound 1.0 --scale 0.33 --bg_radius 32 #--gui #--test +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/garden --workspace trial_nerf_garden_bound_16 --cuda_ray --fp16 --bound 16.0 --scale 0.33 #--gui --test + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/vasedeck --workspace trial_nerf_vasedeck -O --bound 4.0 --scale 0.33 #--gui #--test +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/vasedeck --workspace trial_nerf_vasedeck_bg_32 -O --bound 4.0 --scale 0.33 --bg_radius 32 #--gui #--test \ No newline at end of file diff --git a/torch-ngp/scripts/run_sdf.sh b/torch-ngp/scripts/run_sdf.sh new file mode 100644 index 0000000000000000000000000000000000000000..1cd50eaf37dbea51b993a5fb5f41b7a7a6f64b54 --- /dev/null +++ b/torch-ngp/scripts/run_sdf.sh @@ -0,0 +1,7 @@ +#! /bin/bash + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/armadillo.obj --workspace trial_sdf_ff --fp16 --ff +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/armadillo.obj --workspace trial_sdf_tcnn --fp16 --tcnn + +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/lucy.obj --workspace trial_sdf --fp16 \ No newline at end of file diff --git a/torch-ngp/scripts/run_tensoRF.sh b/torch-ngp/scripts/run_tensoRF.sh new file mode 100644 index 0000000000000000000000000000000000000000..8464b60e3dde268e057eec2c1c9a0c089fde7488 --- /dev/null +++ b/torch-ngp/scripts/run_tensoRF.sh @@ -0,0 +1,11 @@ +#! /bin/bash + + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensoRF_fox -O --error_map +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego_emap -O --bound 1.0 --scale 0.8 --dt_gamma 0 --error_map + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensorCP_fox -O --cp --resolution1 500 --error_map +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensorCP_lego --cp --resolution1 500 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --error_map + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/figure --workspace trial_tensoRF_fig -O --scale 0.33 --bound 1.0 \ No newline at end of file diff --git a/torch-ngp/scripts/tanks2nerf.py b/torch-ngp/scripts/tanks2nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..188dd87707bcc69c035185cf2604aa46cb8050cb --- /dev/null +++ b/torch-ngp/scripts/tanks2nerf.py @@ -0,0 +1,140 @@ +import os +import numpy as np +import math +import json + +import argparse + +# returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel +def closest_point_2_lines(oa, da, ob, db): + da = da / np.linalg.norm(da) + db = db / np.linalg.norm(db) + c = np.cross(da, db) + denom = np.linalg.norm(c)**2 + t = ob - oa + ta = np.linalg.det([t, db, c]) / (denom + 1e-10) + tb = np.linalg.det([t, da, c]) / (denom + 1e-10) + if ta > 0: + ta = 0 + if tb > 0: + tb = 0 + return (oa+ta*da+ob+tb*db) * 0.5, denom + +def rotmat(a, b): + a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) + v = np.cross(a, b) + c = np.dot(a, b) + # handle exception for the opposite direction input + if c < -1 + 1e-10: + return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) + s = np.linalg.norm(v) + kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) + return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="root directory to the Tanks&Temple dataset (contains rgb/, pose/, intrinsics.txt)") + + opt = parser.parse_args() + print(f'[INFO] process {opt.path}') + + # load data + + intrinsics = np.loadtxt(os.path.join(opt.path, "intrinsics.txt")) + fl_x = intrinsics[0, 0] + fl_y = intrinsics[1, 1] + cx = intrinsics[0, 2] + cy = intrinsics[1, 2] + H = 1080 + W = 1920 + + pose_files = sorted(os.listdir(os.path.join(opt.path, 'pose'))) + img_files = sorted(os.listdir(os.path.join(opt.path, 'rgb'))) + + # read in all poses, and do transform + poses = [] + for pose_f in pose_files: + pose = np.loadtxt(os.path.join(opt.path, 'pose', pose_f)) # [4, 4] + poses.append(pose) + + poses = np.stack(poses, axis=0) # [N, 4, 4] + N = poses.shape[0] + + # the following stuff are from colmap2nerf... + poses[:, 0:3, 1] *= -1 + poses[:, 0:3, 2] *= -1 + poses = poses[:, [1, 0, 2, 3], :] # swap y and z + poses[:, 2, :] *= -1 # flip whole world upside down + + up = poses[:, 0:3, 1].sum(0) + up = up / np.linalg.norm(up) + R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] + R = np.pad(R, [0, 1]) + R[-1, -1] = 1 + + poses = R @ poses + + totw = 0.0 + totp = np.array([0.0, 0.0, 0.0]) + for i in range(N): + mf = poses[i, :3, :] + for j in range(i + 1, N): + mg = poses[j, :3, :] + p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) + #print(i, j, p, w) + if w > 0.01: + totp += p * w + totw += w + totp /= totw + print(f'[INFO] totp = {totp}') + poses[:, :3, 3] -= totp + + avglen = np.linalg.norm(poses[:, :3, 3], axis=-1).mean() + + poses[:, :3, 3] *= 4.0 / avglen + + print(f'[INFO] average radius = {avglen}') + + # process three splits + for split, prefix in zip(['train', 'val', 'test'], ['0_', '1_', '2_']): + + print(f'[INFO] process split = {split}') + + split_poses = [poses[i] for i, x in enumerate(pose_files) if x.startswith(prefix)] + split_images = [x for x in img_files if x.startswith(prefix)] + + if len(split_poses) == 0: + print(f'[INFO] No test data found, use valid as test') + split_poses = [poses[i] for i, x in enumerate(pose_files) if x.startswith('1_')] + split_images = [x for x in img_files if x.startswith('1_')] + + print(f'[INFO] loaded {len(split_images)} images, {len(split_poses)} poses.') + + assert len(split_poses) == len(split_images) + + # construct a transforms.json + frames = [] + for image, pose in zip(split_images, split_poses): + frames.append({ + 'file_path': os.path.join('rgb', image), + 'transform_matrix': pose.tolist(), + }) + + transforms = { + 'w': W, + 'h': H, + 'fl_x': fl_x, + 'fl_y': fl_y, + 'cx': cx, + 'cy': cy, + 'aabb_scale': 2, + 'frames': frames, + } + + # write + output_path = os.path.join(opt.path, f'transforms_{split}.json') + print(f'[INFO] write to {output_path}') + with open(output_path, 'w') as f: + json.dump(transforms, f, indent=2) + diff --git a/torch-ngp/sdf/netowrk.py b/torch-ngp/sdf/netowrk.py new file mode 100644 index 0000000000000000000000000000000000000000..67b55f631d6bbf2777c635b3318ba34b0108bf52 --- /dev/null +++ b/torch-ngp/sdf/netowrk.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder + + +class SDFNetwork(nn.Module): + def __init__(self, + encoding="hashgrid", + num_layers=3, + skips=[], + hidden_dim=64, + clip_sdf=None, + ): + super().__init__() + + + self.num_layers = num_layers + self.skips = skips + self.hidden_dim = hidden_dim + self.clip_sdf = clip_sdf + + self.encoder, self.in_dim = get_encoder(encoding) + + backbone = [] + + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + elif l in self.skips: + in_dim = self.hidden_dim + self.in_dim + else: + in_dim = self.hidden_dim + + if l == num_layers - 1: + out_dim = 1 + else: + out_dim = self.hidden_dim + + backbone.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.backbone = nn.ModuleList(backbone) + + + def forward(self, x): + # x: [B, 3] + + x = self.encoder(x) + + h = x + for l in range(self.num_layers): + if l in self.skips: + h = torch.cat([h, x], dim=-1) + h = self.backbone[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + if self.clip_sdf is not None: + h = h.clamp(-self.clip_sdf, self.clip_sdf) + + return h \ No newline at end of file diff --git a/torch-ngp/sdf/netowrk_ff.py b/torch-ngp/sdf/netowrk_ff.py new file mode 100644 index 0000000000000000000000000000000000000000..602eccb860709c9a2488a368241bf701da9f7ff2 --- /dev/null +++ b/torch-ngp/sdf/netowrk_ff.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from ffmlp import FFMLP + + +class SDFNetwork(nn.Module): + def __init__(self, + encoding="hashgrid", + num_layers=3, + skips=[], + hidden_dim=64, + clip_sdf=None, + ): + super().__init__() + + + self.num_layers = num_layers + self.skips = skips + self.hidden_dim = hidden_dim + self.clip_sdf = clip_sdf + + assert self.skips == [], 'FFMLP does not support concatenating inside, please use skips=[].' + + self.encoder, self.in_dim = get_encoder(encoding) + + self.backbone = FFMLP( + input_dim=self.in_dim, + output_dim=1, + hidden_dim=self.hidden_dim, + num_layers=self.num_layers, + ) + + + def forward(self, x): + # x: [B, 3] + + x = self.encoder(x) + + h = self.backbone(x) + + if self.clip_sdf is not None: + h = h.clamp(-self.clip_sdf, self.clip_sdf) + + return h \ No newline at end of file diff --git a/torch-ngp/sdf/network_tcnn.py b/torch-ngp/sdf/network_tcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2eafcd382ed5a5c209b48c1a5359d1d3e7bb84 --- /dev/null +++ b/torch-ngp/sdf/network_tcnn.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import tinycudann as tcnn + +class SDFNetwork(nn.Module): + def __init__(self, + encoding="hashgrid", + num_layers=3, + skips=[], + hidden_dim=64, + clip_sdf=None, + ): + super().__init__() + + + self.num_layers = num_layers + self.skips = skips + self.hidden_dim = hidden_dim + self.clip_sdf = clip_sdf + + assert self.skips == [], 'TCNN does not support concatenating inside, please use skips=[].' + + self.encoder = tcnn.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "HashGrid", + "n_levels": 16, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": 16, + "per_level_scale": 1.3819, + }, + ) + + self.backbone = tcnn.Network( + n_input_dims=32, + n_output_dims=1, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim, + "n_hidden_layers": num_layers - 1, + }, + ) + + + def forward(self, x): + # x: [B, 3] + + x = (x + 1) / 2 # to [0, 1] + x = self.encoder(x) + h = self.backbone(x) + + if self.clip_sdf is not None: + h = h.clamp(-self.clip_sdf, self.clip_sdf) + + return h \ No newline at end of file diff --git a/torch-ngp/sdf/provider.py b/torch-ngp/sdf/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7a39f800ae629bc092f1725fd5e4d26294f479 --- /dev/null +++ b/torch-ngp/sdf/provider.py @@ -0,0 +1,88 @@ +import numpy as np + +import torch +from torch.utils.data import Dataset + +import trimesh +import pysdf + +def map_color(value, cmap_name='viridis', vmin=None, vmax=None): + # value: [N], float + # return: RGB, [N, 3], float in [0, 1] + import matplotlib.cm as cm + if vmin is None: vmin = value.min() + if vmax is None: vmax = value.max() + value = (value - vmin) / (vmax - vmin) # range in [0, 1] + cmap = cm.get_cmap(cmap_name) + rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb + return rgb + +def plot_pointcloud(pc, sdfs): + # pc: [N, 3] + # sdfs: [N, 1] + color = map_color(sdfs.squeeze(1)) + pc = trimesh.PointCloud(pc, color) + trimesh.Scene([pc]).show() + +# SDF dataset +class SDFDataset(Dataset): + def __init__(self, path, size=100, num_samples=2**18, clip_sdf=None): + super().__init__() + self.path = path + + # load obj + self.mesh = trimesh.load(path, force='mesh') + + # normalize to [-1, 1] (different from instant-sdf where is [0, 1]) + vs = self.mesh.vertices + vmin = vs.min(0) + vmax = vs.max(0) + v_center = (vmin + vmax) / 2 + v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95 + vs = (vs - v_center[None, :]) * v_scale + self.mesh.vertices = vs + + print(f"[INFO] mesh: {self.mesh.vertices.shape} {self.mesh.faces.shape}") + + if not self.mesh.is_watertight: + print(f"[WARN] mesh is not watertight! SDF maybe incorrect.") + #trimesh.Scene([self.mesh]).show() + + self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces) + + self.num_samples = num_samples + assert self.num_samples % 8 == 0, "num_samples must be divisible by 8." + self.clip_sdf = clip_sdf + + self.size = size + + + def __len__(self): + return self.size + + def __getitem__(self, _): + + # online sampling + sdfs = np.zeros((self.num_samples, 1)) + # surface + points_surface = self.mesh.sample(self.num_samples * 7 // 8) + # perturb surface + points_surface[self.num_samples // 2:] += 0.01 * np.random.randn(self.num_samples * 3 // 8, 3) + # random + points_uniform = np.random.rand(self.num_samples // 8, 3) * 2 - 1 + points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32) + + sdfs[self.num_samples // 2:] = -self.sdf_fn(points[self.num_samples // 2:])[:,None].astype(np.float32) + + # clip sdf + if self.clip_sdf is not None: + sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf) + + results = { + 'sdfs': sdfs, + 'points': points, + } + + #plot_pointcloud(points, sdfs) + + return results diff --git a/torch-ngp/sdf/utils.py b/torch-ngp/sdf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26735b10543604573ed5e92cfbdde797994680 --- /dev/null +++ b/torch-ngp/sdf/utils.py @@ -0,0 +1,563 @@ +import os +import glob +import tqdm +import random +import warnings +import tensorboardX + +import numpy as np +import pandas as pd + +import time +from datetime import datetime + +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader + +import trimesh +import mcubes +from rich.console import Console +from torch_ema import ExponentialMovingAverage + +import packaging + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if packaging.version.parse(torch.__version__) < packaging.version.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + +def extract_fields(bound_min, bound_max, resolution, query_func): + N = 64 + X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = custom_meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3] + val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [N, 1] --> [x, y, z] + u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val + return u + + +def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): + #print('threshold: {}'.format(threshold)) + u = extract_fields(bound_min, bound_max, resolution, query_func) + + #print(u.shape, u.max(), u.min(), np.percentile(u, 50)) + + vertices, triangles = mcubes.marching_cubes(u, threshold) + + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles + + + +class Trainer(object): + def __init__(self, + name, # name of this experiment + model, # network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + eval_interval=1, # eval once every $ epoch + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metirc + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.name = name + self.mute = mute + self.metrics = metrics + self.local_rank = local_rank + self.world_size = world_size + self.workspace = workspace + self.ema_decay = ema_decay + self.fp16 = fp16 + self.best_mode = best_mode + self.use_loss_as_metric = use_loss_as_metric + self.report_metric_at_train = report_metric_at_train + self.max_keep_ckpt = max_keep_ckpt + self.eval_interval = eval_interval + self.use_checkpoint = use_checkpoint + self.use_tensorboardX = use_tensorboardX + self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") + self.scheduler_update_every_step = scheduler_update_every_step + self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.console = Console() + + model.to(self.device) + if self.world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + self.model = model + + if isinstance(criterion, nn.Module): + criterion.to(self.device) + self.criterion = criterion + + if optimizer is None: + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam + else: + self.optimizer = optimizer(self.model) + + if lr_scheduler is None: + self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + else: + self.lr_scheduler = lr_scheduler(self.optimizer) + + if ema_decay is not None: + self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) + else: + self.ema = None + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + # variable init + self.epoch = 0 + self.global_step = 0 + self.local_step = 0 + self.stats = { + "loss": [], + "valid_loss": [], + "results": [], # metrics[0], or valid_loss + "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt + "best_result": None, + } + + # auto fix + if len(metrics) == 0 or self.use_loss_as_metric: + self.best_mode = 'min' + + # workspace prepare + self.log_ptr = None + if self.workspace is not None: + os.makedirs(self.workspace, exist_ok=True) + self.log_path = os.path.join(workspace, f"log_{self.name}.txt") + self.log_ptr = open(self.log_path, "a+") + + self.ckpt_path = os.path.join(self.workspace, 'checkpoints') + self.best_path = f"{self.ckpt_path}/{self.name}.pth.tar" + os.makedirs(self.ckpt_path, exist_ok=True) + + self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') + self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + + if self.workspace is not None: + if self.use_checkpoint == "scratch": + self.log("[INFO] Training from scratch ...") + elif self.use_checkpoint == "latest": + self.log("[INFO] Loading latest checkpoint ...") + self.load_checkpoint() + elif self.use_checkpoint == "best": + if os.path.exists(self.best_path): + self.log("[INFO] Loading best checkpoint ...") + self.load_checkpoint(self.best_path) + else: + self.log(f"[INFO] {self.best_path} not found, loading latest ...") + self.load_checkpoint() + else: # path to ckpt + self.log(f"[INFO] Loading {self.use_checkpoint} ...") + self.load_checkpoint(self.use_checkpoint) + + def __del__(self): + if self.log_ptr: + self.log_ptr.close() + + def log(self, *args, **kwargs): + if self.local_rank == 0: + if not self.mute: + #print(*args) + self.console.print(*args, **kwargs) + if self.log_ptr: + print(*args, file=self.log_ptr) + self.log_ptr.flush() # write immediately to file + + ### ------------------------------ + + def train_step(self, data): + # assert batch_size == 1 + X = data["points"][0] # [B, 3] + y = data["sdfs"][0] # [B] + + pred = self.model(X) + loss = self.criterion(pred, y) + + return pred, y, loss + + def eval_step(self, data): + return self.train_step(data) + + def test_step(self, data): + X = data["points"][0] + pred = self.model(X) + return pred + + def save_mesh(self, save_path=None, resolution=256): + + if save_path is None: + save_path = os.path.join(self.workspace, 'validation', f'{self.name}_{self.epoch}.ply') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + def query_func(pts): + pts = pts.to(self.device) + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + sdfs = self.model(pts) + return sdfs + + bounds_min = torch.FloatTensor([-1, -1, -1]) + bounds_max = torch.FloatTensor([1, 1, 1]) + + vertices, triangles = extract_geometry(bounds_min, bounds_max, resolution=resolution, threshold=0, query_func=query_func) + + mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... + mesh.export(save_path) + + self.log(f"==> Finished saving mesh.") + + ### ------------------------------ + + def train(self, train_loader, valid_loader, max_epochs): + if self.use_tensorboardX and self.local_rank == 0: + self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) + + for epoch in range(self.epoch + 1, max_epochs + 1): + self.epoch = epoch + + self.train_one_epoch(train_loader) + + if self.workspace is not None and self.local_rank == 0: + self.save_checkpoint(full=True, best=False) + + if self.epoch % self.eval_interval == 0: + self.evaluate_one_epoch(valid_loader) + self.save_mesh() + self.save_checkpoint(full=False, best=True) + + if self.use_tensorboardX and self.local_rank == 0: + self.writer.close() + + def evaluate(self, loader): + #if os.path.exists(self.best_path): + # self.load_checkpoint(self.best_path) + #else: + # self.load_checkpoint() + self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX + self.evaluate_one_epoch(loader) + self.use_tensorboardX = use_tensorboardX + + + + def prepare_data(self, data): + if isinstance(data, list): + for i, v in enumerate(data): + if isinstance(v, np.ndarray): + data[i] = torch.from_numpy(v).to(self.device, non_blocking=True) + if torch.is_tensor(v): + data[i] = v.to(self.device, non_blocking=True) + elif isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, np.ndarray): + data[k] = torch.from_numpy(v).to(self.device, non_blocking=True) + if torch.is_tensor(v): + data[k] = v.to(self.device, non_blocking=True) + elif isinstance(data, np.ndarray): + data = torch.from_numpy(data).to(self.device, non_blocking=True) + else: # is_tensor, or other similar objects that has `to` + data = data.to(self.device, non_blocking=True) + + return data + + def train_one_epoch(self, loader): + self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + for data in loader: + + self.local_step += 1 + self.global_step += 1 + + data = self.prepare_data(data) + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.ema is not None: + self.ema.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + + if self.local_rank == 0: + if self.report_metric_at_train: + for metric in self.metrics: + metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar("train/loss", loss_val, self.global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + self.log(f"==> Finished Epoch {self.epoch}.") + + + def evaluate_one_epoch(self, loader): + self.log(f"++> Evaluate at epoch {self.epoch} ...") + + total_loss = 0 + if self.local_rank == 0: + for metric in self.metrics: + metric.clear() + + self.model.eval() + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + with torch.no_grad(): + self.local_step = 0 + for data in loader: + self.local_step += 1 + + data = self.prepare_data(data) + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.eval_step(data) + + if self.ema is not None: + self.ema.restore() + + # all_gather/reduce the statistics (NCCL only support all_*) + if self.world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + loss = loss / self.world_size + + preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_list, preds) + preds = torch.cat(preds_list, dim=0) + + truths_list = [torch.zeros_like(truths).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(truths_list, truths) + truths = torch.cat(truths_list, dim=0) + + loss_val = loss.item() + total_loss += loss_val + + # only rank = 0 will perform evaluation. + if self.local_rank == 0: + + for metric in self.metrics: + metric.update(preds, truths) + + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + average_loss = total_loss / self.local_step + self.stats["valid_loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if not self.use_loss_as_metric and len(self.metrics) > 0: + result = self.metrics[0].measure() + self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result + else: + self.stats["results"].append(average_loss) # if no metric, choose best by min loss + + for metric in self.metrics: + self.log(metric.report(), style="blue") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="evaluate") + metric.clear() + + self.log(f"++> Evaluate epoch {self.epoch} Finished.") + + def save_checkpoint(self, full=False, best=False): + + state = { + 'epoch': self.epoch, + 'stats': self.stats, + } + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{self.ckpt_path}/{self.name}_ep{self.epoch:04d}.pth.tar" + + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = self.stats["checkpoints"].pop(0) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, file_path) + + else: + if len(self.stats["results"]) > 0: + if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]: + self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") + self.stats["best_result"] = self.stats["results"][-1] + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth.tar')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log("[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + if 'model' not in checkpoint_dict: + self.model.load_state_dict(checkpoint_dict) + self.log("[INFO] loaded model.") + return + + missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + self.ema.load_state_dict(checkpoint_dict['ema']) + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer, use default.") + + # strange bug: keyerror 'lr_lambdas' + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler, use default.") + + if 'scaler' in checkpoint_dict: + self.scaler.load_state_dict(checkpoint_dict['scaler']) \ No newline at end of file diff --git a/torch-ngp/shencoder/__init__.py b/torch-ngp/shencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b55c96efe1a86b2da660ccd961e48a8adc3803f --- /dev/null +++ b/torch-ngp/shencoder/__init__.py @@ -0,0 +1 @@ +from .sphere_harmonics import SHEncoder \ No newline at end of file diff --git a/torch-ngp/shencoder/backend.py b/torch-ngp/shencoder/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..cc08a3e9b927eb6d2c30aa19dc7181b3b8a9f44b --- /dev/null +++ b/torch-ngp/shencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_sh_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/torch-ngp/shencoder/setup.py b/torch-ngp/shencoder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..342a6015b6f4c2cd1789ba39baa16835855f0671 --- /dev/null +++ b/torch-ngp/shencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='shencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_shencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/torch-ngp/shencoder/sphere_harmonics.py b/torch-ngp/shencoder/sphere_harmonics.py new file mode 100644 index 0000000000000000000000000000000000000000..7bab24e69d0c488b33f840ff9e2057cb260c3b5d --- /dev/null +++ b/torch-ngp/shencoder/sphere_harmonics.py @@ -0,0 +1,87 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _shencoder as _backend +except ImportError: + from .backend import _backend + +class _sh_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, calc_grad_inputs=False): + # inputs: [B, input_dim], float in [-1, 1] + # RETURN: [B, F], float + + inputs = inputs.contiguous() + B, input_dim = inputs.shape # batch size, coord dim + output_dim = degree ** 2 + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + if calc_grad_inputs: + dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) + else: + dy_dx = None + + _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) + + ctx.save_for_backward(inputs, dy_dx) + ctx.dims = [B, input_dim, degree] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + inputs, dy_dx = ctx.saved_tensors + + if dy_dx is not None: + grad = grad.contiguous() + B, input_dim, degree = ctx.dims + grad_inputs = torch.zeros_like(inputs) + _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) + return grad_inputs, None, None + else: + return None, None, None + + + +sh_encode = _sh_encoder.apply + + +class SHEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim # coord dims, must be 3 + self.degree = degree # 0 ~ 4 + self.output_dim = degree ** 2 + + assert self.input_dim == 3, "SH encoder only support input dim == 3" + assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" + + def __repr__(self): + return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" + + def forward(self, inputs, size=1): + # inputs: [..., input_dim], normalized real world positions in [-size, size] + # return: [..., degree^2] + + inputs = inputs / size # [-1, 1] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = sh_encode(inputs, self.degree, inputs.requires_grad) + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/torch-ngp/shencoder/src/bindings.cpp b/torch-ngp/shencoder/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..595b5b3a98b10ad01428fc1c6c548a8abcb3934b --- /dev/null +++ b/torch-ngp/shencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "shencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); + m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); +} \ No newline at end of file diff --git a/torch-ngp/shencoder/src/shencoder.cu b/torch-ngp/shencoder/src/shencoder.cu new file mode 100644 index 0000000000000000000000000000000000000000..a92e4ab79ecd96d9db2effa97d14dbe30a2f2bb2 --- /dev/null +++ b/torch-ngp/shencoder/src/shencoder.cu @@ -0,0 +1,439 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__global__ void kernel_sh( + const scalar_t * __restrict__ inputs, + scalar_t * outputs, + uint32_t B, uint32_t D, uint32_t C, + scalar_t * dy_dx +) { + const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; + if (b >= B) return; + + const uint32_t C2 = C * C; + + // locate + inputs += b * D; + outputs += b * C2; + + scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; + + scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; + scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; + scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; + + auto write_sh = [&]() { + outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) + if (C <= 1) { return; } + outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) + outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) + outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) + if (C <= 2) { return; } + outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) + outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) + outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) + outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if (C <= 3) { return; } + outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) + outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if (C <= 5) { return; } + outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if (C <= 7) { return; } + outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) + }; + + write_sh(); + + if (dy_dx) { + scalar_t *dx = dy_dx + b * D * C2; + scalar_t *dy = dx + C2; + scalar_t *dz = dy + C2; + + auto write_sh_dx = [&]() { + dx[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dx[1] = 0.0f ; // 0 + dx[2] = 0.0f ; // 0 + dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + if (C <= 2) { return; } + dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) + dx[5] = 0.0f ; // 0 + dx[6] = 0.0f ; // 0 + dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + if (C <= 3) { return; } + dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) + dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) + dx[11] = 0.0f ; // 0 + dx[12] = 0.0f ; // 0 + dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) + dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) + dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) + dx[19] = 0.0f ; // 0 + dx[20] = 0.0f ; // 0 + dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) + dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) + dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) + dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) + dx[29] = 0.0f ; // 0 + dx[30] = 0.0f ; // 0 + dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) + dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) + dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) + dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[41] = 0.0f ; // 0 + dx[42] = 0.0f ; // 0 + dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) + dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) + dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[55] = 0.0f ; // 0 + dx[56] = 0.0f ; // 0 + dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) + dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) + dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + }; + + auto write_sh_dy = [&]() { + dy[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + dy[2] = 0.0f ; // 0 + dy[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dy[6] = 0.0f ; // 0 + dy[7] = 0.0f ; // 0 + dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + if (C <= 3) { return; } + dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dy[12] = 0.0f ; // 0 + dy[13] = 0.0f ; // 0 + dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) + dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) + if (C <= 4) { return; } + dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dy[20] = 0.0f ; // 0 + dy[21] = 0.0f ; // 0 + dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) + dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) + dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dy[30] = 0.0f ; // 0 + dy[31] = 0.0f ; // 0 + dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) + dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) + dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) + if (C <= 6) { return; } + dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dy[42] = 0.0f ; // 0 + dy[43] = 0.0f ; // 0 + dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) + dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) + dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) + dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) + dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dy[56] = 0.0f ; // 0 + dy[57] = 0.0f ; // 0 + dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + }; + + auto write_sh_dz = [&]() { + dz[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dz[1] = 0.0f ; // 0 + dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) + dz[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dz[4] = 0.0f ; // 0 + dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) + dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) + dz[8] = 0.0f ; // 0 + if (C <= 3) { return; } + dz[9] = 0.0f ; // 0 + dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) + dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) + dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) + dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) + dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) + dz[15] = 0.0f ; // 0 + if (C <= 4) { return; } + dz[16] = 0.0f ; // 0 + dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) + dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) + dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) + dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) + dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) + dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + dz[24] = 0.0f ; // 0 + if (C <= 5) { return; } + dz[25] = 0.0f ; // 0 + dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) + dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) + dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) + dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) + dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) + dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) + dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) + dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[35] = 0.0f ; // 0 + if (C <= 6) { return; } + dz[36] = 0.0f ; // 0 + dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) + dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) + dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) + dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) + dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[48] = 0.0f ; // 0 + if (C <= 7) { return; } + dz[49] = 0.0f ; // 0 + dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) + dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) + dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) + dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + dz[63] = 0.0f ; // 0 + }; + write_sh_dx(); + write_sh_dy(); + write_sh_dz(); + } +} + + +template +__global__ void kernel_sh_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t C, + const scalar_t * __restrict__ dy_dx, + scalar_t * grad_inputs +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t b = t / D; + if (b >= B) return; + + const uint32_t d = t - b * D; + const uint32_t C2 = C * C; + + // locate + grad += b * C2; + dy_dx += b * D * C2 + d * C2; + + for (int ch = 0; ch < C2; ch++) { + grad_inputs[t] += grad[ch] * dy_dx[ch]; + //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); + } + +} + +// inputs: [B, D], float, in [0, 1] +// outputs: [B, L * C], float +template +void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh<<>>(inputs, outputs, B, D, C, dy_dx); +} + + +template +void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); +} + + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { + sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr); + })); +} + +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(dy_dx); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(dy_dx); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(dy_dx); + CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "sh_encode_backward_cuda", ([&] { + sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); + })); +} \ No newline at end of file diff --git a/torch-ngp/shencoder/src/shencoder.h b/torch-ngp/shencoder/src/shencoder.h new file mode 100644 index 0000000000000000000000000000000000000000..f9e89facb848940d87c0dec390d991ab2c5d1dd9 --- /dev/null +++ b/torch-ngp/shencoder/src/shencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// inputs: [B, D], float, in [-1, 1] +// outputs: [B, F], float + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); \ No newline at end of file diff --git a/torch-ngp/tensoRF/network.py b/torch-ngp/tensoRF/network.py new file mode 100644 index 0000000000000000000000000000000000000000..d577ba03d9d9938375cdc77ae6d4a40a599a7385 --- /dev/null +++ b/torch-ngp/tensoRF/network.py @@ -0,0 +1,335 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from encoding import get_encoder +from activation import trunc_exp +from nerf.renderer import NeRFRenderer +import raymarching + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + resolution=[128] * 3, + sigma_rank=[16] * 3, + color_rank=[48] * 3, + bg_resolution=[512, 512], + bg_rank=8, + color_feat_dim=27, + num_layers=3, + hidden_dim=128, + num_layers_bg=2, + hidden_dim_bg=64, + bound=1, + **kwargs + ): + super().__init__(bound, **kwargs) + + self.resolution = resolution + + # vector-matrix decomposition + self.sigma_rank = sigma_rank + self.color_rank = color_rank + self.color_feat_dim = color_feat_dim + + self.mat_ids = [[0, 1], [0, 2], [1, 2]] + self.vec_ids = [2, 1, 0] + + self.sigma_mat, self.sigma_vec = self.init_one_svd(self.sigma_rank, self.resolution) + self.color_mat, self.color_vec = self.init_one_svd(self.color_rank, self.resolution) + self.basis_mat = nn.Linear(sum(self.color_rank), self.color_feat_dim, bias=False) + + # render module (default to freq feat + freq dir) + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder, enc_dim = get_encoder('frequency', input_dim=color_feat_dim, multires=2) + self.encoder_dir, enc_dim_dir = get_encoder('frequency', input_dim=3, multires=2) + + self.in_dim = enc_dim + enc_dim_dir + + color_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + else: + in_dim = self.hidden_dim + + if l == num_layers - 1: + out_dim = 3 # rgb + else: + out_dim = self.hidden_dim + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background model + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + + # TODO: just use a matrix to model the background, no need of factorization. + #self.encoder_bg, self.in_dim_bg = get_encoder('hashgrid', input_dim=2, num_levels=4, log2_hashmap_size=18) # much smaller hashgrid + self.bg_resolution = bg_resolution + self.bg_rank = bg_rank + self.bg_mat = nn.Parameter(0.1 * torch.randn((1, bg_rank, bg_resolution[0], bg_resolution[1]))) # [1, R, H, W] + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = bg_rank + enc_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = nn.ModuleList(bg_net) + else: + self.bg_net = None + + + def init_one_svd(self, n_component, resolution, scale=0.1): + + mat, vec = [], [] + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + mat_id_0, mat_id_1 = self.mat_ids[i] + mat.append(nn.Parameter(scale * torch.randn((1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])))) # [1, R, H, W] + vec.append(nn.Parameter(scale * torch.randn((1, n_component[i], resolution[vec_id], 1)))) # [1, R, D, 1] (fake 2d to use grid_sample) + + return nn.ParameterList(mat), nn.ParameterList(vec) + + + def get_sigma_feat(self, x): + # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode) + + N = x.shape[0] + + # plane + line basis + mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord + + sigma_feat = torch.zeros([N,], device=x.device) + + for i in range(len(self.sigma_mat)): + mat_feat = F.grid_sample(self.sigma_mat[i], mat_coord[[i]], align_corners=True).view(-1, N) # [1, R, N, 1] --> [R, N] + vec_feat = F.grid_sample(self.sigma_vec[i], vec_coord[[i]], align_corners=True).view(-1, N) # [R, N] + sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0) + + return sigma_feat + + + def get_color_feat(self, x): + # x: [N, 3], in [-1, 1] + + N = x.shape[0] + + # plane + line basis + mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord + + mat_feat, vec_feat = [], [] + + for i in range(len(self.color_mat)): + mat_feat.append(F.grid_sample(self.color_mat[i], mat_coord[[i]], align_corners=True).view(-1, N)) # [1, R, N, 1] --> [R, N] + vec_feat.append(F.grid_sample(self.color_vec[i], vec_coord[[i]], align_corners=True).view(-1, N)) # [R, N] + + mat_feat = torch.cat(mat_feat, dim=0) # [3 * R, N] + vec_feat = torch.cat(vec_feat, dim=0) # [3 * R, N] + + color_feat = self.basis_mat((mat_feat * vec_feat).T) # [N, 3R] --> [N, color_feat_dim] + + return color_feat + + + def forward(self, x, d): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + # sigma + sigma_feat = self.get_sigma_feat(x) + sigma = trunc_exp(sigma_feat) + #sigma = F.softplus(sigma_feat - 3) + #sigma = F.relu(sigma_feat) + + # rgb + color_feat = self.get_color_feat(x) + enc_color_feat = self.encoder(color_feat) + enc_d = self.encoder_dir(d) + + h = torch.cat([enc_color_feat, enc_d], dim=-1) + for l in range(self.num_layers): + h = self.color_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgb = torch.sigmoid(h) + + return sigma, rgb + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + sigma_feat = self.get_sigma_feat(x) + sigma = trunc_exp(sigma_feat) + #sigma = F.softplus(sigma_feat - 3) + #sigma = F.relu(sigma_feat) + + return { + 'sigma': sigma, + } + + def background(self, x, d): + # x: [N, 2] in [-1, 1] + + N = x.shape[0] + + h = F.grid_sample(self.bg_mat, x.view(1, N, 1, 2), align_corners=True).view(-1, N).T.contiguous() # [R, N] --> [N, R] + d = self.encoder_dir(d) + + h = torch.cat([d, h], dim=-1) + for l in range(self.num_layers_bg): + h = self.bg_net[l](h) + if l != self.num_layers_bg - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + + # allow masked inference + def color(self, x, d, mask=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + + color_feat = self.get_color_feat(x) + color_feat = self.encoder(color_feat) + d = self.encoder_dir(d) + + h = torch.cat([color_feat, d], dim=-1) + for l in range(self.num_layers): + h = self.color_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) + else: + rgbs = h + + return rgbs + + + # L1 penalty for loss + def density_loss(self): + loss = 0 + for i in range(len(self.sigma_mat)): + loss = loss + torch.mean(torch.abs(self.sigma_mat[i])) + torch.mean(torch.abs(self.sigma_vec[i])) + return loss + + # upsample utils + @torch.no_grad() + def upsample_params(self, mat, vec, resolution): + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + mat_id_0, mat_id_1 = self.mat_ids[i] + mat[i] = nn.Parameter(F.interpolate(mat[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode='bilinear', align_corners=True)) + vec[i] = nn.Parameter(F.interpolate(vec[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=True)) + + + @torch.no_grad() + def upsample_model(self, resolution): + self.upsample_params(self.sigma_mat, self.sigma_vec, resolution) + self.upsample_params(self.color_mat, self.color_vec, resolution) + self.resolution = resolution + + @torch.no_grad() + def shrink_model(self): + # shrink aabb_train and the model so it only represents the space inside aabb_train. + + half_grid_size = self.bound / self.grid_size + thresh = min(self.density_thresh, self.mean_density) + + # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) + valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] + valid_pos = raymarching.morton3D_invert(torch.nonzero(valid_grid)) # [Nz] --> [Nz, 3], in [0, H - 1] + #plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... + valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (self.bound - half_grid_size) # [Nz, 3], in [-b+hgs, b-hgs] + min_pos = valid_pos.amin(0) - half_grid_size # [3] + max_pos = valid_pos.amax(0) + half_grid_size # [3] + + # shrink model + reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) + units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso + tl = (min_pos - self.aabb_train[:3]) / units + br = (max_pos - self.aabb_train[:3]) / units + tl = torch.round(tl).long().clamp(min=0) + br = torch.minimum(torch.round(br).long(), reso) + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + mat_id_0, mat_id_1 = self.mat_ids[i] + + self.sigma_vec[i] = nn.Parameter(self.sigma_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + self.color_vec[i] = nn.Parameter(self.color_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + + self.sigma_mat[i] = nn.Parameter(self.sigma_mat[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) + self.color_mat[i] = nn.Parameter(self.color_mat[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) + + self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] + + print(f'[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}') + print(f'[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}') + + + # optimizer utils + def get_params(self, lr1, lr2): + params = [ + {'params': self.sigma_mat, 'lr': lr1}, + {'params': self.sigma_vec, 'lr': lr1}, + {'params': self.color_mat, 'lr': lr1}, + {'params': self.color_vec, 'lr': lr1}, + {'params': self.basis_mat.parameters(), 'lr': lr2}, + {'params': self.color_net.parameters(), 'lr': lr2}, + ] + if self.bg_radius > 0: + params.append({'params': self.bg_mat, 'lr': lr1}) + params.append({'params': self.bg_net.parameters(), 'lr': lr2}) + return params + \ No newline at end of file diff --git a/torch-ngp/tensoRF/network_cc.py b/torch-ngp/tensoRF/network_cc.py new file mode 100644 index 0000000000000000000000000000000000000000..b752901ed5d569414a240c50efdd9aec8402b2a6 --- /dev/null +++ b/torch-ngp/tensoRF/network_cc.py @@ -0,0 +1,643 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from encoding import get_encoder +from activation import trunc_exp +from nerf.renderer import NeRFRenderer +import raymarching + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + resolution=[128] * 3, + degree=4, + # rank_vec_density=[64], + # rank_mat_density=[16], + # rank_vec=[64], + # rank_mat=[64], + rank_vec_density=[64, 64, 64, 64, 64], + rank_mat_density=[0, 4, 8, 12, 16], + rank_vec=[64, 64, 64, 64, 64], + rank_mat=[0, 4, 16, 32, 64], + bg_resolution=[512, 512], + bg_rank=8, + bound=1, + **kwargs + ): + super().__init__(bound, **kwargs) + + self.resolution = resolution + + self.degree = degree + self.encoder_dir, self.enc_dir_dim = get_encoder('sphere_harmonics', degree=self.degree) + self.out_dim = 3 * self.enc_dir_dim # only color dim + + # group list in list for composition + self.rank_vec_density = [rank_vec_density] + self.rank_mat_density = [rank_mat_density] + self.rank_vec = [rank_vec] + self.rank_mat = [rank_mat] + + # all components are divided into K groups + assert len(rank_vec) == len(rank_mat) == len(rank_vec_density) == len(rank_mat_density) + + self.K = [len(rank_vec)] + + # utility + self.group_vec_density = [np.diff(rank_vec_density, prepend=0)] + self.group_mat_density = [np.diff(rank_mat_density, prepend=0)] + self.group_vec = [np.diff(rank_vec, prepend=0)] + self.group_mat = [np.diff(rank_mat, prepend=0)] + + self.mat_ids = [[0, 1], [0, 2], [1, 2]] + self.vec_ids = [2, 1, 0] + + # allocate params + + self.U_vec_density = nn.ParameterList() + self.S_vec_density = nn.ParameterList() + + for k in range(self.K[0]): + if self.group_vec_density[0][k] > 0: + for i in range(3): + vec_id = self.vec_ids[i] + w = torch.randn(self.group_vec_density[0][k], self.resolution[vec_id]) * 0.2 # [R, H] + self.U_vec_density.append(nn.Parameter(w.view(1, self.group_vec_density[0][k], self.resolution[vec_id], 1))) # [1, R, H, 1] + w = torch.ones(1, self.group_vec_density[0][k]) + torch.nn.init.kaiming_normal_(w) + self.S_vec_density.append(nn.Parameter(w)) + + self.U_mat_density = nn.ParameterList() + self.S_mat_density = nn.ParameterList() + + + for k in range(self.K[0]): + if self.group_mat_density[0][k] > 0: + for i in range(3): + mat_id_0, mat_id_1 = self.mat_ids[i] + w = torch.randn(self.group_mat_density[0][k], self.resolution[mat_id_1] * self.resolution[mat_id_0]) * 0.2 # [R, HW] + self.U_mat_density.append(nn.Parameter(w.view(1, self.group_mat_density[0][k], self.resolution[mat_id_1], self.resolution[mat_id_0]))) # [1, R, H, W] + w = torch.ones(1, self.group_mat_density[0][k]) + torch.nn.init.kaiming_normal_(w) + self.S_mat_density.append(nn.Parameter(w)) + + self.U_vec = nn.ParameterList() + self.S_vec = nn.ParameterList() + + for k in range(self.K[0]): + if self.group_vec[0][k] > 0: + for i in range(3): + vec_id = self.vec_ids[i] + w = torch.randn(self.group_vec[0][k], self.resolution[vec_id]) * 0.2 # [R, H] + self.U_vec.append(nn.Parameter(w.view(1, self.group_vec[0][k], self.resolution[vec_id], 1))) # [1, R, H, 1] + w = torch.ones(self.out_dim, self.group_vec[0][k]) + torch.nn.init.kaiming_normal_(w) + self.S_vec.append(nn.Parameter(w)) + + self.U_mat = nn.ParameterList() + self.S_mat = nn.ParameterList() + + for k in range(self.K[0]): + if self.group_mat[0][k] > 0: + for i in range(3): + mat_id_0, mat_id_1 = self.mat_ids[i] + w = torch.randn(self.group_mat[0][k], self.resolution[mat_id_1] * self.resolution[mat_id_0]) * 0.2 # [R, HW] + self.U_mat.append(nn.Parameter(w.view(1, self.group_mat[0][k], self.resolution[mat_id_1], self.resolution[mat_id_0]))) # [1, R, H, W] + w = torch.ones(self.out_dim, self.group_mat[0][k]) + torch.nn.init.kaiming_normal_(w) + self.S_mat.append(nn.Parameter(w)) + + # flag + self.finalized = False if self.K[0] != 1 else True + + # background model + if self.bg_radius > 0: + + self.bg_resolution = bg_resolution + self.bg_rank = bg_rank + self.bg_mat = nn.Parameter(0.2 * torch.randn((1, bg_rank, bg_resolution[0], bg_resolution[1]))) # [1, R, H, W] + + w = torch.ones(self.out_dim, bg_rank) # just color + torch.nn.init.kaiming_normal_(w) + self.bg_S = nn.Parameter(w) + + + def compute_features_density(self, x, K=-1, residual=False, oid=0): + # x: [N, 3], in [-1, 1] + # return: [K, N, out_dim] + + prefix = x.shape[:-1] + N = np.prod(prefix) + + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) + + mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] + + # calculate first K blocks + if K <= 0: + K = self.K[oid] + + # loop all blocks + if residual: + outputs = [] + + last_y = None + + offset_vec = oid + offset_mat = oid + + for k in range(K): + + y = 0 + + if self.group_vec_density[oid][k]: + vec_feat = F.grid_sample(self.U_vec_density[3 * offset_vec + 0], vec_coord[[0]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_vec_density[3 * offset_vec + 1], vec_coord[[1]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_vec_density[3 * offset_vec + 2], vec_coord[[2]], align_corners=False).view(-1, N) # [r, N] + + y = y + (self.S_vec_density[offset_vec] @ vec_feat) + + offset_vec += 1 + + if self.group_mat_density[oid][k]: + mat_feat = F.grid_sample(self.U_mat_density[3 * offset_mat + 0], mat_coord[[0]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_mat_density[3 * offset_mat + 1], mat_coord[[1]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_mat_density[3 * offset_mat + 2], mat_coord[[2]], align_corners=False).view(-1, N) # [r, N] + + y = y + (self.S_mat_density[offset_mat] @ mat_feat) # [out_dim, N] + + offset_mat += 1 + + if last_y is not None: + y = y + last_y + + if residual: + outputs.append(y) + + last_y = y + + if residual: + outputs = torch.stack(outputs, dim=0).permute(0, 2, 1).contiguous().view(K, *prefix, -1) # [K, out_dim, N] --> [K, N, out_dim] + else: + outputs = last_y.permute(1, 0).contiguous().view(*prefix, -1) # [out_dim, N] --> [N, out_dim] + + return outputs + + def compute_features(self, x, K=-1, residual=False, oid=0): + # x: [N, 3], in [-1, 1] + # return: [K, N, out_dim] + + prefix = x.shape[:-1] + N = np.prod(prefix) + + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) + + mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] + + # calculate first K blocks + if K <= 0: + K = self.K[oid] + + # loop all blocks + if residual: + outputs = [] + + last_y = None + + offset_vec = oid + offset_mat = oid + + for k in range(K): + + y = 0 + + if self.group_vec[oid][k]: + vec_feat = F.grid_sample(self.U_vec[3 * offset_vec + 0], vec_coord[[0]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_vec[3 * offset_vec + 1], vec_coord[[1]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_vec[3 * offset_vec + 2], vec_coord[[2]], align_corners=False).view(-1, N) # [r, N] + + y = y + (self.S_vec[offset_vec] @ vec_feat) + + offset_vec += 1 + + if self.group_mat[oid][k]: + mat_feat = F.grid_sample(self.U_mat[3 * offset_mat + 0], mat_coord[[0]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_mat[3 * offset_mat + 1], mat_coord[[1]], align_corners=False).view(-1, N) * \ + F.grid_sample(self.U_mat[3 * offset_mat + 2], mat_coord[[2]], align_corners=False).view(-1, N) # [r, N] + + y = y + (self.S_mat[offset_mat] @ mat_feat) # [out_dim, N] + + offset_mat += 1 + + if last_y is not None: + y = y + last_y + + if residual: + outputs.append(y) + + last_y = y + + if residual: + outputs = torch.stack(outputs, dim=0).permute(0, 2, 1).contiguous().view(K, *prefix, -1) # [K, out_dim, N] --> [K, N, out_dim] + else: + outputs = last_y.permute(1, 0).contiguous().view(*prefix, -1) # [out_dim, N] --> [N, out_dim] + + return outputs + + + def normalize_coord(self, x, oid=0): + + if oid == 0: + aabb = self.aabb_train + else: + tr = getattr(self, f'T_{oid}') # [4, 4] transformation matrix + x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1) # to homo + x = (x @ tr.T)[:, :3] # [N, 4] --> [N, 3] + + aabb = getattr(self, f'aabb_{oid}') + + return 2 * (x - aabb[:3]) / (aabb[3:] - aabb[:3]) - 1 # [-1, 1] in bbox + + + def normalize_dir(self, d, oid=0): + if oid != 0: + tr = getattr(self, f'R_{oid}') # [3, 3] rotation matrix + d = d @ tr.T + return d + + + def forward(self, x, d, K=-1): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + N = x.shape[0] + + # single object + if len(self.K) == 1: + + x_model = self.normalize_coord(x) + feats_density = self.compute_features_density(x_model, K, residual=self.training) # [K, N, 1] + sigma = trunc_exp(feats_density).squeeze(-1) # [K, N] + + enc_d = self.encoder_dir(d) # [N, C] + + h = self.compute_features(x_model, K, residual=self.training) # [K, N, 3C] + h = h.view(K, N, 3, self.degree ** 2) # [K, N, 3, C] + h = (h * enc_d.unsqueeze(1)).sum(-1) # [K, N, 3] + + rgb = torch.sigmoid(h) # [K, N, 3] + + return sigma, rgb + + # multi-object (composed scene), do not support rank-residual training for now. + else: + + sigma_list = [] + h_list = [] + + sigma_all = 0 + rgb_all = 0 + + + for oid in range(1, len(self.K)): + x_model = self.normalize_coord(x, oid=oid) + + feats_density = self.compute_features_density(x_model, -1, residual=False, oid=oid) # [N, 1] + + sigma = trunc_exp(feats_density).squeeze(-1) # [N] + sigma_list.append(sigma.detach().clone()) + + sigma_all += sigma + + d_model = self.normalize_dir(d, oid=oid) + enc_d = self.encoder_dir(d_model) # [N, C] + + h = self.compute_features(x_model, -1, residual=False, oid=oid) # [N, 3C] + h = h.view(N, 3, self.degree ** 2) + h = (h * enc_d.unsqueeze(1)).sum(-1) # [N, 3] + + h_list.append(h) + + + ws = torch.stack(sigma_list, dim=0) # [O, N] + ws = F.softmax(ws, dim=0) + + for oid in range(1, len(self.K)): + rgb_all += h_list[oid - 1] * ws[oid - 1].unsqueeze(-1) + + rgb_all = torch.sigmoid(rgb_all) + + return sigma_all, rgb_all + + + def density(self, x, K=-1): + # x: [N, 3], in [-bound, bound] + + if len(self.K) == 1: + + x_model = self.normalize_coord(x) + feats_density = self.compute_features_density(x_model, K, residual=False) # [N, 1 + 3C] + sigma = trunc_exp(feats_density).squeeze(-1) # [N] + + return { + 'sigma': sigma, + } + + else: + + sigma_all = 0 + for oid in range(1, len(self.K)): + x_model = self.normalize_coord(x, oid=oid) + feats_density = self.compute_features_density(x_model, -1, residual=False, oid=oid) # [N, 1] + sigma = trunc_exp(feats_density).squeeze(-1) # [N] + sigma_all += sigma + + return { + 'sigma': sigma_all, + } + + + def background(self, x, d): + # x: [N, 2] in [-1, 1] + + N = x.shape[0] + + h = F.grid_sample(self.bg_mat, x.view(1, N, 1, 2), align_corners=False).view(-1, N) # [R, N] + h = (self.bg_S @ h).T.contiguous() # [3C, N] --> [N, 3C] + enc_d = self.encoder_dir(d) + + h = h.view(N, 3, -1) + h = (h * enc_d.unsqueeze(1)).sum(-1) # [N, 3] + + # sigmoid activation for rgb + rgb = torch.sigmoid(h) + + return rgb + + + # L1 penalty for loss + def density_loss(self): + loss = 0 + for i in range(len(self.U_vec_density)): + loss = loss + torch.mean(torch.abs(self.U_vec_density[i])) + for i in range(len(self.U_mat_density)): + loss = loss + torch.mean(torch.abs(self.U_mat_density[i])) + return loss + + + # upsample utils + @torch.no_grad() + def upsample_model(self, resolution): + + for i in range(len(self.U_vec_density)): + vec_id = self.vec_ids[i % 3] + self.U_vec_density[i] = nn.Parameter(F.interpolate(self.U_vec_density[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=False)) + + for i in range(len(self.U_mat_density)): + mat_id_0, mat_id_1 = self.mat_ids[i % 3] + self.U_mat_density[i] = nn.Parameter(F.interpolate(self.U_mat_density[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode='bilinear', align_corners=False)) + + for i in range(len(self.U_vec)): + vec_id = self.vec_ids[i % 3] + self.U_vec[i] = nn.Parameter(F.interpolate(self.U_vec[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=False)) + + for i in range(len(self.U_mat)): + mat_id_0, mat_id_1 = self.mat_ids[i % 3] + self.U_mat[i] = nn.Parameter(F.interpolate(self.U_mat[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode='bilinear', align_corners=False)) + + self.resolution = resolution + + print(f'[INFO] upsampled to {resolution}') + + @torch.no_grad() + def shrink_model(self): + # shrink aabb_train and the model so it only represents the space inside aabb_train. + + half_grid_size = self.bound / self.grid_size + thresh = min(self.density_thresh, self.mean_density) + + # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) + valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] + valid_pos = raymarching.morton3D_invert(torch.nonzero(valid_grid)) # [Nz] --> [Nz, 3], in [0, H - 1] + #plot_pointcloud(valid_pos.detach().cpu().numpy()) + valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (self.bound - half_grid_size) # [Nz, 3], in [-b+hgs, b-hgs] + min_pos = valid_pos.amin(0) - half_grid_size # [3] + max_pos = valid_pos.amax(0) + half_grid_size # [3] + + # shrink model + reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) + units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso + tl = (min_pos - self.aabb_train[:3]) / units + br = (max_pos - self.aabb_train[:3]) / units + tl = torch.round(tl).long().clamp(min=0) + br = torch.minimum(torch.round(br).long(), reso) + + for i in range(len(self.U_vec_density)): + vec_id = self.vec_ids[i % 3] + self.U_vec_density[i] = nn.Parameter(self.U_vec_density[i].data[..., tl[vec_id]:br[vec_id], :]) + + for i in range(len(self.U_mat_density)): + mat_id_0, mat_id_1 = self.mat_ids[i % 3] + self.U_mat_density[i] = nn.Parameter(self.U_mat_density[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) + + for i in range(len(self.U_vec)): + vec_id = self.vec_ids[i % 3] + self.U_vec[i] = nn.Parameter(self.U_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + + for i in range(len(self.U_mat)): + mat_id_0, mat_id_1 = self.mat_ids[i % 3] + self.U_mat[i] = nn.Parameter(self.U_mat[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) + + self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] + + print(f'[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}') + print(f'[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}') + + + @torch.no_grad() + def finalize_group(self, U, S): + + if len(U) == 0 or len(S) == 0: + return nn.ParameterList(), nn.ParameterList() + + # sort rank inside each group + for i in range(len(S)): + importance = S[i].abs().sum(0) # [C, R] --> [R] + for j in range(3): + importance *= U[3 * i + j].view(importance.shape[0], -1).norm(dim=-1) # [R, H] --> [R] + + inds = torch.argsort(importance, descending=True) # important first + + S[i] = nn.Parameter(S[i].data[:, inds]) + for j in range(3): + U[3 * i + j] = nn.Parameter(U[3 * i + j].data[:, inds]) + + # fuse rank across all groups + + S = nn.ParameterList([ + nn.Parameter(torch.cat([s.data for s in S], dim=1)) + ]) + + U = nn.ParameterList([ + nn.Parameter(torch.cat([v.data for v in U[0::3]], dim=1)), + nn.Parameter(torch.cat([v.data for v in U[1::3]], dim=1)), + nn.Parameter(torch.cat([v.data for v in U[2::3]], dim=1)), + ]) + + return U, S + + + # finalize model parameters (fuse all groups) for faster inference, but no longer allow rank-residual training. + @torch.no_grad() + def finalize(self): + self.U_vec_density, self.S_vec_density = self.finalize_group(self.U_vec_density, self.S_vec_density) + self.U_mat_density, self.S_mat_density = self.finalize_group(self.U_mat_density, self.S_mat_density) + self.U_vec, self.S_vec = self.finalize_group(self.U_vec, self.S_vec) + self.U_mat, self.S_mat = self.finalize_group(self.U_mat, self.S_mat) + + # update states + self.rank_vec_density[0] = [self.rank_vec_density[0][-1]] + self.rank_mat_density[0] = [self.rank_mat_density[0][-1]] + self.rank_vec[0] = [self.rank_vec[0][-1]] + self.rank_mat[0] = [self.rank_mat[0][-1]] + + self.group_vec_density[0] = self.rank_vec_density[0] + self.group_mat_density[0] = self.rank_mat_density[0] + self.group_vec[0] = self.rank_vec[0] + self.group_mat[0] = self.rank_mat[0] + + self.K[0] = 1 + + self.finalized = True + + + # assume finalized (sorted), simply slicing! + @torch.no_grad() + def compress_group(self, U, S, rank): + if rank == 0: + return nn.ParameterList(), nn.ParameterList() + S[0] = nn.Parameter(S[0].data[:, :rank].clone()) # clone is necessary, slicing won't change storage! + for i in range(3): + U[i] = nn.Parameter(U[i].data[:, :rank].clone()) + return U, S + + @torch.no_grad() + def compress(self, ranks): + # ranks: (density_vec, density_mat, color_vec, color_mat) + if not self.finalized: + self.finalize() + + self.U_vec_density, self.S_vec_density = self.compress_group(self.U_vec_density, self.S_vec_density, ranks[0]) + self.U_mat_density, self.S_mat_density = self.compress_group(self.U_mat_density, self.S_mat_density, ranks[1]) + self.U_vec, self.S_vec = self.compress_group(self.U_vec, self.S_vec, ranks[2]) + self.U_mat, self.S_mat = self.compress_group(self.U_mat, self.S_mat, ranks[3]) + + # update states + self.rank_vec_density[0] = [ranks[0]] + self.rank_mat_density[0] = [ranks[1]] + self.rank_vec[0] = [ranks[2]] + self.rank_mat[0] = [ranks[3]] + + self.group_vec_density[0] = self.rank_vec_density[0] + self.group_mat_density[0] = self.rank_mat_density[0] + self.group_vec[0] = self.rank_vec[0] + self.group_mat[0] = self.rank_mat[0] + + @torch.no_grad() + def compose(self, other, R=None, s=None, t=None): + if not self.finalized: + self.finalize() + if not other.finalized: + other.finalize() + + # parameters + self.U_vec_density.extend(other.U_vec_density) + self.S_vec_density.extend(other.S_vec_density) + + self.U_mat_density.extend(other.U_mat_density) + self.S_mat_density.extend(other.S_mat_density) + + self.U_vec.extend(other.U_vec) + self.S_vec.extend(other.S_vec) + + self.U_mat.extend(other.U_mat) + self.S_mat.extend(other.S_mat) + + # states + self.rank_vec_density.extend(other.rank_vec_density) + self.rank_mat_density.extend(other.rank_mat_density) + self.rank_vec.extend(other.rank_vec) + self.rank_mat.extend(other.rank_mat) + + self.group_vec_density.extend(other.group_vec_density) + self.group_mat_density.extend(other.group_mat_density) + self.group_vec.extend(other.group_vec) + self.group_mat.extend(other.group_mat) + + self.K.extend(other.K) + + # transforms + oid = len(self.K) - 1 + + # R: a [3, 3] rotation matrix in SO(3) + if R is None: + R = torch.eye(3, dtype=torch.float32) + elif isinstance(R, np.ndarray): + R = torch.from_numpy(R.astype(np.float32)) + else: # tensor + R = R.float() + + # s is a scalar scaling factor + if s is None: + s = 1 + + # t is a [3] translation vector + if t is None: + t = torch.zeros(3, dtype=torch.float32) + elif isinstance(t, np.ndarray): + t = torch.from_numpy(t.astype(np.float32)) + else: # tensor + t = t.float() + + # T: the [4, 4] transformation matrix + # first scale & rotate, then translate. + T = torch.eye(4, dtype=torch.float32) + T[:3, :3] = R * s + T[:3, 3] = t + + # T is the model matrix, but we want the matrix to transform rays, i.e., the inversion. + T = torch.inverse(T).to(self.aabb_train.device) + R = R.T.to(self.aabb_train.device) + + self.register_buffer(f'T_{oid}', T) + self.register_buffer(f'R_{oid}', R) + self.register_buffer(f'aabb_{oid}', other.aabb_train) + + # update density grid multiple times to make sure it is accurate + # TODO: 3 is very empirical... + for _ in range(3): + self.update_extra_state() + + + # optimizer utils + def get_params(self, lr1, lr2): + params = [ + {'params': self.U_vec_density, 'lr': lr1}, + {'params': self.S_vec_density, 'lr': lr2}, + {'params': self.U_mat_density, 'lr': lr1}, + {'params': self.S_mat_density, 'lr': lr2}, + {'params': self.U_vec, 'lr': lr1}, + {'params': self.S_vec, 'lr': lr2}, + {'params': self.U_mat, 'lr': lr1}, + {'params': self.S_mat, 'lr': lr2}, + ] + if self.bg_radius > 0: + params.append({'params': self.bg_mat, 'lr': lr1}) + params.append({'params': self.bg_S, 'lr': lr2}) + return params + \ No newline at end of file diff --git a/torch-ngp/tensoRF/network_cp.py b/torch-ngp/tensoRF/network_cp.py new file mode 100644 index 0000000000000000000000000000000000000000..423367e9db9eaf0df7f36af9724003375cb34cb1 --- /dev/null +++ b/torch-ngp/tensoRF/network_cp.py @@ -0,0 +1,256 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from encoding import get_encoder +from activation import trunc_exp +from nerf.renderer import NeRFRenderer + +import raymarching + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + resolution=[128] * 3, + sigma_rank=[96] * 3, # ref: https://github.com/apchenstu/TensoRF/commit/7f505875a9f321fa8439a8d5c6a15fc7d2f17303 + color_rank=[288] * 3, + color_feat_dim=27, + num_layers=3, + hidden_dim=128, + bound=1, + **kwargs + ): + super().__init__(bound, **kwargs) + + self.resolution = resolution + + # vector-matrix decomposition + self.sigma_rank = sigma_rank + self.color_rank = color_rank + self.color_feat_dim = color_feat_dim + + self.mat_ids = [[0, 1], [0, 2], [1, 2]] + self.vec_ids = [2, 1, 0] + + self.sigma_vec = self.init_one_svd(self.sigma_rank, self.resolution) + self.color_vec = self.init_one_svd(self.color_rank, self.resolution) + self.basis_mat = nn.Linear(self.color_rank[0], self.color_feat_dim, bias=False) + + # render module (default to freq feat + freq dir) + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder, enc_dim = get_encoder('frequency', input_dim=color_feat_dim, multires=2) + self.encoder_dir, enc_dim_dir = get_encoder('frequency', input_dim=3, multires=2) + + self.in_dim = enc_dim + enc_dim_dir + + color_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + else: + in_dim = self.hidden_dim + + if l == num_layers - 1: + out_dim = 3 # rgb + else: + out_dim = self.hidden_dim + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + + def init_one_svd(self, n_component, resolution, scale=0.2): + + vec = [] + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + vec.append(torch.nn.Parameter(scale * torch.randn((1, n_component[i], resolution[vec_id], 1)))) # [1, R, D, 1] (fake 2d to use grid_sample) + + return torch.nn.ParameterList(vec) + + + def get_sigma_feat(self, x): + # x: [N, 3], in [-1, 1] + + N = x.shape[0] + + # line basis + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord + + vec_feat = F.grid_sample(self.sigma_vec[0], vec_coord[[0]], align_corners=True).view(-1, N) * \ + F.grid_sample(self.sigma_vec[1], vec_coord[[1]], align_corners=True).view(-1, N) * \ + F.grid_sample(self.sigma_vec[2], vec_coord[[2]], align_corners=True).view(-1, N) # [R, N] + + sigma_feat = torch.sum(vec_feat, dim=0) + + return sigma_feat + + + def get_color_feat(self, x): + # x: [N, 3], in [-1, 1] + + N = x.shape[0] + + # line basis + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord + + vec_feat = F.grid_sample(self.color_vec[0], vec_coord[[0]], align_corners=True).view(-1, N) * \ + F.grid_sample(self.color_vec[1], vec_coord[[1]], align_corners=True).view(-1, N) * \ + F.grid_sample(self.color_vec[2], vec_coord[[2]], align_corners=True).view(-1, N) # [R, N] + + color_feat = self.basis_mat(vec_feat.T) # [N, R] --> [N, color_feat_dim] + + return color_feat + + + def forward(self, x, d): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + # sigma + sigma_feat = self.get_sigma_feat(x) + sigma = trunc_exp(sigma_feat) + + # rgb + color_feat = self.get_color_feat(x) + enc_color_feat = self.encoder(color_feat) + enc_d = self.encoder_dir(d) + + h = torch.cat([enc_color_feat, enc_d], dim=-1) + for l in range(self.num_layers): + h = self.color_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgb = torch.sigmoid(h) + + return sigma, rgb + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + sigma_feat = self.get_sigma_feat(x) + sigma = trunc_exp(sigma_feat) + + return { + 'sigma': sigma, + } + + # allow masked inference + def color(self, x, d, mask=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + + color_feat = self.get_color_feat(x) + color_feat = self.encoder(color_feat) + d = self.encoder_dir(d) + + h = torch.cat([color_feat, d], dim=-1) + for l in range(self.num_layers): + h = self.color_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) + else: + rgbs = h + + return rgbs + + + # L1 penalty for loss + def density_loss(self): + loss = 0 + for i in range(len(self.sigma_vec)): + loss = loss + torch.mean(torch.abs(self.sigma_vec[i])) + return loss + + # upsample utils + @torch.no_grad() + def upsample_params(self, vec, resolution): + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + vec[i] = torch.nn.Parameter(F.interpolate(vec[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=True)) + + + @torch.no_grad() + def upsample_model(self, resolution): + self.upsample_params(self.sigma_vec, resolution) + self.upsample_params(self.color_vec, resolution) + self.resolution = resolution + + @torch.no_grad() + def shrink_model(self): + + half_grid_size = self.bound / self.grid_size + thresh = min(self.density_thresh, self.mean_density) + + # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) + valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] + valid_pos = raymarching.morton3D_invert(torch.nonzero(valid_grid)) # [Nz] --> [Nz, 3], in [0, H - 1] + + #plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... + valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (self.bound - half_grid_size) # [Nz, 3], in [-b+hgs, b-hgs] + min_pos = valid_pos.amin(0) - half_grid_size # [3] + max_pos = valid_pos.amax(0) + half_grid_size # [3] + + # shrink model + reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) + units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso + tl = (min_pos - self.aabb_train[:3]) / units + br = (max_pos - self.aabb_train[:3]) / units + tl = torch.round(tl).long().clamp(min=0) + br = torch.minimum(torch.round(br).long(), reso) + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + + self.sigma_vec[i] = nn.Parameter(self.sigma_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + self.color_vec[i] = nn.Parameter(self.color_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + + self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] + + print(f'[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}') + print(f'[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}') + + # optimizer utils + def get_params(self, lr1, lr2): + return [ + {'params': self.sigma_vec, 'lr': lr1}, + {'params': self.color_vec, 'lr': lr1}, + {'params': self.basis_mat.parameters(), 'lr': lr2}, + {'params': self.color_net.parameters(), 'lr': lr2}, + ] + \ No newline at end of file diff --git a/torch-ngp/tensoRF/utils.py b/torch-ngp/tensoRF/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41de8e6468c66701068fa9ed294d2c41c66741f7 --- /dev/null +++ b/torch-ngp/tensoRF/utils.py @@ -0,0 +1,401 @@ +from nerf.utils import * +from nerf.utils import Trainer as _Trainer + +# for isinstance +from tensoRF.network_cc import NeRFNetwork as CCNeRF + + +class Trainer(_Trainer): + def __init__(self, + name, # name of this experiment + opt, # extra conf + model, # network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + eval_interval=1, # eval once every $ epoch + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.optimizer_fn = optimizer + self.lr_scheduler_fn = lr_scheduler + + super().__init__(name, opt, model, criterion, optimizer, ema_decay, lr_scheduler, metrics, local_rank, world_size, device, mute, fp16, eval_interval, max_keep_ckpt, workspace, best_mode, use_loss_as_metric, report_metric_at_train, use_checkpoint, use_tensorboardX, scheduler_update_every_step) + + ### ------------------------------ + + def train_step(self, data): + + pred_rgb, gt_rgb, loss = super().train_step(data) + + # l1 reg + loss += self.model.density_loss() * self.opt.l1_reg_weight + + return pred_rgb, gt_rgb, loss + + + def train_one_epoch(self, loader): + self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + for data in loader: + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.local_step += 1 + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + + if self.local_rank == 0: + if self.report_metric_at_train: + for metric in self.metrics: + metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar("train/loss", loss_val, self.global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + # Different from _Trainer! + if self.global_step in self.opt.upsample_model_steps: + + # shrink + if self.model.cuda_ray: # and self.global_step == self.opt.upsample_model_steps[0]: + self.model.shrink_model() + + # adaptive voxel size from aabb_train + n_vox = self.upsample_resolutions.pop(0) ** 3 # n_voxels + aabb = self.model.aabb_train.cpu().numpy() + vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox) + reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist() + + self.log(f"[INFO] upsample model at step {self.global_step} from {self.model.resolution} to {reso}") + self.model.upsample_model(reso) + + # reset optimizer since params changed. + self.optimizer = self.optimizer_fn(self.model) + self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + self.log(f"==> Finished Epoch {self.epoch}.") + + + # [GUI] just train for 16 steps, without any other overhead that may slow down rendering. + def train_gui(self, train_loader, step=16): + + self.model.train() + + total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) + + loader = iter(train_loader) + + for _ in range(step): + + # mimic an infinite loop dataloader (in case the total dataset is smaller than step) + try: + data = next(loader) + except StopIteration: + loader = iter(train_loader) + data = next(loader) + + # mark untrained grid + if self.global_step == 0: + self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) + self.error_map = train_loader._data.error_map + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + total_loss += loss.detach() + + # Different from _Trainer! + if self.global_step in self.opt.upsample_model_steps: + + # shrink + if self.model.cuda_ray: + self.model.shrink_model() + + # adaptive voxel size from aabb_train + n_vox = self.upsample_resolutions.pop(0) ** 3 # n_voxels + aabb = self.model.aabb_train.cpu().numpy() + vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox) + reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist() + + self.log(f"[INFO] upsample model at step {self.global_step} from {self.model.resolution} to {reso}") + self.model.upsample_model(reso) + + # reset optimizer since params changed. + self.optimizer = self.optimizer_fn(self.model) + self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss.item() / step + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + outputs = { + 'loss': average_loss, + 'lr': self.optimizer.param_groups[0]['lr'], + } + + return outputs + + + def save_checkpoint(self, name=None, full=False, best=False, remove_old=True): + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}.pth' + + state = { + 'epoch': self.epoch, + 'global_step': self.global_step, + 'stats': self.stats, + 'resolution': self.model.resolution, # Different from _Trainer! + } + + # special case for CCNeRF... + if isinstance(self.model, CCNeRF): + state['rank_vec_density'] = self.model.rank_vec_density[0] + state['rank_mat_density'] = self.model.rank_mat_density[0] + state['rank_vec'] = self.model.rank_vec[0] + state['rank_mat'] = self.model.rank_mat[0] + + if self.model.cuda_ray: + state['mean_count'] = self.model.mean_count + state['mean_density'] = self.model.mean_density + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{self.ckpt_path}/{name}.pth" + + if remove_old: + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = self.stats["checkpoints"].pop(0) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, file_path) + + else: + if len(self.stats["results"]) > 0: + if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]: + self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") + self.stats["best_result"] = self.stats["results"][-1] + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None, model_only=False): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log("[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + # if 'model' not in checkpoint_dict: + # # reset resolution + # self.model.upsample_model() # TODO: need to calclate resolution from param size... + # self.optimizer = self.optimizer_fn(self.model) + # self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) + + # self.model.load_state_dict(checkpoint_dict) + # self.log("[INFO] loaded model.") + # return + + # special case for CCNeRF: model structure should be identical to ckpt... + if isinstance(self.model, CCNeRF): + + # print(checkpoint_dict['rank_vec_density'], checkpoint_dict['rank_mat_density'], checkpoint_dict['rank_vec'], checkpoint_dict['rank_mat']) + + # very ugly... + self.model = CCNeRF( + rank_vec_density=checkpoint_dict['rank_vec_density'], + rank_mat_density=checkpoint_dict['rank_mat_density'], + rank_vec=checkpoint_dict['rank_vec'], + rank_mat=checkpoint_dict['rank_mat'], + resolution=checkpoint_dict['resolution'], + bound=self.opt.bound, + cuda_ray=self.opt.cuda_ray, + density_scale=1, + min_near=self.opt.min_near, + density_thresh=self.opt.density_thresh, + bg_radius=self.opt.bg_radius, + ).to(self.device) + + self.log(f"[INFO] ===== re-initialize CCNeRF =====") + self.log(self.model) + + else: + self.model.upsample_model(checkpoint_dict['resolution']) + + if self.optimizer_fn is not None: + self.optimizer = self.optimizer_fn(self.model) + if self.lr_scheduler_fn is not None: + self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) + + missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + self.ema.load_state_dict(checkpoint_dict['ema']) + + if self.model.cuda_ray: + if 'mean_count' in checkpoint_dict: + self.model.mean_count = checkpoint_dict['mean_count'] + if 'mean_density' in checkpoint_dict: + self.model.mean_density = checkpoint_dict['mean_density'] + + if model_only: + return + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + self.global_step = checkpoint_dict['global_step'] + self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer.") + + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler.") + + if self.scaler and 'scaler' in checkpoint_dict: + try: + self.scaler.load_state_dict(checkpoint_dict['scaler']) + self.log("[INFO] loaded scaler.") + except: + self.log("[WARN] Failed to load scaler.") \ No newline at end of file diff --git a/torch-ngp/testing/test_ffmlp.py b/torch-ngp/testing/test_ffmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..8b8add230840161571b0d88a0a8623bfa74391f6 --- /dev/null +++ b/torch-ngp/testing/test_ffmlp.py @@ -0,0 +1,238 @@ +from matplotlib.animation import AVConvBase +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ffmlp import FFMLP +import math + +import tinycudann as tcnn + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, hidden_dim, num_layers, activation=F.relu): + super().__init__() + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.activation = activation + + self.net = nn.ModuleList() + self.net.append(nn.Linear(input_dim, hidden_dim, bias=False)) + for i in range(num_layers - 1): + self.net.append(nn.Linear(hidden_dim, hidden_dim, bias=False)) + self.net.append(nn.Linear(hidden_dim, output_dim, bias=False)) + + self.reset_parameters() + + def reset_parameters(self): + torch.manual_seed(42) + for p in self.parameters(): + #nn.init.constant_(p.data, 1) + std = math.sqrt(3 / self.hidden_dim) + p.data.uniform_(-std, std) + #torch.manual_seed(42) + #nn.init.uniform_(p.data, 0, 1) + #nn.init.eye_(p.data) + + + def forward(self, x): + for i in range(self.num_layers + 1): + x = self.net[i](x) + if i != self.num_layers: + x = self.activation(x) + return x + +# ################################## +# # Functionality +# ################################## + +# BATCH_SIZE = 1280000 # 1048576 # 128 # the least batch to lauch a full block ! +# INPUT_DIM = 16 # 16 # != (16 * m) has bug... +# OUTPUT_DIM = 1 # 16 # > 16 still has bug... +# HIDDEN_DIM = 64 # 16 +# NUM_LAYERS = 3 # 2 + + +# net0 = FFMLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() +# net1 = MLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() + +# # print(net0.weights) +# # print(net1.net[0].weight) + +# for _ in range(5): + +# x0 = torch.randn(BATCH_SIZE, INPUT_DIM).cuda() * 1 +# x1 = x0.detach().clone() +# x0.requires_grad_(True) +# x1.requires_grad_(True) + +# # print('===== x =====') +# # print(x0) +# # print(x1) + +# with torch.cuda.amp.autocast(enabled=True): +# y1 = net1(x1) +# y0 = net0(x0) + + +# print('===== y1 =====') +# print(y1) + +# print('===== y0 =====') +# print(y0) + +# (y1.sum() * 1).backward() +# print('===== grad w1 =====') +# print(net1.net[0].weight.grad.dtype, torch.cat([net1.net[0].weight.grad.view(-1), net1.net[1].weight.grad.view(-1), net1.net[2].weight.grad.view(-1)], dim=0)) +# print(x1.grad.dtype, x1.grad) + +# (y0.sum() * 1).backward() +# print('===== grad w0 =====') +# print(net0.weights.grad.dtype, net0.weights.grad) +# print(x0.grad.dtype, x0.grad) + + + +# ################################## +# # Speed +# ################################## + +BATCH_SIZE = 2**21 +INPUT_DIM = 16 +OUTPUT_DIM = 16 +HIDDEN_DIM = 64 +NUM_LAYERS = 2 + +net0 = FFMLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() +net1 = MLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() +net2 = tcnn.Network(n_input_dims=INPUT_DIM, n_output_dims=OUTPUT_DIM, network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": HIDDEN_DIM, + "n_hidden_layers": NUM_LAYERS, + }) + +x = torch.rand(BATCH_SIZE, INPUT_DIM).cuda() * 10 +x1 = x.detach().clone() +x2 = x.detach().clone() +x3 = x.detach().clone() + + + +#with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + +starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) +starter.record() +y2 = net1(x2) +ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'pytorch MLP (fp32 train) = {curr_time}') + +starter.record() +y2.sum().backward() +ender.record() +torch.cuda.synchronize() +curr_time = starter.elapsed_time(ender) +print(f'pytorch MLP (fp32 back) = {curr_time}') + +#print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + +with torch.cuda.amp.autocast(enabled=True): + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + starter.record() + y0 = net0(x) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'FFMLP (forward) = {curr_time}') + + starter.record() + y0.sum().backward() + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'FFMLP (backward) = {curr_time}') + + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + starter.record() + y1 = net1(x1) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'pytorch MLP (forward) = {curr_time}') + + starter.record() + y1.sum().backward() + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'pytorch MLP (backward) = {curr_time}') + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + starter.record() + y3 = net2(x3) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'TCNN (forward) = {curr_time}') + + starter.record() + y3.sum().backward() + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'TCNN (backward) = {curr_time}') + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + +with torch.no_grad(): + + starter.record() + y1 = net1(x) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'pytorch MLP (fp32 infer) = {curr_time}') + + with torch.cuda.amp.autocast(enabled=True): + + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + + starter.record() + y0 = net0(x) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'FFMLP (infer) = {curr_time}') + + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + + starter.record() + y1 = net1(x) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'pytorch MLP (infer) = {curr_time}') + + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + + starter.record() + y2 = net2(x) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'TCNN (infer) = {curr_time}') + + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + +# print(y0) +# print(y1) + \ No newline at end of file diff --git a/torch-ngp/testing/test_hashencoder.py b/torch-ngp/testing/test_hashencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..81c9d5381d787ea70e6badae19b180c7bb16c0d5 --- /dev/null +++ b/torch-ngp/testing/test_hashencoder.py @@ -0,0 +1,44 @@ + +import numpy as np +import torch +from gridencoder import GridEncoder + +B = 1 +D = 2 + +enc = GridEncoder(D=D, L=2, C=1, base_resolution=4, log2_hashmap_size=5).cuda() +#enc = GridEncoder(D=D, L=16, C=2, base_resolution=16).cuda() + +print(f"=== enc ===") +print(enc.embeddings.shape) +print(enc.embeddings) + +#x = torch.rand(B, D).cuda() * 2 - 1 # in [-1, 1] +x = torch.FloatTensor(np.array([ + #[-1, 1], + #[1, 1], + [0, 0], + #[-1, -1], + #[1, -1], +])).cuda() + +#x.requires_grad_(True) + +print(f"=== x ===") +print(x) +print(x.shape) + +y = enc(x, calc_grad_inputs=False) + +print(f"=== y ===") +print(y.shape) +print(y) + +y.sum().backward() + +print(f"=== grad enc ===") +print(enc.embeddings.grad.shape) +print(enc.embeddings.grad) + +#print(x.grad.shape) +#print(x.grad) \ No newline at end of file diff --git a/torch-ngp/testing/test_hashgrid_grad.py b/torch-ngp/testing/test_hashgrid_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..358af5b77a27bcd72b99d39ac1c80f5ac4f49d37 --- /dev/null +++ b/torch-ngp/testing/test_hashgrid_grad.py @@ -0,0 +1,62 @@ +# we need check the grad_hash_grid; +import torch +import torch.nn.functional as F +from torch.autograd import gradcheck +import numpy as np +from gridencoder.grid import _grid_encode +import random +import os +# import torch.random as random +device=torch.device(0) +input_dim=3 # 2 +num_levels=4 # 1 +level_dim=2 # 1 +per_level_scale=2 +base_resolution=4 # 2 +log2_hashmap_size=8 # 4 +# inputs , embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False + +output_dim = num_levels * level_dim + +if level_dim % 2 != 0: + print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)') + +# allocate parameters +offsets = [] +offset = 0 +max_params = 2 ** log2_hashmap_size +for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(max_params, (resolution + 1) ** input_dim) # limit max number + #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible + offsets.append(offset) + offset += params_in_level +offsets.append(offset) + +print(offsets) + +def seed_torch(seed=42): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + +#seed_torch() + +# parameters +inputs = torch.rand(1, input_dim, dtype= torch.float64, requires_grad=False).to(device) + +offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)).to(device) +embeddings = torch.randn(offset, level_dim, dtype=torch.float64, requires_grad=True).to(device) * 0.1 + +print(inputs) +print(embeddings) + + +Inputs = (inputs, embeddings, offsets, per_level_scale, base_resolution, inputs.requires_grad) +check_results1 = torch.autograd.gradcheck(_grid_encode.apply, Inputs, eps=1e-2, atol=1e-3, rtol=0.01, fast_mode=False) +print("check_results1", check_results1) diff --git a/torch-ngp/testing/test_raymarching.py b/torch-ngp/testing/test_raymarching.py new file mode 100644 index 0000000000000000000000000000000000000000..ec56d0695e597055bbb3fbf4b3ef6750d5901cb9 --- /dev/null +++ b/torch-ngp/testing/test_raymarching.py @@ -0,0 +1 @@ +import raymarching \ No newline at end of file diff --git a/torch-ngp/testing/test_shencoder.py b/torch-ngp/testing/test_shencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0d9b9fb88d2a15627b44a4f7b31e7adf1cf555 --- /dev/null +++ b/torch-ngp/testing/test_shencoder.py @@ -0,0 +1,147 @@ +import time +import numpy as np +import torch +import torch.nn as nn +from shencoder import SHEncoder + + +class SHEncoder_torch(nn.Module): + def __init__(self, input_dim=3, degree=4): + + super().__init__() + + self.input_dim = input_dim + self.degree = degree + + assert self.input_dim == 3 + assert self.degree >= 1 and self.degree <= 5 + + self.output_dim = degree ** 2 + + self.C0 = 0.28209479177387814 + self.C1 = 0.4886025119029199 + self.C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 + ] + self.C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 + ] + self.C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761 + ] + + def forward(self, input, **kwargs): + + result = torch.empty((*input.shape[:-1], self.output_dim), dtype=input.dtype, device=input.device) + x, y, z = input.unbind(-1) + + result[..., 0] = self.C0 + if self.degree > 1: + result[..., 1] = -self.C1 * y + result[..., 2] = self.C1 * z + result[..., 3] = -self.C1 * x + if self.degree > 2: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result[..., 4] = self.C2[0] * xy + result[..., 5] = self.C2[1] * yz + #result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy) + result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting... + result[..., 7] = self.C2[3] * xz + result[..., 8] = self.C2[4] * (xx - yy) + if self.degree > 3: + result[..., 9] = self.C3[0] * y * (3 * xx - yy) + result[..., 10] = self.C3[1] * xy * z + result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy) + result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy) + result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy) + result[..., 14] = self.C3[5] * z * (xx - yy) + result[..., 15] = self.C3[6] * x * (xx - 3 * yy) + if self.degree > 4: + result[..., 16] = self.C4[0] * xy * (xx - yy) + result[..., 17] = self.C4[1] * yz * (3 * xx - yy) + result[..., 18] = self.C4[2] * xy * (7 * zz - 1) + result[..., 19] = self.C4[3] * yz * (7 * zz - 3) + result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3) + result[..., 21] = self.C4[5] * xz * (7 * zz - 3) + result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1) + result[..., 23] = self.C4[7] * xz * (xx - 3 * yy) + result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) + + return result + +B = 25600 +C = 3 +degree = 4 + +enc1 = SHEncoder_torch(degree=degree).cuda() +enc2 = SHEncoder(degree=degree).cuda() + +x1 = torch.rand(B, 3).cuda() * 2 - 1 # in [-1, 1] +x1 = x1 / (torch.norm(x1, dim=-1, keepdim=True) + 1e-8) +x1.requires_grad_(True) + +x2 = x1.detach().clone() +x2.requires_grad_(True) + +print(f"=== x ===") +print(x1) + +starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + +with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=True): + + starter.record() + y1 = enc1(x1) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'time 1 = {curr_time}') + + starter.record() + y2 = enc2(x2) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'time 2 = {curr_time}') + + print(f"=== y ===") + print(y1) + print(y2) + + # starter.record() + # y1.sum().backward() + # ender.record() + # torch.cuda.synchronize() + # curr_time = starter.elapsed_time(ender) + # print(f'time 1 (back) = {curr_time}') + + # starter.record() + # y2.sum().backward() + # ender.record() + # torch.cuda.synchronize() + # curr_time = starter.elapsed_time(ender) + # print(f'time 2 (back) = {curr_time}') + + # print(f"=== grad x ===") + # print(x1.grad) + # print(x2.grad) \ No newline at end of file diff --git a/train_face.py b/train_face.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1e39c71db38c3291d130a31815dd23e68906a6 --- /dev/null +++ b/train_face.py @@ -0,0 +1,397 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import torch +from random import randint +from utils.loss_utils import l1_loss, l2_loss, patchify, ssim +from gaussian_renderer import render, render_motion +import sys +from scene import Scene, GaussianModel, MotionNetwork +from utils.general_utils import safe_state +import lpips +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): + testing_iterations = [i for i in range(0, opt.iterations + 1, 2000)] + checkpoint_iterations = saving_iterations = [i for i in range(0, opt.iterations + 1, 10000)] + [opt.iterations] + + # vars + warm_step = 3000 + opt.densify_until_iter = opt.iterations - 1000 + bg_iter = opt.iterations # opt.densify_until_iter + lpips_start_iter = opt.densify_until_iter - 2000 + motion_stop_iter = bg_iter + mouth_select_iter = bg_iter - 10000 + mouth_step = 1 / mouth_select_iter + hair_mask_interval = 7 + select_interval = 15 + + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + + motion_net = MotionNetwork(args=dataset).cuda() + motion_optimizer = torch.optim.AdamW(motion_net.get_params(5e-3, 5e-4), betas=(0.9, 0.99), eps=1e-8) + scheduler = torch.optim.lr_scheduler.LambdaLR(motion_optimizer, lambda iter: (0.5 ** (iter / mouth_select_iter)) if iter < mouth_select_iter else 0.1 ** (iter / bg_iter)) + + lpips_criterion = lpips.LPIPS(net='alex').eval().cuda() + + gaussians.training_setup(opt) + if checkpoint: + (model_params, motion_params, motion_optimizer_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + motion_net.load_state_dict(motion_params) + motion_optimizer.load_state_dict(motion_optimizer_params) + + bg_color = [0, 1, 0] # [1, 1, 1] # if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), ascii=True, dynamic_ncols=True, desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + # find a big mouth + mouth_global_lb = viewpoint_cam.talking_dict['mouth_bound'][0] + mouth_global_ub = viewpoint_cam.talking_dict['mouth_bound'][1] + mouth_global_lb += (mouth_global_ub - mouth_global_lb) * 0.2 + mouth_window = (mouth_global_ub - mouth_global_lb) * 0.2 + + mouth_lb = mouth_global_lb + mouth_step * iteration * (mouth_global_ub - mouth_global_lb) + mouth_ub = mouth_lb + mouth_window + mouth_lb = mouth_lb - mouth_window + + + au_global_lb = 0 + au_global_ub = 1 + au_window = 0.3 + + au_lb = au_global_lb + mouth_step * iteration * (au_global_ub - au_global_lb) + au_ub = au_lb + au_window + au_lb = au_lb - au_window * 0.5 + + + if iteration < warm_step: + if iteration % select_interval == 0: + while viewpoint_cam.talking_dict['mouth_bound'][2] < mouth_lb or viewpoint_cam.talking_dict['mouth_bound'][2] > mouth_ub: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + + if warm_step < iteration < mouth_select_iter: + + if iteration % select_interval == 0: + while viewpoint_cam.talking_dict['blink'] < au_lb or viewpoint_cam.talking_dict['blink'] > au_ub: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + face_mask = torch.as_tensor(viewpoint_cam.talking_dict["face_mask"]).cuda() + hair_mask = torch.as_tensor(viewpoint_cam.talking_dict["hair_mask"]).cuda() + mouth_mask = torch.as_tensor(viewpoint_cam.talking_dict["mouth_mask"]).cuda() + head_mask = face_mask + hair_mask + + if iteration > lpips_start_iter: + max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + mouth_mask = (-max_pool(-max_pool(mouth_mask[None].float())))[0].bool() + + + hair_mask_iter = (warm_step < iteration < lpips_start_iter - 1000) and iteration % hair_mask_interval != 0 + + if iteration < warm_step: + render_pkg = render(viewpoint_cam, gaussians, pipe, background) + else: + render_pkg = render_motion(viewpoint_cam, gaussians, motion_net, pipe, background, return_attn=True) + + image_white, alpha, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["alpha"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + gt_image = viewpoint_cam.original_image.cuda() / 255.0 + gt_image_white = gt_image * head_mask + background[:, None, None] * ~head_mask + + if iteration > motion_stop_iter: + for param in motion_net.parameters(): + param.requires_grad = False + if iteration > bg_iter: + gaussians._xyz.requires_grad = False + gaussians._opacity.requires_grad = False + # gaussians._features_dc.requires_grad = False + # gaussians._features_rest.requires_grad = False + gaussians._scaling.requires_grad = False + gaussians._rotation.requires_grad = False + + # Loss + if iteration < bg_iter: + if hair_mask_iter: + image_white[:, hair_mask] = background[:, None] + gt_image_white[:, hair_mask] = background[:, None] + + # image_white[:, mouth_mask] = 1 + gt_image_white[:, mouth_mask] = background[:, None] + + Ll1 = l1_loss(image_white, gt_image_white) + loss = Ll1 + opt.lambda_dssim * (1.0 - ssim(image_white, gt_image_white)) + + # mouth_alpha_loss = 1e-2 * (alpha[:,mouth_mask]).mean() + # if not torch.isnan(mouth_alpha_loss): + # loss += mouth_alpha_loss + # print(alpha[:,mouth_mask], mouth_mask.sum()) + + if iteration > warm_step: + loss += 1e-5 * (render_pkg['motion']['d_xyz'].abs()).mean() + loss += 1e-5 * (render_pkg['motion']['d_rot'].abs()).mean() + loss += 1e-5 * (render_pkg['motion']['d_opa'].abs()).mean() + loss += 1e-5 * (render_pkg['motion']['d_scale'].abs()).mean() + + loss += 1e-3 * (((1-alpha) * head_mask).mean() + (alpha * ~head_mask).mean()) + + + [xmin, xmax, ymin, ymax] = viewpoint_cam.talking_dict['lips_rect'] + loss += 1e-4 * (render_pkg["attn"][1, xmin:xmax, ymin:ymax]).mean() + if not hair_mask_iter: + loss += 1e-4 * (render_pkg["attn"][1][hair_mask]).mean() + loss += 1e-4 * (render_pkg["attn"][0][hair_mask]).mean() + + # loss += l2_loss(image_white[:, xmin:xmax, ymin:ymax], image_white[:, xmin:xmax, ymin:ymax]) + + image_t = image_white.clone() + gt_image_t = gt_image_white.clone() + + else: + # with real bg + image = image_white - background[:, None, None] * (1.0 - alpha) + viewpoint_cam.background.cuda() / 255.0 * (1.0 - alpha) + + Ll1 = l1_loss(image, gt_image) + loss = Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + + image_t = image.clone() + gt_image_t = gt_image.clone() + + if iteration > lpips_start_iter: + # mask mouth + [xmin, xmax, ymin, ymax] = viewpoint_cam.talking_dict['lips_rect'] + loss += 0.01 * lpips_criterion(image_t.clone()[:, xmin:xmax, ymin:ymax] * 2 - 1, gt_image_t.clone()[:, xmin:xmax, ymin:ymax] * 2 - 1).mean() + + image_t[:, xmin:xmax, ymin:ymax] = background[:, None, None] + gt_image_t[:, xmin:xmax, ymin:ymax] = background[:, None, None] + + patch_size = random.randint(32, 48) * 2 + loss += 0.2 * lpips_criterion(patchify(image_t[None, ...] * 2 - 1, patch_size), patchify(gt_image_t[None, ...] * 2 - 1, patch_size)).mean() + # loss += 0.5 * lpips_criterion(image_t[None, ...] * 2 - 1, gt_image_t[None, ...] * 2 - 1).mean() + + + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}", "Mouth": f"{mouth_lb:.{1}f}-{mouth_ub:.{1}f}"}) # , "AU25": f"{au_lb:.{1}f}-{au_ub:.{1}f}" + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, motion_net, render if iteration < warm_step else render_motion, (pipe, background)) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(str(iteration)+'_face') + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + ckpt = (gaussians.capture(), motion_net.state_dict(), motion_optimizer.state_dict(), iteration) + torch.save(ckpt, scene.model_path + "/chkpnt_face_" + str(iteration) + ".pth") + torch.save(ckpt, scene.model_path + "/chkpnt_face_latest" + ".pth") + + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.05 + 0.25 * iteration / opt.densify_until_iter, scene.cameras_extent, size_threshold) + + + # bg prune + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + from utils.sh_utils import eval_sh + + shs_view = gaussians.get_features.transpose(1, 2).view(-1, 3, (gaussians.max_sh_degree+1)**2) + dir_pp = (gaussians.get_xyz - viewpoint_cam.camera_center.repeat(gaussians.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(gaussians.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + + bg_color_mask = (colors_precomp[..., 0] < 30/255) * (colors_precomp[..., 1] > 225/255) * (colors_precomp[..., 2] < 30/255) + gaussians.prune_points(bg_color_mask.squeeze()) + + + # Optimizer step + if iteration < opt.iterations: + motion_optimizer.step() + gaussians.optimizer.step() + + motion_optimizer.zero_grad() + gaussians.optimizer.zero_grad(set_to_none = True) + + scheduler.step() + + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, motion_net, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : [scene.getTestCameras()[idx % len(scene.getTestCameras())] for idx in range(5, 100, 5)]}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config['cameras']): + + if renderFunc is render: + render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) + else: + render_pkg = renderFunc(viewpoint, scene.gaussians, motion_net, return_attn=True, frame_idx=0, *renderArgs) + + image = torch.clamp(render_pkg["render"], 0.0, 1.0) + alpha = render_pkg["alpha"] + image = image - renderArgs[1][:, None, None] * (1.0 - alpha) + viewpoint.background.cuda() / 255.0 * (1.0 - alpha) + gt_image = torch.clamp(viewpoint.original_image.to("cuda") / 255.0, 0.0, 1.0) + + mouth_mask = torch.as_tensor(viewpoint.talking_dict["mouth_mask"]).cuda() + max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + mouth_mask_post = (-max_pool(-max_pool(mouth_mask[None].float())))[0].bool() + + if tb_writer and (idx < 5): + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) + # tb_writer.add_images(config['name'] + "_view_{}/depth".format(viewpoint.image_name), (render_pkg["depth"] / render_pkg["depth"].max())[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/mouth_mask_post".format(viewpoint.image_name), (~mouth_mask_post * gt_image)[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/mouth_mask".format(viewpoint.image_name), (~mouth_mask[None] * gt_image)[None], global_step=iteration) + + if renderFunc is not render: + tb_writer.add_images(config['name'] + "_view_{}/attn_a".format(viewpoint.image_name), (render_pkg["attn"][0] / render_pkg["attn"][0].max())[None, None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/attn_e".format(viewpoint.image_name), (render_pkg["attn"][1] / render_pkg["attn"][1].max())[None, None], global_step=iteration) + + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + if tb_writer: + tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--debug_from', type=int, default=-1) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + + # All done + print("\nTraining complete.") diff --git a/train_fuse.py b/train_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..c7642c70363df307f9cce590a25806a6c8ddf694 --- /dev/null +++ b/train_fuse.py @@ -0,0 +1,264 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import torch +from random import randint +from utils.loss_utils import l1_loss, l2_loss, patchify, ssim +from gaussian_renderer import render, render_motion, render_motion_mouth +import sys +from scene import Scene, GaussianModel, MotionNetwork, MouthMotionNetwork +from utils.general_utils import safe_state +import lpips +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): + opt.iterations = 10000 + opt.densify_until_iter = 0 + + testing_iterations = [i for i in range(0, opt.iterations + 1, 2000)] + checkpoint_iterations = [opt.iterations] + + # vars + bg_iter = opt.densify_until_iter + lpips_start_iter = 5000 + + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians_mouth = GaussianModel(dataset.sh_degree) + with torch.no_grad(): + motion_net_mouth = MouthMotionNetwork(args=dataset).cuda() + motion_net = MotionNetwork(args=dataset).cuda() + + gaussians.training_setup(opt) + gaussians_mouth.training_setup(opt) + + (model_params, motion_params, _, _) = torch.load(os.path.join(scene.model_path, "chkpnt_face_latest.pth")) + gaussians.restore(model_params, opt) + motion_net.load_state_dict(motion_params) + + (model_params, motion_params, _, _) = torch.load(os.path.join(scene.model_path, "chkpnt_mouth_latest.pth")) + gaussians_mouth.restore(model_params, opt) + motion_net_mouth.load_state_dict(motion_params) + + lpips_criterion = lpips.LPIPS(net='alex').eval().cuda() + + bg_color = [0, 1, 0] # [1, 1, 1] # if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), ascii=True, dynamic_ncols=True, desc="Training progress") + first_iter += 1 + + for iteration in range(first_iter, opt.iterations + 1): + + iter_start.record() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + gaussians.update_learning_rate(iteration) + + face_mask = torch.as_tensor(viewpoint_cam.talking_dict["face_mask"]).cuda() + hair_mask = torch.as_tensor(viewpoint_cam.talking_dict["hair_mask"]).cuda() + mouth_mask = torch.as_tensor(viewpoint_cam.talking_dict["mouth_mask"]).cuda() + head_mask = face_mask + hair_mask + mouth_mask + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + render_pkg = render_motion(viewpoint_cam, gaussians, motion_net, pipe, background) + render_pkg_mouth = render_motion_mouth(viewpoint_cam, gaussians_mouth, motion_net_mouth, pipe, background) + viewspace_point_tensor, visibility_filter = render_pkg["viewspace_points"], render_pkg["visibility_filter"] + viewspace_point_tensor_mouth, visibility_filter_mouth = render_pkg_mouth["viewspace_points"], render_pkg_mouth["visibility_filter"] + + + alpha_mouth = render_pkg_mouth["alpha"] + alpha = render_pkg["alpha"] + mouth_image = render_pkg_mouth["render"] - background[:, None, None] * (1.0 - alpha_mouth) + viewpoint_cam.background.cuda() / 255.0 * (1.0 - alpha_mouth) + image = render_pkg["render"] - background[:, None, None] * (1.0 - alpha) + mouth_image * (1.0 - alpha) + + gt_image = viewpoint_cam.original_image.cuda() / 255.0 + gt_image_white = gt_image * head_mask + background[:, None, None] * ~head_mask + + if iteration > bg_iter: + for param in motion_net.parameters(): + param.requires_grad = False + for param in motion_net_mouth.parameters(): + param.requires_grad = False + + gaussians._xyz.requires_grad = False + # gaussians._opacity.requires_grad = False + gaussians._scaling.requires_grad = False + gaussians._rotation.requires_grad = False + + gaussians_mouth._xyz.requires_grad = False + gaussians_mouth._opacity.requires_grad = False + gaussians_mouth._scaling.requires_grad = False + gaussians_mouth._rotation.requires_grad = False + + + # Loss + if iteration < bg_iter: + image[:, ~head_mask] = background[:, None] + # gt_image_white[:, ~head_mask] = background[:, None] + + Ll1 = l1_loss(image, gt_image_white) + loss = Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image_white)) + loss += 1e-3 * (((1-alpha) * head_mask).mean() + (alpha * ~head_mask).mean()) + + image_t = image.clone() + gt_image_t = gt_image_white.clone() + + else: + Ll1 = l1_loss(image, gt_image) + loss = Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + + image_t = image.clone() + gt_image_t = gt_image.clone() + + if iteration > lpips_start_iter: + # mask mouth + # [xmin, xmax, ymin, ymax] = viewpoint_cam.talking_dict['lips_rect'] + # image_t[:, xmin:xmax, ymin:ymax] = 1 + # gt_image_t[:, xmin:xmax, ymin:ymax] = 1 + + patch_size = random.randint(16, 21) * 2 + loss += 0.5 * lpips_criterion(patchify(image_t[None, ...] * 2 - 1, patch_size), patchify(gt_image_t[None, ...] * 2 - 1, patch_size)).mean() + + + + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report(tb_writer, iteration, testing_iterations, image, gt_image) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + ckpt = (gaussians.capture(), motion_net.state_dict(), gaussians_mouth.capture(), motion_net_mouth.state_dict()) + torch.save(ckpt, scene.model_path + "/chkpnt_fuse_" + str(iteration) + ".pth") + torch.save(ckpt, scene.model_path + "/chkpnt_fuse_latest" + ".pth") + + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + gaussians_mouth.add_densification_stats(viewspace_point_tensor_mouth, visibility_filter_mouth) + + if iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.3, scene.cameras_extent, size_threshold) + gaussians_mouth.densify_and_prune(opt.densify_grad_threshold, 0.3, scene.cameras_extent, size_threshold) + + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians_mouth.optimizer.step() + + gaussians.optimizer.zero_grad(set_to_none = True) + gaussians_mouth.optimizer.zero_grad(set_to_none = True) + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + + +def training_report(tb_writer, iteration, testing_iterations, image, gt_image): + # Report test and samples of training set + if iteration in testing_iterations: + tb_writer.add_images("fuse/render", image[None], global_step=iteration) + tb_writer.add_images("fuse/ground_truth", gt_image[None], global_step=iteration) + + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--debug_from', type=int, default=-1) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + + # All done + print("\nTraining complete.") diff --git a/train_mouth.py b/train_mouth.py new file mode 100644 index 0000000000000000000000000000000000000000..eab9a6ada9175f3eeb27c7fc3f7fbaa5fe3d6de8 --- /dev/null +++ b/train_mouth.py @@ -0,0 +1,331 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import torch +from random import randint +from utils.loss_utils import l1_loss, l2_loss, patchify, ssim +from gaussian_renderer import render, render_motion, render_motion_mouth +import sys +from scene import Scene, GaussianModel, MouthMotionNetwork +from utils.general_utils import safe_state +import lpips +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): + testing_iterations = [i for i in range(0, opt.iterations + 1, 2000)] + checkpoint_iterations = saving_iterations = [i for i in range(0, opt.iterations + 1, 10000)] + [opt.iterations] + + # vars + warm_step = 3000 + bg_iter = opt.iterations-1000 # opt.densify_until_iter + lpips_start_iter = bg_iter + motion_stop_iter = bg_iter + mouth_select_iter = bg_iter - 10000 + mouth_step = 1 / mouth_select_iter + select_interval = 7 + + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + + motion_net = MouthMotionNetwork(args=dataset).cuda() + motion_optimizer = torch.optim.AdamW(motion_net.get_params(5e-3, 5e-4), betas=(0.9, 0.99), eps=1e-8) + scheduler = torch.optim.lr_scheduler.LambdaLR(motion_optimizer, lambda iter: (0.5 ** (iter / mouth_select_iter)) if iter < mouth_select_iter else 0.1 ** (iter / bg_iter)) + + lpips_criterion = lpips.LPIPS(net='alex').eval().cuda() + + gaussians.training_setup(opt) + if checkpoint: + (model_params, motion_params, motion_optimizer_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + motion_net.load_state_dict(motion_params) + motion_optimizer.load_state_dict(motion_optimizer_params) + + bg_color = [0, 1, 0] # if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), ascii=True, dynamic_ncols=True, desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + # find a big mouth + + au_global_lb = viewpoint_cam.talking_dict['au25'][1] + au_global_ub = viewpoint_cam.talking_dict['au25'][4] + au_window = (au_global_ub - au_global_lb) * 0.2 + + au_ub = au_global_ub + au_lb = au_ub - mouth_step * iteration * (au_global_ub - au_global_lb) + + if iteration < warm_step: + while viewpoint_cam.talking_dict['au25'][0] < au_global_ub: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + if warm_step < iteration < mouth_select_iter: + if iteration % select_interval == 0: + while viewpoint_cam.talking_dict['au25'][0] < au_lb or viewpoint_cam.talking_dict['au25'][0] > au_ub: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + while torch.as_tensor(viewpoint_cam.talking_dict["mouth_mask"]).cuda().sum() < 20: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + + + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + if iteration > bg_iter: + # turn to black + bg_color = [0, 0, 0] # if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + face_mask = torch.as_tensor(viewpoint_cam.talking_dict["face_mask"]).cuda() + hair_mask = torch.as_tensor(viewpoint_cam.talking_dict["hair_mask"]).cuda() + mouth_mask = torch.as_tensor(viewpoint_cam.talking_dict["mouth_mask"]).cuda() + head_mask = face_mask + hair_mask + + [xmin, xmax, ymin, ymax] = viewpoint_cam.talking_dict['lips_rect'] + lips_mask = torch.zeros_like(mouth_mask) + lips_mask[xmin:xmax, ymin:ymax] = True + + if iteration < warm_step: + render_pkg = render(viewpoint_cam, gaussians, pipe, background) + else: + render_pkg = render_motion_mouth(viewpoint_cam, gaussians, motion_net, pipe, background) + + image_green, alpha, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["alpha"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + gt_image = viewpoint_cam.original_image.cuda() / 255.0 + gt_image_green = gt_image * mouth_mask + background[:, None, None] * ~mouth_mask + + if iteration > motion_stop_iter: + for param in motion_net.parameters(): + param.requires_grad = False + if iteration > bg_iter: + gaussians._xyz.requires_grad = False + gaussians._opacity.requires_grad = False + # gaussians._features_dc.requires_grad = False + # gaussians._features_rest.requires_grad = False + gaussians._scaling.requires_grad = False + gaussians._rotation.requires_grad = False + + # Loss + image_green[:, (lips_mask ^ mouth_mask)] = background[:, None] + + Ll1 = l1_loss(image_green, gt_image_green) + loss = Ll1 + opt.lambda_dssim * (1.0 - ssim(image_green, gt_image_green)) + + + if iteration > warm_step: + # loss += 1e-5 * (render_pkg['motion']['d_xyz'].abs()).mean() + loss += 1e-3 * (((1-alpha) * lips_mask).mean() + (alpha * ~lips_mask).mean()) + + image_t = image_green.clone() + gt_image_t = gt_image_green.clone() + + if iteration > lpips_start_iter: + patch_size = random.randint(16, 21) * 2 + loss += 0.5 * lpips_criterion(patchify(image_t[None, ...] * 2 - 1, patch_size), patchify(gt_image_t[None, ...] * 2 - 1, patch_size)).mean() + + + + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}", "AU25": f"{au_lb:.{1}f}-{au_ub:.{1}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(str(iteration)+'_mouth') + + # Log and save + training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, motion_net, render if iteration < warm_step else render_motion_mouth, (pipe, background)) + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + ckpt = (gaussians.capture(), motion_net.state_dict(), motion_optimizer.state_dict(), iteration) + torch.save(ckpt, scene.model_path + "/chkpnt_mouth_" + str(iteration) + ".pth") + torch.save(ckpt, scene.model_path + "/chkpnt_mouth_latest" + ".pth") + + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.05 + 0.25 * iteration / opt.densify_until_iter, scene.cameras_extent, size_threshold) + + shs_view = gaussians.get_features.transpose(1, 2).view(-1, 3, (gaussians.max_sh_degree+1)**2) + dir_pp = (gaussians.get_xyz - viewpoint_cam.camera_center.repeat(gaussians.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + from utils.sh_utils import eval_sh + sh2rgb = eval_sh(gaussians.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + + bg_color_mask = (colors_precomp[..., 0] < 20/255) * (colors_precomp[..., 1] > 235/255) * (colors_precomp[..., 2] < 20/255) + gaussians.xyz_gradient_accum[bg_color_mask] /= 2 + gaussians._opacity[bg_color_mask] = gaussians.inverse_opacity_activation(torch.ones_like(gaussians._opacity[bg_color_mask]) * 0.1) + gaussians._scaling[bg_color_mask] /= 10 + + # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + # gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + motion_optimizer.step() + gaussians.optimizer.step() + + motion_optimizer.zero_grad() + gaussians.optimizer.zero_grad(set_to_none = True) + + scheduler.step() + + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, motion_net, renderFunc, renderArgs): + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : [scene.getTestCameras()[idx % len(scene.getTestCameras())] for idx in range(5, 100, 10)]}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config['cameras']): + + if renderFunc is render: + render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs) + else: + render_pkg = renderFunc(viewpoint, scene.gaussians, motion_net, *renderArgs) + + image = torch.clamp(render_pkg["render"], 0.0, 1.0) + alpha = render_pkg["alpha"] + image = image - renderArgs[1][:, None, None] * (1.0 - alpha) + viewpoint.background.cuda() / 255.0 * (1.0 - alpha) + gt_image = torch.clamp(viewpoint.original_image.to("cuda") / 255.0, 0.0, 1.0) + if tb_writer and (idx < 5): + tb_writer.add_images(config['name'] + "_view_{}_mouth/render".format(viewpoint.image_name), image[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}_mouth/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}_mouth/depth".format(viewpoint.image_name), (render_pkg["depth"] / render_pkg["depth"].max())[None], global_step=iteration) + + + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--debug_from', type=int, default=-1) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + + # All done + print("\nTraining complete.") diff --git a/utils/audio_utils.py b/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac7d0c76dba9e1c29624ffd690a7142d31d91c1 --- /dev/null +++ b/utils/audio_utils.py @@ -0,0 +1,35 @@ +import torch + +def get_audio_features(features, att_mode, index): + if att_mode == 0: + return features[[index]] + elif att_mode == 1: + left = index - 8 + pad_left = 0 + if left < 0: + pad_left = -left + left = 0 + auds = features[left:index] + if pad_left > 0: + # pad may be longer than auds, so do not use zeros_like + auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) + return auds + elif att_mode == 2: + left = index - 4 + right = index + 4 + pad_left = 0 + pad_right = 0 + if left < 0: + pad_left = -left + left = 0 + if right > features.shape[0]: + pad_right = right - features.shape[0] + right = features.shape[0] + auds = features[left:right] + if pad_left > 0: + auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) + if pad_right > 0: + auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] + return auds + else: + raise NotImplementedError(f'wrong att_mode: {att_mode}') \ No newline at end of file diff --git a/utils/camera_utils.py b/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c29f345fa18e43b99379e92d3879083179f3431b --- /dev/null +++ b/utils/camera_utils.py @@ -0,0 +1,60 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +WARNED = False + +def loadCam(args, id, cam_info, resolution_scale): + + image_rgb = PILtoTorch(cam_info.image).type("torch.ByteTensor") + background = PILtoTorch(cam_info.background)[:3, ...].type("torch.ByteTensor") + + gt_image = image_rgb[:3, ...] + loaded_mask = None + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=gt_image, gt_alpha_mask=loaded_mask, background=background, talking_dict=cam_info.talking_dict, + image_name=cam_info.image_name, uid=id, data_device=args.data_device) + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + +def camera_to_JSON(id, camera : Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + 'id' : id, + 'img_name' : camera.image_name, + 'width' : camera.width, + 'height' : camera.height, + 'position': pos.tolist(), + 'rotation': serializable_array_2d, + 'fy' : fov2focal(camera.FovY, camera.height), + 'fx' : fov2focal(camera.FovX, camera.width) + } + return camera_entry diff --git a/utils/general_utils.py b/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..689b0af49fcc0ddc6cc0885426d86977396e1db4 --- /dev/null +++ b/utils/general_utils.py @@ -0,0 +1,140 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +# def PILtoTorch(pil_image, resolution): +# resized_image_PIL = pil_image.resize(resolution) +# resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 +# if len(resized_image.shape) == 3: +# return resized_image.permute(2, 0, 1) +# else: +# return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def PILtoTorch(np_image): + resized_image = torch.from_numpy(np_image) # / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4627d837c74fcdffc898fa0c3071cb7b316802b --- /dev/null +++ b/utils/graphics_utils.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +import numpy as np +from typing import NamedTuple + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) \ No newline at end of file diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaa1b6d250e549181ab165070f82ccd31b3eb9 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,19 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/utils/loss_utils.py b/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1265b08395cf2ef8972028a9b0205f68d714902 --- /dev/null +++ b/utils/loss_utils.py @@ -0,0 +1,68 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp + +def patchify(input, patch_size): + patches = F.unfold(input, kernel_size=patch_size, stride=patch_size).permute(0,2,1).view(-1, 3, patch_size, patch_size) + return patches + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + diff --git a/utils/sh_utils.py b/utils/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/utils/system_utils.py b/utils/system_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e --- /dev/null +++ b/utils/system_utils.py @@ -0,0 +1,28 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters)