diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..53aee21b9a286bf9d1904e9527c3f0b047f4f0ea 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8b3d8a077a8756a19b61019756a549494972cac0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +flagged/ +sample_output/ +wandb/ +.vscode +.DS_Store +*ckpt*/ +# Custom +*.pt +data/local +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/README.md b/README.md index c908a08c25f0cbf1578787a063aa2365ce0e1167..9a1ce52391d49d316e66b141754ddbf8f9bc2e06 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ --- -title: SwinTExCo Cpu +title: Exemplar-based Video Colorization using Vision Transformer (CPU version) emoji: 🏃 colorFrom: green colorTo: yellow diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..effa29f423f1ba41ba153d380e06277753c1990b --- /dev/null +++ b/app.py @@ -0,0 +1,50 @@ +import gradio as gr +from src.inference import SwinTExCo +import cv2 +import os +from PIL import Image +import time +import app_config as cfg + + +model = SwinTExCo(weights_path=cfg.ckpt_path) + +def video_colorization(video_path, ref_image, progress=gr.Progress()): + # Initialize video reader + video_reader = cv2.VideoCapture(video_path) + fps = video_reader.get(cv2.CAP_PROP_FPS) + height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT)) + width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH)) + num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Initialize reference image + ref_image = Image.fromarray(ref_image) + + # Initialize video writer + output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4') + video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + # Init progress bar + + for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")): + colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR) + video_writer.write(colorized_frame) + + # for i in progress.tqdm(range(1000)): + # time.sleep(0.5) + + video_writer.release() + + return output_path + +app = gr.Interface( + fn=video_colorization, + inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True), + gr.Image(sources="upload", label="Reference image (color)")], + outputs=gr.Video(label="Output video (colorized)"), + title=cfg.TITLE, + description=cfg.DESCRIPTION +).queue() + + +app.launch() \ No newline at end of file diff --git a/app_config.py b/app_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b2727ead5fdcf6eb4d135a55504089ceeb2a93a9 --- /dev/null +++ b/app_config.py @@ -0,0 +1,9 @@ +ckpt_path = 'checkpoints/epoch_20' +TITLE = 'Deep Exemplar-based Video Colorization using Vision Transformer' +DESCRIPTION = ''' +
+This is a demo app of the thesis: Deep Exemplar-based Video Colorization using Vision Transformer.
+The code is available at: The link will be updated soon.
+Our previous work was also written into paper and accepted at the ICTC 2023 conference (Section B1-4). +
+'''.strip() \ No newline at end of file diff --git a/checkpoints/epoch_10/colornet.pth b/checkpoints/epoch_10/colornet.pth new file mode 100644 index 0000000000000000000000000000000000000000..f3a8d9c2b872f34bd7080d01dfbf6d4d22bbe3ab --- /dev/null +++ b/checkpoints/epoch_10/colornet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ecb43b5e02b77bec5342e2e296d336bf8f384a07d3c809d1a548fd5fb1e7365 +size 131239411 diff --git a/checkpoints/epoch_10/discriminator.pth b/checkpoints/epoch_10/discriminator.pth new file mode 100644 index 0000000000000000000000000000000000000000..9e14193d0dcd96bc08ca72cf330320a72d3cfdd5 --- /dev/null +++ b/checkpoints/epoch_10/discriminator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce8968a9d3d2f99b1bc1e32080507e0d671cee00b66200105c8839be684b84b4 +size 45073068 diff --git a/checkpoints/epoch_10/embed_net.pth b/checkpoints/epoch_10/embed_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..0439349777f69682d5b01e03b96659ad64c817c9 --- /dev/null +++ b/checkpoints/epoch_10/embed_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc711755a75c43025dabe9407cbd11d164eaa9e21f26430d0c16c7493410d902 +size 110352261 diff --git a/checkpoints/epoch_10/learning_state.pth b/checkpoints/epoch_10/learning_state.pth new file mode 100644 index 0000000000000000000000000000000000000000..81ab6499ba83f0d580c2668e4bf2915a1a1f1ff2 --- /dev/null +++ b/checkpoints/epoch_10/learning_state.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d09b1e96fdf0205930a21928449a44c51cedd965cc0d573068c73971bcb8bd2 +size 748166487 diff --git a/checkpoints/epoch_10/nonlocal_net.pth b/checkpoints/epoch_10/nonlocal_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..389e42c55fcd2e6a853e1f3a288f23fd3d653a8d --- /dev/null +++ b/checkpoints/epoch_10/nonlocal_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86c97d6803d625a0dff8c6c09b70852371906eb5ef77df0277c27875666a68e2 +size 73189765 diff --git a/checkpoints/epoch_12/colornet.pth b/checkpoints/epoch_12/colornet.pth new file mode 100644 index 0000000000000000000000000000000000000000..d56a75ae928a1e69d125437d714b997a3736b8e7 --- /dev/null +++ b/checkpoints/epoch_12/colornet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50f4b92cd59f4c88c0c1d7c93652413d54b1b96d729fc4b93e235887b5164f28 +size 131239846 diff --git a/checkpoints/epoch_12/discriminator.pth b/checkpoints/epoch_12/discriminator.pth new file mode 100644 index 0000000000000000000000000000000000000000..80f3701e5da8d9ba91e1e59e91c6fc42c228245a --- /dev/null +++ b/checkpoints/epoch_12/discriminator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b54b0bad6ceec33569cc5833cbf03ed8ddbb5f07998aa634badf8298d3cd15f +size 45073513 diff --git a/checkpoints/epoch_12/embed_net.pth b/checkpoints/epoch_12/embed_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..a7c02eee1819cb220e835fa88db8c0265bb783f8 --- /dev/null +++ b/checkpoints/epoch_12/embed_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865 +size 110352698 diff --git a/checkpoints/epoch_12/learning_state.pth b/checkpoints/epoch_12/learning_state.pth new file mode 100644 index 0000000000000000000000000000000000000000..b2d2a3012bb8c67cb8d7ad3630316b78b0a42341 --- /dev/null +++ b/checkpoints/epoch_12/learning_state.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f8bb4dbb3cb8e497a9a2079947f0221823fa8b44695e2d2ad8478be48464fad +size 748166934 diff --git a/checkpoints/epoch_12/nonlocal_net.pth b/checkpoints/epoch_12/nonlocal_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..bf47c4949ef6096e234d8582179d65881d97e9f2 --- /dev/null +++ b/checkpoints/epoch_12/nonlocal_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1f76b53dad7bf15c7d26aa106c95387e75751b8c31fafef2bd73ea7d77160cb +size 73190208 diff --git a/checkpoints/epoch_16/colornet.pth b/checkpoints/epoch_16/colornet.pth new file mode 100644 index 0000000000000000000000000000000000000000..8ad3e949edc5ef3b1b9fbe8242c1be89d8a69398 --- /dev/null +++ b/checkpoints/epoch_16/colornet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81ec9cff0ad5b0d920179fa7a9cc229e1424bfc796b7134604ff66b97d748c49 +size 131239846 diff --git a/checkpoints/epoch_16/discriminator.pth b/checkpoints/epoch_16/discriminator.pth new file mode 100644 index 0000000000000000000000000000000000000000..0d417db3092c3558beb444bc8e1f34784d207c2a --- /dev/null +++ b/checkpoints/epoch_16/discriminator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42262d5ed7596f38e65774085222530eee57da8dfaa7fe1aa223d824ed166f62 +size 45073513 diff --git a/checkpoints/epoch_16/embed_net.pth b/checkpoints/epoch_16/embed_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..a7c02eee1819cb220e835fa88db8c0265bb783f8 --- /dev/null +++ b/checkpoints/epoch_16/embed_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865 +size 110352698 diff --git a/checkpoints/epoch_16/learning_state.pth b/checkpoints/epoch_16/learning_state.pth new file mode 100644 index 0000000000000000000000000000000000000000..b0c7771adc7c37e9c5e873fde1ab4dfa881e9d6e --- /dev/null +++ b/checkpoints/epoch_16/learning_state.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea4cf81341750ebf517c696a0f6241bfeede0584b0ce75ad208e3ffc8280877f +size 748166934 diff --git a/checkpoints/epoch_16/nonlocal_net.pth b/checkpoints/epoch_16/nonlocal_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..6e7933d9f6e20f85894a67ca281028007e47e035 --- /dev/null +++ b/checkpoints/epoch_16/nonlocal_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85b63363bc9c79732df78ba50ed19491ed86e961214bbd1f796a871334eba516 +size 73190208 diff --git a/checkpoints/epoch_20/colornet.pth b/checkpoints/epoch_20/colornet.pth new file mode 100644 index 0000000000000000000000000000000000000000..283002997e001fbe870daa4c868165fab041fae0 --- /dev/null +++ b/checkpoints/epoch_20/colornet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c524f4e5df5f6ce91db1973a30de55299ebcbbde1edd2009718d3b4cd2631339 +size 131239846 diff --git a/checkpoints/epoch_20/discriminator.pth b/checkpoints/epoch_20/discriminator.pth new file mode 100644 index 0000000000000000000000000000000000000000..d9bdc788c644fe5a20280acea1f20ebf9fe55014 --- /dev/null +++ b/checkpoints/epoch_20/discriminator.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcd80950c796fcfe6e4b6bdeeb358776700458d868da94ee31df3d1d37779310 +size 45073513 diff --git a/checkpoints/epoch_20/embed_net.pth b/checkpoints/epoch_20/embed_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..a7c02eee1819cb220e835fa88db8c0265bb783f8 --- /dev/null +++ b/checkpoints/epoch_20/embed_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865 +size 110352698 diff --git a/checkpoints/epoch_20/learning_state.pth b/checkpoints/epoch_20/learning_state.pth new file mode 100644 index 0000000000000000000000000000000000000000..f2efcf054b44b72c12ce845a30f2c37bdfd50535 --- /dev/null +++ b/checkpoints/epoch_20/learning_state.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b1163b210b246b07d8f1c50eb3766d97c6f03bf409c854d00b7c69edb6d7391 +size 748166934 diff --git a/checkpoints/epoch_20/nonlocal_net.pth b/checkpoints/epoch_20/nonlocal_net.pth new file mode 100644 index 0000000000000000000000000000000000000000..f68eb2413624faabd93ff65986e7c13be4da1485 --- /dev/null +++ b/checkpoints/epoch_20/nonlocal_net.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:031e5f38cc79eb3c0ed51ca2ad3c8921fdda2fa05946c357f84881259de74e6d +size 73190208 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..025946634ce9270d63ec6dd5e0443f78a768ce54 Binary files /dev/null and b/requirements.txt differ diff --git a/sample_input/ref1.jpg b/sample_input/ref1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2ed0a6f318d7b1a4716ef64422575c5d2078afe Binary files /dev/null and b/sample_input/ref1.jpg differ diff --git a/sample_input/video1.mp4 b/sample_input/video1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e96476e110511c201d972fc780d39a833abb92f1 --- /dev/null +++ b/sample_input/video1.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:077ebcd3cf6c020c95732e74a0fe1fab9b80102bc14d5e201b12c4917e0c0d1d +size 1011726 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/dataloader.py b/src/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..16fb6d288c37b1dd450f6d00a60909e26b75c3bb --- /dev/null +++ b/src/data/dataloader.py @@ -0,0 +1,332 @@ +import numpy as np +import pandas as pd +from src.utils import ( + CenterPadCrop_numpy, + Distortion_with_flow_cpu, + Distortion_with_flow_gpu, + Normalize, + RGB2Lab, + ToTensor, + Normalize, + RGB2Lab, + ToTensor, + CenterPad, + read_flow, + SquaredPadding +) +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from numpy import random +import os +from PIL import Image +from scipy.ndimage.filters import gaussian_filter +from scipy.ndimage import map_coordinates +import glob + + + +def image_loader(path): + with open(path, "rb") as f: + with Image.open(f) as img: + return img.convert("RGB") + + +class CenterCrop(object): + """ + center crop the numpy array + """ + + def __init__(self, image_size): + self.h0, self.w0 = image_size + + def __call__(self, input_numpy): + if input_numpy.ndim == 3: + h, w, channel = input_numpy.shape + output_numpy = np.zeros((self.h0, self.w0, channel)) + output_numpy = input_numpy[ + (h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0, : + ] + else: + h, w = input_numpy.shape + output_numpy = np.zeros((self.h0, self.w0)) + output_numpy = input_numpy[ + (h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0 + ] + return output_numpy + + +class VideosDataset(torch.utils.data.Dataset): + def __init__( + self, + video_data_root, + flow_data_root, + mask_data_root, + imagenet_folder, + annotation_file_path, + image_size, + num_refs=5, # max = 20 + image_transform=None, + real_reference_probability=1, + nonzero_placeholder_probability=0.5, + ): + self.video_data_root = video_data_root + self.flow_data_root = flow_data_root + self.mask_data_root = mask_data_root + self.imagenet_folder = imagenet_folder + self.image_transform = image_transform + self.CenterPad = CenterPad(image_size) + self.Resize = transforms.Resize(image_size) + self.ToTensor = ToTensor() + self.CenterCrop = transforms.CenterCrop(image_size) + self.SquaredPadding = SquaredPadding(image_size[0]) + self.num_refs = num_refs + + assert os.path.exists(self.video_data_root), "find no video dataroot" + assert os.path.exists(self.flow_data_root), "find no flow dataroot" + assert os.path.exists(self.imagenet_folder), "find no imagenet folder" + # self.epoch = epoch + self.image_pairs = pd.read_csv(annotation_file_path, dtype=str) + self.real_len = len(self.image_pairs) + # self.image_pairs = pd.concat([self.image_pairs] * self.epoch, ignore_index=True) + self.real_reference_probability = real_reference_probability + self.nonzero_placeholder_probability = nonzero_placeholder_probability + print("##### parsing image pairs in %s: %d pairs #####" % (video_data_root, self.__len__())) + + def __getitem__(self, index): + ( + video_name, + prev_frame, + current_frame, + flow_forward_name, + mask_name, + reference_1_name, + reference_2_name, + reference_3_name, + reference_4_name, + reference_5_name + ) = self.image_pairs.iloc[index, :5+self.num_refs].values.tolist() + + video_path = os.path.join(self.video_data_root, video_name) + flow_path = os.path.join(self.flow_data_root, video_name) + mask_path = os.path.join(self.mask_data_root, video_name) + + prev_frame_path = os.path.join(video_path, prev_frame) + current_frame_path = os.path.join(video_path, current_frame) + list_frame_path = glob.glob(os.path.join(video_path, '*')) + list_frame_path.sort() + + reference_1_path = os.path.join(self.imagenet_folder, reference_1_name) + reference_2_path = os.path.join(self.imagenet_folder, reference_2_name) + reference_3_path = os.path.join(self.imagenet_folder, reference_3_name) + reference_4_path = os.path.join(self.imagenet_folder, reference_4_name) + reference_5_path = os.path.join(self.imagenet_folder, reference_5_name) + + flow_forward_path = os.path.join(flow_path, flow_forward_name) + mask_path = os.path.join(mask_path, mask_name) + + #reference_gt_1_path = prev_frame_path + #reference_gt_2_path = current_frame_path + try: + I1 = Image.open(prev_frame_path).convert("RGB") + I2 = Image.open(current_frame_path).convert("RGB") + try: + I_reference_video = Image.open(list_frame_path[0]).convert("RGB") # Get first frame + except: + I_reference_video = Image.open(current_frame_path).convert("RGB") # Get current frame if error + + reference_list = [reference_1_path, reference_2_path, reference_3_path, reference_4_path, reference_5_path] + while reference_list: # run until getting the colorized reference + reference_path = random.choice(reference_list) + I_reference_video_real = Image.open(reference_path) + if I_reference_video_real.mode == 'L': + reference_list.remove(reference_path) + else: + break + if not reference_list: + I_reference_video_real = I_reference_video + + flow_forward = read_flow(flow_forward_path) # numpy + + mask = Image.open(mask_path) # PIL + mask = self.Resize(mask) + mask = np.array(mask) + # mask = self.SquaredPadding(mask, return_pil=False, return_paddings=False) + # binary mask + mask[mask < 240] = 0 + mask[mask >= 240] = 1 + mask = self.ToTensor(mask) + + # transform + I1 = self.image_transform(I1) + I2 = self.image_transform(I2) + I_reference_video = self.image_transform(I_reference_video) + I_reference_video_real = self.image_transform(I_reference_video_real) + flow_forward = self.ToTensor(flow_forward) + flow_forward = self.Resize(flow_forward)#, return_pil=False, return_paddings=False, dtype=np.float32) + + + if np.random.random() < self.real_reference_probability: + I_reference_output = I_reference_video_real # Use reference from imagenet + placeholder = torch.zeros_like(I1) + self_ref_flag = torch.zeros_like(I1) + else: + I_reference_output = I_reference_video # Use reference from ground truth + placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1) + self_ref_flag = torch.ones_like(I1) + + outputs = [ + I1, + I2, + I_reference_output, + flow_forward, + mask, + placeholder, + self_ref_flag, + video_name + prev_frame, + video_name + current_frame, + reference_path + ] + + except Exception as e: + print("error in reading image pair: %s" % str(self.image_pairs[index])) + print(e) + return self.__getitem__(np.random.randint(0, len(self.image_pairs))) + return outputs + + def __len__(self): + return len(self.image_pairs) + + +def parse_imgnet_images(pairs_file): + pairs = [] + with open(pairs_file, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip().split("|") + image_a = line[0] + image_b = line[1] + pairs.append((image_a, image_b)) + return pairs + + +class VideosDataset_ImageNet(data.Dataset): + def __init__( + self, + imagenet_data_root, + pairs_file, + image_size, + transforms_imagenet=None, + distortion_level=3, + brightnessjitter=0, + nonzero_placeholder_probability=0.5, + extra_reference_transform=None, + real_reference_probability=1, + distortion_device='cpu' + ): + self.imagenet_data_root = imagenet_data_root + self.image_pairs = pd.read_csv(pairs_file, names=['i1', 'i2']) + self.transforms_imagenet_raw = transforms_imagenet + self.extra_reference_transform = transforms.Compose(extra_reference_transform) + self.real_reference_probability = real_reference_probability + self.transforms_imagenet = transforms.Compose(transforms_imagenet) + self.image_size = image_size + self.real_len = len(self.image_pairs) + self.distortion_level = distortion_level + self.distortion_transform = Distortion_with_flow_cpu() if distortion_device == 'cpu' else Distortion_with_flow_gpu() + self.brightnessjitter = brightnessjitter + self.flow_transform = transforms.Compose([CenterPadCrop_numpy(self.image_size), ToTensor()]) + self.nonzero_placeholder_probability = nonzero_placeholder_probability + self.ToTensor = ToTensor() + self.Normalize = Normalize() + print("##### parsing imageNet pairs in %s: %d pairs #####" % (imagenet_data_root, self.__len__())) + + def __getitem__(self, index): + pa, pb = self.image_pairs.iloc[index].values.tolist() + if np.random.random() > 0.5: + pa, pb = pb, pa + + image_a_path = os.path.join(self.imagenet_data_root, pa) + image_b_path = os.path.join(self.imagenet_data_root, pb) + + I1 = image_loader(image_a_path) + I2 = I1 + I_reference_video = I1 + I_reference_video_real = image_loader(image_b_path) + # print("i'm here get image 2") + # generate the flow + alpha = np.random.rand() * self.distortion_level + distortion_range = 50 + random_state = np.random.RandomState(None) + shape = self.image_size[0], self.image_size[1] + # dx: flow on the vertical direction; dy: flow on the horizontal direction + forward_dx = ( + gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000 + ) + forward_dy = ( + gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000 + ) + # print("i'm here get image 3") + for transform in self.transforms_imagenet_raw: + if type(transform) is RGB2Lab: + I1_raw = I1 + I1 = transform(I1) + for transform in self.transforms_imagenet_raw: + if type(transform) is RGB2Lab: + I2 = self.distortion_transform(I2, forward_dx, forward_dy) + I2_raw = I2 + I2 = transform(I2) + # print("i'm here get image 4") + I2[0:1, :, :] = I2[0:1, :, :] + torch.randn(1) * self.brightnessjitter + + I_reference_video = self.extra_reference_transform(I_reference_video) + for transform in self.transforms_imagenet_raw: + I_reference_video = transform(I_reference_video) + + I_reference_video_real = self.transforms_imagenet(I_reference_video_real) + # print("i'm here get image 5") + flow_forward_raw = np.stack((forward_dy, forward_dx), axis=-1) + flow_forward = self.flow_transform(flow_forward_raw) + + # update the mask for the pixels on the border + grid_x, grid_y = np.meshgrid(np.arange(self.image_size[0]), np.arange(self.image_size[1]), indexing="ij") + grid = np.stack((grid_y, grid_x), axis=-1) + grid_warp = grid + flow_forward_raw + location_y = grid_warp[:, :, 0].flatten() + location_x = grid_warp[:, :, 1].flatten() + I2_raw = np.array(I2_raw).astype(float) + I21_r = map_coordinates(I2_raw[:, :, 0], np.stack((location_x, location_y)), cval=-1).reshape( + (self.image_size[0], self.image_size[1]) + ) + I21_g = map_coordinates(I2_raw[:, :, 1], np.stack((location_x, location_y)), cval=-1).reshape( + (self.image_size[0], self.image_size[1]) + ) + I21_b = map_coordinates(I2_raw[:, :, 2], np.stack((location_x, location_y)), cval=-1).reshape( + (self.image_size[0], self.image_size[1]) + ) + I21_raw = np.stack((I21_r, I21_g, I21_b), axis=2) + mask = np.ones((self.image_size[0], self.image_size[1])) + mask[(I21_raw[:, :, 0] == -1) & (I21_raw[:, :, 1] == -1) & (I21_raw[:, :, 2] == -1)] = 0 + mask[abs(I21_raw - I1_raw).sum(axis=-1) > 50] = 0 + mask = self.ToTensor(mask) + # print("i'm here get image 6") + if np.random.random() < self.real_reference_probability: + I_reference_output = I_reference_video_real + placeholder = torch.zeros_like(I1) + self_ref_flag = torch.zeros_like(I1) + else: + I_reference_output = I_reference_video + placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1) + self_ref_flag = torch.ones_like(I1) + + # except Exception as e: + # if combo_path is not None: + # print("problem in ", combo_path) + # print("problem in, ", image_a_path) + # print(e) + # return self.__getitem__(np.random.randint(0, len(self.image_pairs))) + # print("i'm here get image 7") + return [I1, I2, I_reference_output, flow_forward, mask, placeholder, self_ref_flag, "holder", pb, pa] + + def __len__(self): + return len(self.image_pairs) \ No newline at end of file diff --git a/src/data/functional.py b/src/data/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..14aa7882d3dfca1ba6649d0b7fdb2c443e3b7f20 --- /dev/null +++ b/src/data/functional.py @@ -0,0 +1,84 @@ +from __future__ import division + +import torch +import numbers +import collections +import numpy as np +from PIL import Image, ImageOps + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_mytensor(pic): + pic_arr = np.array(pic) + if pic_arr.ndim == 2: + pic_arr = pic_arr[..., np.newaxis] + img = torch.from_numpy(pic_arr.transpose((2, 0, 1))) + if not isinstance(img, torch.FloatTensor): + return img.float() # no normalize .div(255) + else: + return img + + +def normalize(tensor, mean, std): + if not _is_tensor_image(tensor): + raise TypeError("tensor is not a torch image.") + if tensor.size(0) == 1: + tensor.sub_(mean).div_(std) + else: + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def resize(img, size, interpolation=Image.BILINEAR): + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + if not isinstance(size, int) and (not isinstance(size, collections.Iterable) or len(size) != 2): + raise TypeError("Got inappropriate size arg: {}".format(size)) + + if not isinstance(size, int): + return img.resize(size[::-1], interpolation) + + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(round(size * h / w)) + else: + oh = size + ow = int(round(size * w / h)) + return img.resize((ow, oh), interpolation) + + +def pad(img, padding, fill=0): + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple)): + raise TypeError("Got inappropriate padding arg") + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError("Got inappropriate fill arg") + + if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) + + return ImageOps.expand(img, border=padding, fill=fill) + + +def crop(img, i, j, h, w): + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + return img.crop((j, i, j + w, i + h)) diff --git a/src/data/transforms.py b/src/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ce42b501ed66802df534491a633498af7f824c --- /dev/null +++ b/src/data/transforms.py @@ -0,0 +1,348 @@ +from __future__ import division + +import collections +import numbers +import random + +import torch +from PIL import Image +from skimage import color + +import src.data.functional as F + +__all__ = [ + "Compose", + "Concatenate", + "ToTensor", + "Normalize", + "Resize", + "Scale", + "CenterCrop", + "Pad", + "RandomCrop", + "RandomHorizontalFlip", + "RandomVerticalFlip", + "RandomResizedCrop", + "RandomSizedCrop", + "FiveCrop", + "TenCrop", + "RGB2Lab", +] + + +def CustomFunc(inputs, func, *args, **kwargs): + im_l = func(inputs[0], *args, **kwargs) + im_ab = func(inputs[1], *args, **kwargs) + warp_ba = func(inputs[2], *args, **kwargs) + warp_aba = func(inputs[3], *args, **kwargs) + im_gbl_ab = func(inputs[4], *args, **kwargs) + bgr_mc_im = func(inputs[5], *args, **kwargs) + + layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] + + for l in range(5): + layer = inputs[6 + l] + err_ba = func(layer[0], *args, **kwargs) + err_ab = func(layer[1], *args, **kwargs) + + layer_data.append([err_ba, err_ab]) + + return layer_data + + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, inputs): + for t in self.transforms: + inputs = t(inputs) + return inputs + + +class Concatenate(object): + """ + Input: [im_l, im_ab, inputs] + inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba] + + Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba] + """ + + def __call__(self, inputs): + im_l = inputs[0] + im_ab = inputs[1] + warp_ba = inputs[2] + warp_aba = inputs[3] + im_glb_ab = inputs[4] + bgr_mc_im = inputs[5] + bgr_mc_im = bgr_mc_im[[2, 1, 0], ...] + + err_ba = [] + err_ab = [] + + for l in range(5): + layer = inputs[6 + l] + err_ba.append(layer[0]) + err_ab.append(layer[1]) + + cerr_ba = torch.cat(err_ba, 0) + cerr_ab = torch.cat(err_ab, 0) + + return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab) + + +class ToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + + def __call__(self, inputs): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return CustomFunc(inputs, F.to_mytensor) + + +class Normalize(object): + """Normalize an tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __call__(self, inputs): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + + im_l = F.normalize(inputs[0], 50, 1) # [0, 100] + im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100] + + inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1) + inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1)) + warp_ba = inputs[2] + + inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1) + inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1)) + warp_aba = inputs[3] + + im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100] + + bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1)) + + layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] + + for l in range(5): + layer = inputs[6 + l] + err_ba = F.normalize(layer[0], 127, 2) # [0, 255] + err_ab = F.normalize(layer[1], 127, 2) # [0, 255] + layer_data.append([err_ba, err_ab]) + + return layer_data + + +class Resize(object): + """Resize the input PIL Image to the given size. + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, inputs): + """ + Args: + img (PIL Image): Image to be scaled. + + Returns: + PIL Image: Rescaled image. + """ + return CustomFunc(inputs, F.resize, self.size, self.interpolation) + + +class RandomCrop(object): + """Crop the given PIL Image at a random location. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is 0, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. + """ + + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, inputs): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + if self.padding > 0: + inputs = CustomFunc(inputs, F.pad, self.padding) + + i, j, h, w = self.get_params(inputs[0], self.size) + return CustomFunc(inputs, F.crop, i, j, h, w) + + +class CenterCrop(object): + """Crop the given PIL Image at a random location. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is 0, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. + """ + + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = (h - th) // 2 + j = (w - tw) // 2 + return i, j, th, tw + + def __call__(self, inputs): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + if self.padding > 0: + inputs = CustomFunc(inputs, F.pad, self.padding) + + i, j, h, w = self.get_params(inputs[0], self.size) + return CustomFunc(inputs, F.crop, i, j, h, w) + + +class RandomHorizontalFlip(object): + """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" + + def __call__(self, inputs): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + + if random.random() < 0.5: + return CustomFunc(inputs, F.hflip) + return inputs + + +class RGB2Lab(object): + def __call__(self, inputs): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + + def __call__(self, inputs): + image_lab = color.rgb2lab(inputs[0]) + warp_ba_lab = color.rgb2lab(inputs[2]) + warp_aba_lab = color.rgb2lab(inputs[3]) + im_gbl_lab = color.rgb2lab(inputs[4]) + + inputs[0] = image_lab[:, :, :1] # l channel + inputs[1] = image_lab[:, :, 1:] # ab channel + inputs[2] = warp_ba_lab # lab channel + inputs[3] = warp_aba_lab # lab channel + inputs[4] = im_gbl_lab[:, :, 1:] # ab channel + + return inputs diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2306f659e4ab5684058c8ccfe7d560de97e84815 --- /dev/null +++ b/src/inference.py @@ -0,0 +1,174 @@ +from src.models.CNN.ColorVidNet import ColorVidNet +from src.models.vit.embed import SwinModel +from src.models.CNN.NonlocalNet import WarpNet +from src.models.CNN.FrameColor import frame_colorization +import torch +from src.models.vit.utils import load_params +import os +import cv2 +from PIL import Image +from PIL import ImageEnhance as IE +import torchvision.transforms as T +from src.utils import ( + RGB2Lab, + ToTensor, + Normalize, + uncenter_l, + tensor_lab2rgb +) +import numpy as np +from tqdm import tqdm + +class SwinTExCo: + def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None): + if device == None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + + self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device) + self.nonlocal_net = WarpNet(feature_channel=128).to(self.device) + self.colornet = ColorVidNet(7).to(self.device) + + self.embed_net.eval() + self.nonlocal_net.eval() + self.colornet.eval() + + self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth")) + self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth")) + self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth")) + + self.processor = T.Compose([ + T.Resize((224,224)), + RGB2Lab(), + ToTensor(), + Normalize() + ]) + + pass + + def __load_models(self, model, weight_path): + params = load_params(weight_path, self.device) + model.load_state_dict(params, strict=True) + + def __preprocess_reference(self, img): + color_enhancer = IE.Color(img) + img = color_enhancer.enhance(1.5) + return img + + def __upscale_image(self, large_IA_l, I_current_ab_predict): + H, W = large_IA_l.shape[2:] + large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict, + size=(H,W), + mode="bilinear", + align_corners=False) + large_IA_l = torch.cat((large_IA_l, large_current_ab_predict.cpu()), dim=1) + large_current_rgb_predict = tensor_lab2rgb(large_IA_l) + return large_current_rgb_predict + + def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B): + large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0) + large_IA_l = large_IA_lab[:, 0:1, :, :] + + IA_lab = self.processor(curr_frame) + IA_lab = IA_lab.unsqueeze(0).to(self.device) + IA_l = IA_lab[:, 0:1, :, :] + if I_last_lab_predict is None: + I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device) + + + with torch.no_grad(): + I_current_ab_predict, _ = frame_colorization( + IA_l, + I_reference_lab, + I_last_lab_predict, + features_B, + self.embed_net, + self.nonlocal_net, + self.colornet, + luminance_noise=0, + temperature=1e-10, + joint_training=False + ) + I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) + + IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict) + IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.) + IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8) + + return I_last_lab_predict, IA_predict_rgb + + def predict_video(self, video, ref_image): + ref_image = self.__preprocess_reference(ref_image) + + I_last_lab_predict = None + + IB_lab = self.processor(ref_image) + IB_lab = IB_lab.unsqueeze(0).to(self.device) + + with torch.no_grad(): + I_reference_lab = IB_lab + I_reference_l = I_reference_lab[:, 0:1, :, :] + I_reference_ab = I_reference_lab[:, 1:3, :, :] + I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device) + features_B = self.embed_net(I_reference_rgb) + + #PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame") + while video.isOpened(): + #PBAR.update(1) + ret, curr_frame = video.read() + + if not ret: + break + + curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) + curr_frame = Image.fromarray(curr_frame) + + I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B) + + IA_predict_rgb = IA_predict_rgb.transpose(1,2,0) + + yield IA_predict_rgb + + #PBAR.close() + video.release() + + def predict_image(self, image, ref_image): + ref_image = self.__preprocess_reference(ref_image) + + I_last_lab_predict = None + + IB_lab = self.processor(ref_image) + IB_lab = IB_lab.unsqueeze(0).to(self.device) + + with torch.no_grad(): + I_reference_lab = IB_lab + I_reference_l = I_reference_lab[:, 0:1, :, :] + I_reference_ab = I_reference_lab[:, 1:3, :, :] + I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device) + features_B = self.embed_net(I_reference_rgb) + + curr_frame = image + I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B) + + IA_predict_rgb = IA_predict_rgb.transpose(1,2,0) + + return IA_predict_rgb + +if __name__ == "__main__": + model = SwinTExCo('checkpoints/epoch_20/') + + # Initialize video reader and writer + video = cv2.VideoCapture('sample_input/video_2.mp4') + fps = video.get(cv2.CAP_PROP_FPS) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + # Initialize reference image + ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB') + + for colorized_frame in model.predict_video(video, ref_image): + video_writer.write(colorized_frame) + + video_writer.release() \ No newline at end of file diff --git a/src/losses.py b/src/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..dd78f9226bdee39354fa8fb31a05e4aefeb9e55d --- /dev/null +++ b/src/losses.py @@ -0,0 +1,277 @@ +import torch +import torch.nn as nn +from src.utils import feature_normalize + + +### START### CONTEXTUAL LOSS #### +class ContextualLoss(nn.Module): + """ + input is Al, Bl, channel = 1, range ~ [0, 255] + """ + + def __init__(self): + super(ContextualLoss, self).__init__() + return None + + def forward(self, X_features, Y_features, h=0.1, feature_centering=True): + """ + X_features&Y_features are are feature vectors or feature 2d array + h: bandwidth + return the per-sample loss + """ + batch_size = X_features.shape[0] + feature_depth = X_features.shape[1] + + # to normalized feature vectors + if feature_centering: + X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( + dim=-1 + ) + Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( + dim=-1 + ) + X_features = feature_normalize(X_features).view( + batch_size, feature_depth, -1 + ) # batch_size * feature_depth * feature_size^2 + Y_features = feature_normalize(Y_features).view( + batch_size, feature_depth, -1 + ) # batch_size * feature_depth * feature_size^2 + + # conine distance = 1 - similarity + X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth + d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 + + # normalized distance: dij_bar + d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 + + # pairwise affinity + w = torch.exp((1 - d_norm) / h) + A_ij = w / torch.sum(w, dim=-1, keepdim=True) + + # contextual loss per sample + CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1) + return -torch.log(CX) + + +class ContextualLoss_forward(nn.Module): + """ + input is Al, Bl, channel = 1, range ~ [0, 255] + """ + + def __init__(self): + super(ContextualLoss_forward, self).__init__() + return None + + def forward(self, X_features, Y_features, h=0.1, feature_centering=True): + """ + X_features&Y_features are are feature vectors or feature 2d array + h: bandwidth + return the per-sample loss + """ + batch_size = X_features.shape[0] + feature_depth = X_features.shape[1] + + # to normalized feature vectors + if feature_centering: + X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( + dim=-1 + ) + Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( + dim=-1 + ) + X_features = feature_normalize(X_features).view( + batch_size, feature_depth, -1 + ) # batch_size * feature_depth * feature_size^2 + Y_features = feature_normalize(Y_features).view( + batch_size, feature_depth, -1 + ) # batch_size * feature_depth * feature_size^2 + + # conine distance = 1 - similarity + X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth + d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 + + # normalized distance: dij_bar + d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 + + # pairwise affinity + w = torch.exp((1 - d_norm) / h) + A_ij = w / torch.sum(w, dim=-1, keepdim=True) + + # contextual loss per sample + CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1) + return -torch.log(CX) + + +### END### CONTEXTUAL LOSS #### + + +########################## + + +def mse_loss_fn(input, target=0): + return torch.mean((input - target) ** 2) + + +### START### PERCEPTUAL LOSS ### +def Perceptual_loss(domain_invariant, weight_perceptual): + instancenorm = nn.InstanceNorm2d(512, affine=False) + + def __call__(A_relu5_1, predict_relu5_1): + if domain_invariant: + feat_loss = ( + mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2 + ) + else: + feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual + return feat_loss + + return __call__ + + +### END### PERCEPTUAL LOSS ### + + +def l1_loss_fn(input, target=0): + return torch.mean(torch.abs(input - target)) + + +### END################# + + +### START### ADVERSIAL LOSS ### +def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device): + if weight_gan > 0: + y_pred_fake, _ = discriminator(fake_data_lab) + y_pred_real, _ = discriminator(real_data_lab) + + y = torch.ones_like(y_pred_real) + generator_loss = ( + ( + torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2) + + torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2) + ) + / 2 + * weight_gan + ) + return generator_loss + + return torch.Tensor([0]).to(device) + + +def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator): + y_pred_fake, _ = discriminator(fake_data_lab.detach()) + y_pred_real, _ = discriminator(real_data_lab.detach()) + + y = torch.ones_like(y_pred_real) + discriminator_loss = ( + torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2) + + torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2) + ) / 2 + return discriminator_loss + + +### END### ADVERSIAL LOSS ##### + + +def consistent_loss_fn( + I_current_lab_predict, + I_last_ab_predict, + I_current_nonlocal_lab_predict, + I_last_nonlocal_lab_predict, + flow_forward, + mask, + warping_layer, + weight_consistent=0.02, + weight_nonlocal_consistent=0.0, + device="cuda", +): + def weighted_mse_loss(input, target, weights): + out = (input - target) ** 2 + out = out * weights.expand_as(out) + return out.mean() + + def consistent(): + I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward) + I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :] + consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent + return consistent_loss + + def nonlocal_consistent(): + I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward) + nonlocal_consistent_loss = ( + weighted_mse_loss( + I_current_nonlocal_lab_predict_warp[:, 1:3, :, :], + I_last_nonlocal_lab_predict[:, 1:3, :, :], + mask, + ) + * weight_nonlocal_consistent + ) + + return nonlocal_consistent_loss + + consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device) + nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device) + + return consistent_loss + nonlocal_consistent_loss + + +### END### CONSISTENCY LOSS ##### + + +### START### SMOOTHNESS LOSS ### +def smoothness_loss_fn( + I_current_l, + I_current_lab, + I_current_ab_predict, + A_relu2_1, + weighted_layer_color, + nonlocal_weighted_layer, + weight_smoothness=5.0, + weight_nonlocal_smoothness=0.0, + device="cuda", +): + def smoothness(scale_factor=1.0): + I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) + IA_ab_weighed = weighted_layer_color( + I_current_lab, + I_current_lab_predict, + patch_size=3, + alpha=10, + scale_factor=scale_factor, + ) + smoothness_loss = ( + mse_loss_fn( + nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor), + IA_ab_weighed, + ) + * weight_smoothness + ) + + return smoothness_loss + + def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5): + nonlocal_smooth_feature = feature_normalize(A_relu2_1) + I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) + I_current_ab_weighted_nonlocal = nonlocal_weighted_layer( + I_current_lab_predict, + nonlocal_smooth_feature.detach(), + patch_size=3, + alpha=alpha_nonlocal_smoothness, + scale_factor=scale_factor, + ) + nonlocal_smoothness_loss = ( + mse_loss_fn( + nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor), + I_current_ab_weighted_nonlocal, + ) + * weight_nonlocal_smoothness + ) + return nonlocal_smoothness_loss + + smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device) + nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device) + + return smoothness_loss + nonlocal_smoothness_loss + + +### END### SMOOTHNESS LOSS ##### diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4cfff45b2dda234756c3d809263be3143a47e952 --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,225 @@ +from skimage.metrics import structural_similarity, peak_signal_noise_ratio +import numpy as np +import lpips +import torch +from pytorch_fid.fid_score import calculate_frechet_distance +from pytorch_fid.inception import InceptionV3 +import torch.nn as nn +import cv2 +from scipy import stats +import os + +def calc_ssim(pred_image, gt_image): + ''' + Structural Similarity Index (SSIM) is a perceptual metric that quantifies the image quality degradation that is + caused by processing such as data compression or by losses in data transmission. + + # Arguments + img1: PIL.Image + img2: PIL.Image + # Returns + ssim: float (-1.0, 1.0) + ''' + pred_image = np.array(pred_image.convert('RGB')).astype(np.float32) + gt_image = np.array(gt_image.convert('RGB')).astype(np.float32) + ssim = structural_similarity(pred_image, gt_image, channel_axis=2, data_range=255.) + return ssim + +def calc_psnr(pred_image, gt_image): + ''' + Peak Signal-to-Noise Ratio (PSNR) is an expression for the ratio between the maximum possible value (power) of a signal + and the power of distorting noise that affects the quality of its representation. + + # Arguments + img1: PIL.Image + img2: PIL.Image + # Returns + psnr: float + ''' + pred_image = np.array(pred_image.convert('RGB')).astype(np.float32) + gt_image = np.array(gt_image.convert('RGB')).astype(np.float32) + + psnr = peak_signal_noise_ratio(gt_image, pred_image, data_range=255.) + return psnr + +class LPIPS_utils: + def __init__(self, device = 'cuda'): + self.loss_fn = lpips.LPIPS(net='vgg', spatial=True) # Can set net = 'squeeze' or 'vgg'or 'alex' + self.loss_fn = self.loss_fn.to(device) + self.device = device + + def compare_lpips(self,img_fake, img_real, data_range=255.): # input: torch 1 c h w / h w c + img_fake = torch.from_numpy(np.array(img_fake).astype(np.float32)/data_range) + img_real = torch.from_numpy(np.array(img_real).astype(np.float32)/data_range) + if img_fake.ndim==3: + img_fake = img_fake.permute(2,0,1).unsqueeze(0) + img_real = img_real.permute(2,0,1).unsqueeze(0) + img_fake = img_fake.to(self.device) + img_real = img_real.to(self.device) + + dist = self.loss_fn.forward(img_fake,img_real) + return dist.mean().item() + +class FID_utils(nn.Module): + """Class for computing the Fréchet Inception Distance (FID) metric score. + It is implemented as a class in order to hold the inception model instance + in its state. + Parameters + ---------- + resize_input : bool (optional) + Whether or not to resize the input images to the image size (299, 299) + on which the inception model was trained. Since the model is fully + convolutional, the score also works without resizing. In literature + and when working with GANs people tend to set this value to True, + however, for internal evaluation this is not necessary. + device : str or torch.device + The device on which to run the inception model. + """ + + def __init__(self, resize_input=True, device="cuda"): + super(FID_utils, self).__init__() + self.device = device + if self.device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + #self.model = InceptionV3(resize_input=resize_input).to(device) + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] + self.model = InceptionV3([block_idx]).to(device) + self.model = self.model.eval() + + def get_activations(self,batch): # 1 c h w + with torch.no_grad(): + pred = self.model(batch)[0] + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + #pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + print("error in get activations!") + #pred = pred.squeeze(3).squeeze(2).cpu().numpy() + return pred + + + def _get_mu_sigma(self, batch,data_range): + """Compute the inception mu and sigma for a batch of images. + Parameters + ---------- + images : np.ndarray + A batch of images with shape (n_images,3, width, height). + Returns + ------- + mu : np.ndarray + The array of mean activations with shape (2048,). + sigma : np.ndarray + The covariance matrix of activations with shape (2048, 2048). + """ + # forward pass + if batch.ndim ==3 and batch.shape[2]==3: + batch=batch.permute(2,0,1).unsqueeze(0) + batch /= data_range + #batch = torch.tensor(batch)#.unsqueeze(1).repeat((1, 3, 1, 1)) + batch = batch.to(self.device, torch.float32) + #(activations,) = self.model(batch) + activations = self.get_activations(batch) + activations = activations.detach().cpu().numpy().squeeze(3).squeeze(2) + + # compute statistics + mu = np.mean(activations,axis=0) + sigma = np.cov(activations, rowvar=False) + + return mu, sigma + + def score(self, images_1, images_2, data_range=255.): + """Compute the FID score. + The input batches should have the shape (n_images,3, width, height). or (h,w,3) + Parameters + ---------- + images_1 : np.ndarray + First batch of images. + images_2 : np.ndarray + Section batch of images. + Returns + ------- + score : float + The FID score. + """ + images_1 = torch.from_numpy(np.array(images_1).astype(np.float32)) + images_2 = torch.from_numpy(np.array(images_2).astype(np.float32)) + images_1 = images_1.to(self.device) + images_2 = images_2.to(self.device) + + mu_1, sigma_1 = self._get_mu_sigma(images_1,data_range) + mu_2, sigma_2 = self._get_mu_sigma(images_2,data_range) + score = calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2) + + return score + +def JS_divergence(p, q): + M = (p + q) / 2 + return 0.5 * stats.entropy(p, M) + 0.5 * stats.entropy(q, M) + + +def compute_JS_bgr(input_dir, dilation=1): + input_img_list = os.listdir(input_dir) + input_img_list.sort() + # print(input_img_list) + + hist_b_list = [] # [img1_histb, img2_histb, ...] + hist_g_list = [] + hist_r_list = [] + + for img_name in input_img_list: + # print(os.path.join(input_dir, img_name)) + img_in = cv2.imread(os.path.join(input_dir, img_name)) + H, W, C = img_in.shape + + hist_b = cv2.calcHist([img_in], [0], None, [256], [0,256]) # B + hist_g = cv2.calcHist([img_in], [1], None, [256], [0,256]) # G + hist_r = cv2.calcHist([img_in], [2], None, [256], [0,256]) # R + + hist_b = hist_b / (H * W) + hist_g = hist_g / (H * W) + hist_r = hist_r / (H * W) + + hist_b_list.append(hist_b) + hist_g_list.append(hist_g) + hist_r_list.append(hist_r) + + JS_b_list = [] + JS_g_list = [] + JS_r_list = [] + + for i in range(len(hist_b_list)): + if i + dilation > len(hist_b_list) - 1: + break + hist_b_img1 = hist_b_list[i] + hist_b_img2 = hist_b_list[i + dilation] + JS_b = JS_divergence(hist_b_img1, hist_b_img2) + JS_b_list.append(JS_b) + + hist_g_img1 = hist_g_list[i] + hist_g_img2 = hist_g_list[i+dilation] + JS_g = JS_divergence(hist_g_img1, hist_g_img2) + JS_g_list.append(JS_g) + + hist_r_img1 = hist_r_list[i] + hist_r_img2 = hist_r_list[i+dilation] + JS_r = JS_divergence(hist_r_img1, hist_r_img2) + JS_r_list.append(JS_r) + + return JS_b_list, JS_g_list, JS_r_list + + +def calc_cdc(vid_folder, dilation=[1, 2, 4], weight=[1/3, 1/3, 1/3]): + mean_b, mean_g, mean_r = 0, 0, 0 + for d, w in zip(dilation, weight): + JS_b_list_one, JS_g_list_one, JS_r_list_one = compute_JS_bgr(vid_folder, d) + mean_b += w * np.mean(JS_b_list_one) + mean_g += w * np.mean(JS_g_list_one) + mean_r += w * np.mean(JS_r_list_one) + + cdc = np.mean([mean_b, mean_g, mean_r]) + return cdc + + + + + \ No newline at end of file diff --git a/src/models/CNN/ColorVidNet.py b/src/models/CNN/ColorVidNet.py new file mode 100644 index 0000000000000000000000000000000000000000..0394f06a27582f950cd8df598eef9be715e26242 --- /dev/null +++ b/src/models/CNN/ColorVidNet.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.parallel + +class ColorVidNet(nn.Module): + def __init__(self, ic): + super(ColorVidNet, self).__init__() + self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1)) + self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1) + self.conv1_2norm = nn.BatchNorm2d(64, affine=False) + self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64) + self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1) + self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1) + self.conv2_2norm = nn.BatchNorm2d(128, affine=False) + self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128) + self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) + self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) + self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1) + self.conv3_3norm = nn.BatchNorm2d(256, affine=False) + self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256) + self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) + self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv4_3norm = nn.BatchNorm2d(512, affine=False) + self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) + self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) + self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) + self.conv5_3norm = nn.BatchNorm2d(512, affine=False) + self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) + self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) + self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) + self.conv6_3norm = nn.BatchNorm2d(512, affine=False) + self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1) + self.conv7_3norm = nn.BatchNorm2d(512, affine=False) + self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1) + self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1) + self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1) + self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1) + self.conv8_3norm = nn.BatchNorm2d(256, affine=False) + self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) + self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1) + self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1) + self.conv9_2norm = nn.BatchNorm2d(128, affine=False) + self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1) + self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1) + self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1) + self.conv10_ab = nn.Conv2d(128, 2, 1, 1) + + # add self.relux_x + self.relu1_1 = nn.PReLU() + self.relu1_2 = nn.PReLU() + self.relu2_1 = nn.PReLU() + self.relu2_2 = nn.PReLU() + self.relu3_1 = nn.PReLU() + self.relu3_2 = nn.PReLU() + self.relu3_3 = nn.PReLU() + self.relu4_1 = nn.PReLU() + self.relu4_2 = nn.PReLU() + self.relu4_3 = nn.PReLU() + self.relu5_1 = nn.PReLU() + self.relu5_2 = nn.PReLU() + self.relu5_3 = nn.PReLU() + self.relu6_1 = nn.PReLU() + self.relu6_2 = nn.PReLU() + self.relu6_3 = nn.PReLU() + self.relu7_1 = nn.PReLU() + self.relu7_2 = nn.PReLU() + self.relu7_3 = nn.PReLU() + self.relu8_1_comb = nn.PReLU() + self.relu8_2 = nn.PReLU() + self.relu8_3 = nn.PReLU() + self.relu9_1_comb = nn.PReLU() + self.relu9_2 = nn.PReLU() + self.relu10_1_comb = nn.PReLU() + self.relu10_2 = nn.LeakyReLU(0.2, True) + + self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1)) + self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1)) + self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1)) + + self.conv1_2norm = nn.InstanceNorm2d(64) + self.conv2_2norm = nn.InstanceNorm2d(128) + self.conv3_3norm = nn.InstanceNorm2d(256) + self.conv4_3norm = nn.InstanceNorm2d(512) + self.conv5_3norm = nn.InstanceNorm2d(512) + self.conv6_3norm = nn.InstanceNorm2d(512) + self.conv7_3norm = nn.InstanceNorm2d(512) + self.conv8_3norm = nn.InstanceNorm2d(256) + self.conv9_2norm = nn.InstanceNorm2d(128) + + def forward(self, x): + """x: gray image (1 channel), ab(2 channel), ab_err, ba_err""" + conv1_1 = self.relu1_1(self.conv1_1(x)) + conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) + conv1_2norm = self.conv1_2norm(conv1_2) + conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm) + conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss)) + conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) + conv2_2norm = self.conv2_2norm(conv2_2) + conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm) + conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss)) + conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) + conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) + conv3_3norm = self.conv3_3norm(conv3_3) + conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm) + conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss)) + conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) + conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) + conv4_3norm = self.conv4_3norm(conv4_3) + conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm)) + conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) + conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) + conv5_3norm = self.conv5_3norm(conv5_3) + conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm)) + conv6_2 = self.relu6_2(self.conv6_2(conv6_1)) + conv6_3 = self.relu6_3(self.conv6_3(conv6_2)) + conv6_3norm = self.conv6_3norm(conv6_3) + conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm)) + conv7_2 = self.relu7_2(self.conv7_2(conv7_1)) + conv7_3 = self.relu7_3(self.conv7_3(conv7_2)) + conv7_3norm = self.conv7_3norm(conv7_3) + conv8_1 = self.conv8_1(conv7_3norm) + conv3_3_short = self.conv3_3_short(conv3_3norm) + conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short) + conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb)) + conv8_3 = self.relu8_3(self.conv8_3(conv8_2)) + conv8_3norm = self.conv8_3norm(conv8_3) + conv9_1 = self.conv9_1(conv8_3norm) + conv2_2_short = self.conv2_2_short(conv2_2norm) + conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short) + conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb)) + conv9_2norm = self.conv9_2norm(conv9_2) + conv10_1 = self.conv10_1(conv9_2norm) + conv1_2_short = self.conv1_2_short(conv1_2norm) + conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short) + conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb)) + conv10_ab = self.conv10_ab(conv10_2) + + return torch.tanh(conv10_ab) * 128 diff --git a/src/models/CNN/FrameColor.py b/src/models/CNN/FrameColor.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4f4e7d0f712aef415909ddac3786f84d7de665 --- /dev/null +++ b/src/models/CNN/FrameColor.py @@ -0,0 +1,76 @@ +import torch +from src.utils import * +from src.models.vit.vit import FeatureTransform + + +def warp_color( + IA_l, + IB_lab, + features_B, + embed_net, + nonlocal_net, + temperature=0.01, +): + IA_rgb_from_gray = gray2rgb_batch(IA_l) + + with torch.no_grad(): + A_feat0, A_feat1, A_feat2, A_feat3 = embed_net(IA_rgb_from_gray) + B_feat0, B_feat1, B_feat2, B_feat3 = features_B + + A_feat0 = feature_normalize(A_feat0) + A_feat1 = feature_normalize(A_feat1) + A_feat2 = feature_normalize(A_feat2) + A_feat3 = feature_normalize(A_feat3) + + B_feat0 = feature_normalize(B_feat0) + B_feat1 = feature_normalize(B_feat1) + B_feat2 = feature_normalize(B_feat2) + B_feat3 = feature_normalize(B_feat3) + + return nonlocal_net( + IB_lab, + A_feat0, + A_feat1, + A_feat2, + A_feat3, + B_feat0, + B_feat1, + B_feat2, + B_feat3, + temperature=temperature, + ) + + +def frame_colorization( + IA_l, + IB_lab, + IA_last_lab, + features_B, + embed_net, + nonlocal_net, + colornet, + joint_training=True, + luminance_noise=0, + temperature=0.01, +): + if luminance_noise: + IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise + + with torch.autograd.set_grad_enabled(joint_training): + nonlocal_BA_lab, similarity_map = warp_color( + IA_l, + IB_lab, + features_B, + embed_net, + nonlocal_net, + temperature=temperature, + ) + nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :] + IA_ab_predict = colornet( + torch.cat( + (IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab), + dim=1, + ) + ) + + return IA_ab_predict, nonlocal_BA_lab \ No newline at end of file diff --git a/src/models/CNN/GAN_models.py b/src/models/CNN/GAN_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b82bdfd32f27126d03e20c8f4c9d9c1bdf1c4803 --- /dev/null +++ b/src/models/CNN/GAN_models.py @@ -0,0 +1,212 @@ +# DCGAN-like generator and discriminator +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import Parameter + + +def l2normalize(v, eps=1e-12): + return v / (v.norm() + eps) + + +class SpectralNorm(nn.Module): + def __init__(self, module, name="weight", power_iterations=1): + super(SpectralNorm, self).__init__() + self.module = module + self.name = name + self.power_iterations = power_iterations + if not self._made_params(): + self._make_params() + + def _update_u_v(self): + u = getattr(self.module, self.name + "_u") + v = getattr(self.module, self.name + "_v") + w = getattr(self.module, self.name + "_bar") + + height = w.data.shape[0] + for _ in range(self.power_iterations): + v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) + u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) + + sigma = u.dot(w.view(height, -1).mv(v)) + setattr(self.module, self.name, w / sigma.expand_as(w)) + + def _made_params(self): + try: + u = getattr(self.module, self.name + "_u") + v = getattr(self.module, self.name + "_v") + w = getattr(self.module, self.name + "_bar") + return True + except AttributeError: + return False + + def _make_params(self): + w = getattr(self.module, self.name) + + height = w.data.shape[0] + width = w.view(height, -1).data.shape[1] + + u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) + v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) + u.data = l2normalize(u.data) + v.data = l2normalize(v.data) + w_bar = Parameter(w.data) + + del self.module._parameters[self.name] + + self.module.register_parameter(self.name + "_u", u) + self.module.register_parameter(self.name + "_v", v) + self.module.register_parameter(self.name + "_bar", w_bar) + + def forward(self, *args): + self._update_u_v() + return self.module.forward(*args) + + +class Generator(nn.Module): + def __init__(self, z_dim): + super(Generator, self).__init__() + self.z_dim = z_dim + + self.model = nn.Sequential( + nn.ConvTranspose2d(z_dim, 512, 4, stride=1), + nn.InstanceNorm2d(512), + nn.ReLU(), + nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)), + nn.InstanceNorm2d(256), + nn.ReLU(), + nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)), + nn.InstanceNorm2d(128), + nn.ReLU(), + nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)), + nn.InstanceNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1, 1)), + nn.Tanh(), + ) + + def forward(self, z): + return self.model(z.view(-1, self.z_dim, 1, 1)) + + +channels = 3 +leak = 0.1 +w_g = 4 + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + + self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1, 1))) + self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1, 1))) + self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1, 1))) + self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1, 1))) + self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1, 1))) + self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1, 1))) + self.conv7 = SpectralNorm(nn.Conv2d(256, 256, 3, stride=1, padding=(1, 1))) + self.conv8 = SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1, 1))) + self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, 1)) + + def forward(self, x): + m = x + m = nn.LeakyReLU(leak)(self.conv1(m)) + m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(64)(self.conv2(m))) + m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv3(m))) + m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv4(m))) + m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv5(m))) + m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv6(m))) + m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv7(m))) + m = nn.LeakyReLU(leak)(self.conv8(m)) + + return self.fc(m.view(-1, w_g * w_g * 512)) + + +class Self_Attention(nn.Module): + """Self attention Layer""" + + def __init__(self, in_dim): + super(Self_Attention, self).__init__() + self.chanel_in = in_dim + + self.query_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1)) + self.key_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1)) + self.value_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)) + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + + def forward(self, x): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + m_batchsize, C, width, height = x.size() + proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) + energy = torch.bmm(proj_query, proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(m_batchsize, C, width, height) + + out = self.gamma * out + x + return out + +class Discriminator_x64_224(nn.Module): + """ + Discriminative Network + """ + + def __init__(self, in_size=6, ndf=64): + super(Discriminator_x64_224, self).__init__() + self.in_size = in_size + self.ndf = ndf + + self.layer1 = nn.Sequential(SpectralNorm(nn.Conv2d(self.in_size, self.ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True)) + + self.layer2 = nn.Sequential( + SpectralNorm(nn.Conv2d(self.ndf, self.ndf, 4, 2, 1)), + nn.InstanceNorm2d(self.ndf), + nn.LeakyReLU(0.2, inplace=True), + ) + self.attention = Self_Attention(self.ndf) + self.layer3 = nn.Sequential( + SpectralNorm(nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1)), + nn.InstanceNorm2d(self.ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + ) + self.layer4 = nn.Sequential( + SpectralNorm(nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1)), + nn.InstanceNorm2d(self.ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + ) + self.layer5 = nn.Sequential( + SpectralNorm(nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1)), + nn.InstanceNorm2d(self.ndf * 8), + nn.LeakyReLU(0.2, inplace=True), + ) + self.layer6 = nn.Sequential( + SpectralNorm(nn.Conv2d(self.ndf * 8, self.ndf * 16, 4, 2, 1)), + nn.InstanceNorm2d(self.ndf * 16), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.last = SpectralNorm(nn.Conv2d(self.ndf * 16, 1, [3, 3], 1, 0)) + + def forward(self, input): + feature1 = self.layer1(input) + feature2 = self.layer2(feature1) + feature_attention = self.attention(feature2) + feature3 = self.layer3(feature_attention) + feature4 = self.layer4(feature3) + feature5 = self.layer5(feature4) + feature6 = self.layer6(feature5) + output = self.last(feature6) + output = F.avg_pool2d(output, output.size()[2:]).view(output.size()[0], -1) + + return output, feature4 diff --git a/src/models/CNN/NonlocalNet.py b/src/models/CNN/NonlocalNet.py new file mode 100644 index 0000000000000000000000000000000000000000..cce995da91cda1f65ee4bdf3e26cf3e2992f70b0 --- /dev/null +++ b/src/models/CNN/NonlocalNet.py @@ -0,0 +1,437 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.utils import uncenter_l + + +def find_local_patch(x, patch_size): + """ + > We take a tensor `x` and return a tensor `x_unfold` that contains all the patches of size + `patch_size` in `x` + + Args: + x: the input tensor + patch_size: the size of the patch to be extracted. + """ + + N, C, H, W = x.shape + x_unfold = F.unfold(x, kernel_size=(patch_size, patch_size), padding=(patch_size // 2, patch_size // 2), stride=(1, 1)) + + return x_unfold.view(N, x_unfold.shape[1], H, W) + + +class WeightedAverage(nn.Module): + def __init__( + self, + ): + super(WeightedAverage, self).__init__() + + def forward(self, x_lab, patch_size=3, alpha=1, scale_factor=1): + """ + It takes a 3-channel image (L, A, B) and returns a 2-channel image (A, B) where each pixel is a + weighted average of the A and B values of the pixels in a 3x3 neighborhood around it + + Args: + x_lab: the input image in LAB color space + patch_size: the size of the patch to use for the local average. Defaults to 3 + alpha: the higher the alpha, the smoother the output. Defaults to 1 + scale_factor: the scale factor of the input image. Defaults to 1 + + Returns: + The output of the forward function is a tensor of size (batch_size, 2, height, width) + """ + # alpha=0: less smooth; alpha=inf: smoother + x_lab = F.interpolate(x_lab, scale_factor=scale_factor) + l = x_lab[:, 0:1, :, :] + a = x_lab[:, 1:2, :, :] + b = x_lab[:, 2:3, :, :] + local_l = find_local_patch(l, patch_size) + local_a = find_local_patch(a, patch_size) + local_b = find_local_patch(b, patch_size) + local_difference_l = (local_l - l) ** 2 + correlation = nn.functional.softmax(-1 * local_difference_l / alpha, dim=1) + + return torch.cat( + ( + torch.sum(correlation * local_a, dim=1, keepdim=True), + torch.sum(correlation * local_b, dim=1, keepdim=True), + ), + 1, + ) + + +class WeightedAverage_color(nn.Module): + """ + smooth the image according to the color distance in the LAB space + """ + + def __init__( + self, + ): + super(WeightedAverage_color, self).__init__() + + def forward(self, x_lab, x_lab_predict, patch_size=3, alpha=1, scale_factor=1): + """ + It takes the predicted a and b channels, and the original a and b channels, and finds the + weighted average of the predicted a and b channels based on the similarity of the original a and + b channels to the predicted a and b channels + + Args: + x_lab: the input image in LAB color space + x_lab_predict: the predicted LAB image + patch_size: the size of the patch to use for the local color correction. Defaults to 3 + alpha: controls the smoothness of the output. Defaults to 1 + scale_factor: the scale factor of the input image. Defaults to 1 + + Returns: + The return is the weighted average of the local a and b channels. + """ + """ alpha=0: less smooth; alpha=inf: smoother """ + x_lab = F.interpolate(x_lab, scale_factor=scale_factor) + l = uncenter_l(x_lab[:, 0:1, :, :]) + a = x_lab[:, 1:2, :, :] + b = x_lab[:, 2:3, :, :] + a_predict = x_lab_predict[:, 1:2, :, :] + b_predict = x_lab_predict[:, 2:3, :, :] + local_l = find_local_patch(l, patch_size) + local_a = find_local_patch(a, patch_size) + local_b = find_local_patch(b, patch_size) + local_a_predict = find_local_patch(a_predict, patch_size) + local_b_predict = find_local_patch(b_predict, patch_size) + + local_color_difference = (local_l - l) ** 2 + (local_a - a) ** 2 + (local_b - b) ** 2 + # so that sum of weights equal to 1 + correlation = nn.functional.softmax(-1 * local_color_difference / alpha, dim=1) + + return torch.cat( + ( + torch.sum(correlation * local_a_predict, dim=1, keepdim=True), + torch.sum(correlation * local_b_predict, dim=1, keepdim=True), + ), + 1, + ) + + +class NonlocalWeightedAverage(nn.Module): + def __init__( + self, + ): + super(NonlocalWeightedAverage, self).__init__() + + def forward(self, x_lab, feature, patch_size=3, alpha=0.1, scale_factor=1): + """ + It takes in a feature map and a label map, and returns a smoothed label map + + Args: + x_lab: the input image in LAB color space + feature: the feature map of the input image + patch_size: the size of the patch to be used for the correlation matrix. Defaults to 3 + alpha: the higher the alpha, the smoother the output. + scale_factor: the scale factor of the input image. Defaults to 1 + + Returns: + weighted_ab is the weighted ab channel of the image. + """ + # alpha=0: less smooth; alpha=inf: smoother + # input feature is normalized feature + x_lab = F.interpolate(x_lab, scale_factor=scale_factor) + batch_size, channel, height, width = x_lab.shape + feature = F.interpolate(feature, size=(height, width)) + batch_size = x_lab.shape[0] + x_ab = x_lab[:, 1:3, :, :].view(batch_size, 2, -1) + x_ab = x_ab.permute(0, 2, 1) + + local_feature = find_local_patch(feature, patch_size) + local_feature = local_feature.view(batch_size, local_feature.shape[1], -1) + + correlation_matrix = torch.matmul(local_feature.permute(0, 2, 1), local_feature) + correlation_matrix = nn.functional.softmax(correlation_matrix / alpha, dim=-1) + + weighted_ab = torch.matmul(correlation_matrix, x_ab) + weighted_ab = weighted_ab.permute(0, 2, 1).contiguous() + weighted_ab = weighted_ab.view(batch_size, 2, height, width) + return weighted_ab + + +class CorrelationLayer(nn.Module): + def __init__(self, search_range): + super(CorrelationLayer, self).__init__() + self.search_range = search_range + + def forward(self, x1, x2, alpha=1, raw_output=False, metric="similarity"): + """ + It takes two tensors, x1 and x2, and returns a tensor of shape (batch_size, (search_range * 2 + + 1) ** 2, height, width) where each element is the dot product of the corresponding patch in x1 + and x2 + + Args: + x1: the first image + x2: the image to be warped + alpha: the temperature parameter for the softmax function. Defaults to 1 + raw_output: if True, return the raw output of the network, otherwise return the softmax + output. Defaults to False + metric: "similarity" or "subtraction". Defaults to similarity + + Returns: + The output of the forward function is a softmax of the correlation volume. + """ + shape = list(x1.size()) + shape[1] = (self.search_range * 2 + 1) ** 2 + cv = torch.zeros(shape).to(torch.device("cuda")) + + for i in range(-self.search_range, self.search_range + 1): + for j in range(-self.search_range, self.search_range + 1): + if i < 0: + slice_h, slice_h_r = slice(None, i), slice(-i, None) + elif i > 0: + slice_h, slice_h_r = slice(i, None), slice(None, -i) + else: + slice_h, slice_h_r = slice(None), slice(None) + + if j < 0: + slice_w, slice_w_r = slice(None, j), slice(-j, None) + elif j > 0: + slice_w, slice_w_r = slice(j, None), slice(None, -j) + else: + slice_w, slice_w_r = slice(None), slice(None) + + if metric == "similarity": + cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = ( + x1[:, :, slice_h, slice_w] * x2[:, :, slice_h_r, slice_w_r] + ).sum(1) + else: # patchwise subtraction + cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = -( + (x1[:, :, slice_h, slice_w] - x2[:, :, slice_h_r, slice_w_r]) ** 2 + ).sum(1) + + # TODO sigmoid? + if raw_output: + return cv + else: + return nn.functional.softmax(cv / alpha, dim=1) + + +class WTA_scale(torch.autograd.Function): + """ + We can implement our own custom autograd Functions by subclassing + torch.autograd.Function and implementing the forward and backward passes + which operate on Tensors. + """ + + @staticmethod + def forward(ctx, input, scale=1e-4): + """ + In the forward pass we receive a Tensor containing the input and return a + Tensor containing the output. You can cache arbitrary Tensors for use in the + backward pass using the save_for_backward method. + """ + activation_max, index_max = torch.max(input, -1, keepdim=True) + input_scale = input * scale # default: 1e-4 + # input_scale = input * scale # default: 1e-4 + output_max_scale = torch.where(input == activation_max, input, input_scale) + + mask = (input == activation_max).type(torch.float) + ctx.save_for_backward(input, mask) + return output_max_scale + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + input, mask = ctx.saved_tensors + mask_ones = torch.ones_like(mask) + mask_small_ones = torch.ones_like(mask) * 1e-4 + # mask_small_ones = torch.ones_like(mask) * 1e-4 + + grad_scale = torch.where(mask == 1, mask_ones, mask_small_ones) + grad_input = grad_output.clone() * grad_scale + return grad_input, None + + +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1): + super(ResidualBlock, self).__init__() + self.padding1 = nn.ReflectionPad2d(padding) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride) + self.bn1 = nn.InstanceNorm2d(out_channels) + self.prelu = nn.PReLU() + self.padding2 = nn.ReflectionPad2d(padding) + self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride) + self.bn2 = nn.InstanceNorm2d(out_channels) + + def forward(self, x): + residual = x + out = self.padding1(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + out = self.padding2(out) + out = self.conv2(out) + out = self.bn2(out) + out += residual + out = self.prelu(out) + return out + + +class WarpNet(nn.Module): + """input is Al, Bl, channel = 1, range~[0,255]""" + + def __init__(self, feature_channel=128): + super(WarpNet, self).__init__() + self.feature_channel = feature_channel + self.in_channels = self.feature_channel * 4 + self.inter_channels = 256 + # 44*44 + self.layer2_1 = nn.Sequential( + nn.ReflectionPad2d(1), + # nn.Conv2d(128, 128, kernel_size=3, padding=0, stride=1), + # nn.Conv2d(96, 128, kernel_size=3, padding=20, stride=1), + nn.Conv2d(96, 128, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(128), + nn.PReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=2), + nn.InstanceNorm2d(self.feature_channel), + nn.PReLU(), + nn.Dropout(0.2), + ) + self.layer3_1 = nn.Sequential( + nn.ReflectionPad2d(1), + # nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1), + # nn.Conv2d(192, 128, kernel_size=3, padding=10, stride=1), + nn.Conv2d(192, 128, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(128), + nn.PReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(self.feature_channel), + nn.PReLU(), + nn.Dropout(0.2), + ) + + # 22*22->44*44 + self.layer4_1 = nn.Sequential( + nn.ReflectionPad2d(1), + # nn.Conv2d(512, 256, kernel_size=3, padding=0, stride=1), + # nn.Conv2d(384, 256, kernel_size=3, padding=5, stride=1), + nn.Conv2d(384, 256, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(256), + nn.PReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(self.feature_channel), + nn.PReLU(), + nn.Upsample(scale_factor=2), + nn.Dropout(0.2), + ) + + # 11*11->44*44 + self.layer5_1 = nn.Sequential( + nn.ReflectionPad2d(1), + # nn.Conv2d(1024, 256, kernel_size=3, padding=0, stride=1), + # nn.Conv2d(768, 256, kernel_size=2, padding=2, stride=1), + nn.Conv2d(768, 256, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(256), + nn.PReLU(), + nn.Upsample(scale_factor=2), + nn.ReflectionPad2d(1), + nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1), + nn.InstanceNorm2d(self.feature_channel), + nn.PReLU(), + nn.Upsample(scale_factor=2), + nn.Dropout(0.2), + ) + + self.layer = nn.Sequential( + ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1), + ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1), + ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1), + ) + + self.theta = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) + + self.upsampling = nn.Upsample(scale_factor=4) + + def forward( + self, + B_lab_map, + A_relu2_1, + A_relu3_1, + A_relu4_1, + A_relu5_1, + B_relu2_1, + B_relu3_1, + B_relu4_1, + B_relu5_1, + temperature=0.001 * 5, + detach_flag=False, + WTA_scale_weight=1, + ): + batch_size = B_lab_map.shape[0] + channel = B_lab_map.shape[1] + image_height = B_lab_map.shape[2] + image_width = B_lab_map.shape[3] + feature_height = int(image_height / 4) + feature_width = int(image_width / 4) + + # scale feature size to 44*44 + A_feature2_1 = self.layer2_1(A_relu2_1) + B_feature2_1 = self.layer2_1(B_relu2_1) + A_feature3_1 = self.layer3_1(A_relu3_1) + B_feature3_1 = self.layer3_1(B_relu3_1) + A_feature4_1 = self.layer4_1(A_relu4_1) + B_feature4_1 = self.layer4_1(B_relu4_1) + A_feature5_1 = self.layer5_1(A_relu5_1) + B_feature5_1 = self.layer5_1(B_relu5_1) + + # concatenate features + if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]: + A_feature5_1 = F.pad(A_feature5_1, (0, 0, 1, 1), "replicate") + B_feature5_1 = F.pad(B_feature5_1, (0, 0, 1, 1), "replicate") + + A_features = self.layer(torch.cat((A_feature2_1, A_feature3_1, A_feature4_1, A_feature5_1), 1)) + B_features = self.layer(torch.cat((B_feature2_1, B_feature3_1, B_feature4_1, B_feature5_1), 1)) + + # pairwise cosine similarity + theta = self.theta(A_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width) + theta = theta - theta.mean(dim=-1, keepdim=True) # center the feature + theta_norm = torch.norm(theta, 2, 1, keepdim=True) + sys.float_info.epsilon + theta = torch.div(theta, theta_norm) + theta_permute = theta.permute(0, 2, 1) # 2*(feature_height*feature_width)*256 + phi = self.phi(B_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width) + phi = phi - phi.mean(dim=-1, keepdim=True) # center the feature + phi_norm = torch.norm(phi, 2, 1, keepdim=True) + sys.float_info.epsilon + phi = torch.div(phi, phi_norm) + f = torch.matmul(theta_permute, phi) # 2*(feature_height*feature_width)*(feature_height*feature_width) + if detach_flag: + f = f.detach() + + f_similarity = f.unsqueeze_(dim=1) + similarity_map = torch.max(f_similarity, -1, keepdim=True)[0] + similarity_map = similarity_map.view(batch_size, 1, feature_height, feature_width) + + # f can be negative + f_WTA = f if WTA_scale_weight == 1 else WTA_scale.apply(f, WTA_scale_weight) + f_WTA = f_WTA / temperature + f_div_C = F.softmax(f_WTA.squeeze_(), dim=-1) # 2*1936*1936; + + # downsample the reference color + B_lab = F.avg_pool2d(B_lab_map, 4) + B_lab = B_lab.view(batch_size, channel, -1) + B_lab = B_lab.permute(0, 2, 1) # 2*1936*channel + + # multiply the corr map with color + y = torch.matmul(f_div_C, B_lab) # 2*1936*channel + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, channel, feature_height, feature_width) # 2*3*44*44 + y = self.upsampling(y) + similarity_map = self.upsampling(similarity_map) + + return y, similarity_map \ No newline at end of file diff --git a/src/models/CNN/__init__.py b/src/models/CNN/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/vit/__init__.py b/src/models/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/vit/blocks.py b/src/models/vit/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1634453ba76163a1d6b72123765933825049be --- /dev/null +++ b/src/models/vit/blocks.py @@ -0,0 +1,80 @@ +import torch.nn as nn +from timm.models.layers import DropPath + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout, out_dim=None): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = nn.GELU() + if out_dim is None: + out_dim = dim + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.drop = nn.Dropout(dropout) + + @property + def unwrapped(self): + return self + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, heads, dropout): + super().__init__() + self.heads = heads + head_dim = dim // heads + self.scale = head_dim**-0.5 + self.attn = None + + self.qkv = nn.Linear(dim, dim * 3) + self.attn_drop = nn.Dropout(dropout) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(dropout) + + @property + def unwrapped(self): + return self + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, heads, mlp_dim, dropout, drop_path): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.attn = Attention(dim, heads, dropout) + self.mlp = FeedForward(dim, mlp_dim, dropout) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, mask=None, return_attention=False): + y, attn = self.attn(self.norm1(x), mask) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x diff --git a/src/models/vit/config.py b/src/models/vit/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9af621c90b36c60d0e1bd0b399c5a01295c7343b --- /dev/null +++ b/src/models/vit/config.py @@ -0,0 +1,22 @@ +import yaml +from pathlib import Path + +import os + + +def load_config(): + return yaml.load( + open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader + ) + + +def check_os_environ(key, use): + if key not in os.environ: + raise ValueError( + f"{key} is not defined in the os variables, it is required for {use}." + ) + + +def dataset_dir(): + check_os_environ("DATASET", "data loading") + return os.environ["DATASET"] diff --git a/src/models/vit/config.yml b/src/models/vit/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..79485c328c22ff0fb50d60f0d4c360e0801cc0bc --- /dev/null +++ b/src/models/vit/config.yml @@ -0,0 +1,132 @@ +model: + # deit + deit_tiny_distilled_patch16_224: + image_size: 224 + patch_size: 16 + d_model: 192 + n_heads: 3 + n_layers: 12 + normalization: deit + distilled: true + deit_small_distilled_patch16_224: + image_size: 224 + patch_size: 16 + d_model: 384 + n_heads: 6 + n_layers: 12 + normalization: deit + distilled: true + deit_base_distilled_patch16_224: + image_size: 224 + patch_size: 16 + d_model: 768 + n_heads: 12 + n_layers: 12 + normalization: deit + distilled: true + deit_base_distilled_patch16_384: + image_size: 384 + patch_size: 16 + d_model: 768 + n_heads: 12 + n_layers: 12 + normalization: deit + distilled: true + # vit + vit_base_patch8_384: + image_size: 384 + patch_size: 8 + d_model: 768 + n_heads: 12 + n_layers: 12 + normalization: vit + distilled: false + vit_tiny_patch16_384: + image_size: 384 + patch_size: 16 + d_model: 192 + n_heads: 3 + n_layers: 12 + normalization: vit + distilled: false + vit_small_patch16_384: + image_size: 384 + patch_size: 16 + d_model: 384 + n_heads: 6 + n_layers: 12 + normalization: vit + distilled: false + vit_base_patch16_384: + image_size: 384 + patch_size: 16 + d_model: 768 + n_heads: 12 + n_layers: 12 + normalization: vit + distilled: false + vit_large_patch16_384: + image_size: 384 + patch_size: 16 + d_model: 1024 + n_heads: 16 + n_layers: 24 + normalization: vit + vit_small_patch32_384: + image_size: 384 + patch_size: 32 + d_model: 384 + n_heads: 6 + n_layers: 12 + normalization: vit + distilled: false + vit_base_patch32_384: + image_size: 384 + patch_size: 32 + d_model: 768 + n_heads: 12 + n_layers: 12 + normalization: vit + vit_large_patch32_384: + image_size: 384 + patch_size: 32 + d_model: 1024 + n_heads: 16 + n_layers: 24 + normalization: vit +decoder: + linear: {} + deeplab_dec: + encoder_layer: -1 + mask_transformer: + drop_path_rate: 0.0 + dropout: 0.1 + n_layers: 2 +dataset: + ade20k: + epochs: 64 + eval_freq: 2 + batch_size: 8 + learning_rate: 0.001 + im_size: 512 + crop_size: 512 + window_size: 512 + window_stride: 512 + pascal_context: + epochs: 256 + eval_freq: 8 + batch_size: 16 + learning_rate: 0.001 + im_size: 520 + crop_size: 480 + window_size: 480 + window_stride: 320 + cityscapes: + epochs: 216 + eval_freq: 4 + batch_size: 8 + learning_rate: 0.01 + im_size: 1024 + crop_size: 768 + window_size: 768 + window_stride: 512 diff --git a/src/models/vit/decoder.py b/src/models/vit/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..acdb2f83660904423b97f9163bd81a4016dc8723 --- /dev/null +++ b/src/models/vit/decoder.py @@ -0,0 +1,34 @@ +import torch.nn as nn +from einops import rearrange +from src.models.vit.utils import init_weights + + +class DecoderLinear(nn.Module): + def __init__( + self, + n_cls, + d_encoder, + scale_factor, + dropout_rate=0.3, + ): + super().__init__() + self.scale_factor = scale_factor + self.head = nn.Linear(d_encoder, n_cls) + self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear") + self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor)) + self.dropout = nn.Dropout(dropout_rate) + self.gelu = nn.GELU() + self.apply(init_weights) + + def forward(self, x, img_size): + H, _ = img_size + x = self.head(x) ####### (2, 577, 64) + x = x.transpose(2, 1) ## (2, 64, 576) + x = self.upsampling(x) # (2, 64, 576*scale_factor*scale_factor) + x = x.transpose(2, 1) ## (2, 576*scale_factor*scale_factor, 64) + x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor)) # (2, 64, 24*scale_factor, 24*scale_factor) + x = self.norm(x) + x = self.dropout(x) + x = self.gelu(x) + + return x # (2, 64, a, a) diff --git a/src/models/vit/embed.py b/src/models/vit/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..50fe91a1b6798f002b50e7253f42743763311908 --- /dev/null +++ b/src/models/vit/embed.py @@ -0,0 +1,52 @@ +from torch import nn +from timm import create_model +from torchvision.transforms import Normalize + +class SwinModel(nn.Module): + def __init__(self, pretrained_model="swinv2-cr-t-224", device="cuda") -> None: + """ + vit_tiny_patch16_224.augreg_in21k_ft_in1k + swinv2_cr_tiny_ns_224.sw_in1k + """ + super().__init__() + self.device = device + self.pretrained_model = pretrained_model + if pretrained_model == "swinv2-cr-t-224": + self.pretrained = create_model( + "swinv2_cr_tiny_ns_224.sw_in1k", + pretrained=True, + features_only=True, + out_indices=[-4, -3, -2, -1], + ).to(device) + elif pretrained_model == "swinv2-t-256": + self.pretrained = create_model( + "swinv2_tiny_window16_256.ms_in1k", + pretrained=True, + features_only=True, + out_indices=[-4, -3, -2, -1], + ).to(device) + elif pretrained_model == "swinv2-cr-s-224": + self.pretrained = create_model( + "swinv2_cr_small_ns_224.sw_in1k", + pretrained=True, + features_only=True, + out_indices=[-4, -3, -2, -1], + ).to(device) + else: + raise NotImplementedError + + self.pretrained.eval() + self.normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.upsample = nn.Upsample(scale_factor=2) + + for params in self.pretrained.parameters(): + params.requires_grad = False + + def forward(self, x): + outputs = self.pretrained(x) + if self.pretrained_model in ["swinv2-t-256"]: + for i in range(len(outputs)): + outputs[i] = outputs[i].permute(0, 3, 1, 2) # Change channel-last to channel-first + outputs = [self.upsample(feat) for feat in outputs] + + return outputs \ No newline at end of file diff --git a/src/models/vit/factory.py b/src/models/vit/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2cad05744bf6ed60ee6278a5b79b92321ac4c5 --- /dev/null +++ b/src/models/vit/factory.py @@ -0,0 +1,45 @@ +import os +import torch +from timm.models.vision_transformer import default_cfgs +from timm.models.helpers import load_pretrained, load_custom_pretrained +from src.models.vit.utils import checkpoint_filter_fn +from src.models.vit.vit import VisionTransformer + + +def create_vit(model_cfg): + model_cfg = model_cfg.copy() + backbone = model_cfg.pop("backbone") + + model_cfg.pop("normalization") + model_cfg["n_cls"] = 1000 + mlp_expansion_ratio = 4 + model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] + + if backbone in default_cfgs: + default_cfg = default_cfgs[backbone] + else: + default_cfg = dict( + pretrained=False, + num_classes=1000, + drop_rate=0.0, + drop_path_rate=0.0, + drop_block_rate=None, + ) + + default_cfg["input_size"] = ( + 3, + model_cfg["image_size"][0], + model_cfg["image_size"][1], + ) + model = VisionTransformer(**model_cfg) + if backbone == "vit_base_patch8_384": + path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth") + state_dict = torch.load(path, map_location="cpu") + filtered_dict = checkpoint_filter_fn(state_dict, model) + model.load_state_dict(filtered_dict, strict=True) + elif "deit" in backbone: + load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) + else: + load_custom_pretrained(model, default_cfg) + + return model diff --git a/src/models/vit/utils.py b/src/models/vit/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9aad57595ed5c7c2b84f436223ea9f8d8b961a2 --- /dev/null +++ b/src/models/vit/utils.py @@ -0,0 +1,71 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_ +from collections import OrderedDict + + +def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + posemb_tok, posemb_grid = ( + posemb[:, :num_extra_tokens], + posemb[0, num_extra_tokens:], + ) + if grid_old_shape is None: + gs_old_h = int(math.sqrt(len(posemb_grid))) + gs_old_w = gs_old_h + else: + gs_old_h, gs_old_w = grid_old_shape + + gs_h, gs_w = grid_new_shape + posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + +def checkpoint_filter_fn(state_dict, model): + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if "model" in state_dict: + # For deit models + state_dict = state_dict["model"] + num_extra_tokens = 1 + ("dist_token" in state_dict.keys()) + patch_size = model.patch_size + image_size = model.patch_embed.image_size + for k, v in state_dict.items(): + if k == "pos_embed" and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, + None, + (image_size[0] // patch_size, image_size[1] // patch_size), + num_extra_tokens, + ) + out_dict[k] = v + return out_dict + +def load_params(ckpt_file, device): + # params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}') + # new_params = [] + # for key, value in params.items(): + # new_params.append(("module."+key if has_module else key, value)) + # return OrderedDict(new_params) + params = torch.load(ckpt_file, map_location=device) + new_params = [] + for key, value in params.items(): + new_params.append((key, value)) + return OrderedDict(new_params) diff --git a/src/models/vit/vit.py b/src/models/vit/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d868d3ef239283195f492f29f5434ebe55b322 --- /dev/null +++ b/src/models/vit/vit.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import _load_weights +from timm.models.layers import trunc_normal_ +from typing import List + +from src.models.vit.utils import init_weights, resize_pos_embed +from src.models.vit.blocks import Block +from src.models.vit.decoder import DecoderLinear + + +class PatchEmbedding(nn.Module): + def __init__(self, image_size, patch_size, embed_dim, channels): + super().__init__() + + self.image_size = image_size + if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0: + raise ValueError("image dimensions must be divisible by the patch size") + self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.patch_size = patch_size + + self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, im): + B, C, H, W = im.shape + x = self.proj(im).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size, + patch_size, + n_layers, + d_model, + d_ff, + n_heads, + n_cls, + dropout=0.1, + drop_path_rate=0.0, + distilled=False, + channels=3, + ): + super().__init__() + self.patch_embed = PatchEmbedding( + image_size, + patch_size, + d_model, + channels, + ) + self.patch_size = patch_size + self.n_layers = n_layers + self.d_model = d_model + self.d_ff = d_ff + self.n_heads = n_heads + self.dropout = nn.Dropout(dropout) + self.n_cls = n_cls + + # cls and pos tokens + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.distilled = distilled + if self.distilled: + self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 2, d_model)) + self.head_dist = nn.Linear(d_model, n_cls) + else: + self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, d_model)) + + # transformer blocks + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] + self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]) + + # output head + self.norm = nn.LayerNorm(d_model) + self.head = nn.Linear(d_model, n_cls) + + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) + if self.distilled: + trunc_normal_(self.dist_token, std=0.02) + self.pre_logits = nn.Identity() + + self.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=""): + _load_weights(self, checkpoint_path, prefix) + + def forward(self, im, head_out_idx: List[int], n_dim_output=3, return_features=False): + B, _, H, W = im.shape + PS = self.patch_size + assert n_dim_output == 3 or n_dim_output == 4, "n_dim_output must be 3 or 4" + x = self.patch_embed(im) + cls_tokens = self.cls_token.expand(B, -1, -1) + if self.distilled: + dist_tokens = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_tokens, x), dim=1) + else: + x = torch.cat((cls_tokens, x), dim=1) + + pos_embed = self.pos_embed + num_extra_tokens = 1 + self.distilled + if x.shape[1] != pos_embed.shape[1]: + pos_embed = resize_pos_embed( + pos_embed, + self.patch_embed.grid_size, + (H // PS, W // PS), + num_extra_tokens, + ) + x = x + pos_embed + x = self.dropout(x) + device = x.device + + if n_dim_output == 3: + heads_out = torch.zeros(size=(len(head_out_idx), B, (H // PS) ** 2 + 1, self.d_model)).to(device) + else: + heads_out = torch.zeros(size=(len(head_out_idx), B, self.d_model, H // PS, H // PS)).to(device) + self.register_buffer("heads_out", heads_out) + + head_idx = 0 + for idx_layer, blk in enumerate(self.blocks): + x = blk(x) + if idx_layer in head_out_idx: + if n_dim_output == 3: + heads_out[head_idx] = x + else: + heads_out[head_idx] = x[:, 1:, :].reshape((-1, 24, 24, self.d_model)).permute(0, 3, 1, 2) + head_idx += 1 + + x = self.norm(x) + + if return_features: + return heads_out + + if self.distilled: + x, x_dist = x[:, 0], x[:, 1] + x = self.head(x) + x_dist = self.head_dist(x_dist) + x = (x + x_dist) / 2 + else: + x = x[:, 0] + x = self.head(x) + return x + + def get_attention_map(self, im, layer_id): + if layer_id >= self.n_layers or layer_id < 0: + raise ValueError(f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}.") + B, _, H, W = im.shape + PS = self.patch_size + + x = self.patch_embed(im) + cls_tokens = self.cls_token.expand(B, -1, -1) + if self.distilled: + dist_tokens = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_tokens, x), dim=1) + else: + x = torch.cat((cls_tokens, x), dim=1) + + pos_embed = self.pos_embed + num_extra_tokens = 1 + self.distilled + if x.shape[1] != pos_embed.shape[1]: + pos_embed = resize_pos_embed( + pos_embed, + self.patch_embed.grid_size, + (H // PS, W // PS), + num_extra_tokens, + ) + x = x + pos_embed + + for i, blk in enumerate(self.blocks): + if i < layer_id: + x = blk(x) + else: + return blk(x, return_attention=True) + + +class FeatureTransform(nn.Module): + def __init__(self, img_size, d_encoder, nls_list=[128, 256, 512, 512], scale_factor_list=[8, 4, 2, 1]): + super(FeatureTransform, self).__init__() + self.img_size = img_size + + self.decoder_0 = DecoderLinear(n_cls=nls_list[0], d_encoder=d_encoder, scale_factor=scale_factor_list[0]) + self.decoder_1 = DecoderLinear(n_cls=nls_list[1], d_encoder=d_encoder, scale_factor=scale_factor_list[1]) + self.decoder_2 = DecoderLinear(n_cls=nls_list[2], d_encoder=d_encoder, scale_factor=scale_factor_list[2]) + self.decoder_3 = DecoderLinear(n_cls=nls_list[3], d_encoder=d_encoder, scale_factor=scale_factor_list[3]) + + def forward(self, x_list): + feat_3 = self.decoder_3(x_list[3][:, 1:, :], self.img_size) # (2, 512, 24, 24) + feat_2 = self.decoder_2(x_list[2][:, 1:, :], self.img_size) # (2, 512, 48, 48) + feat_1 = self.decoder_1(x_list[1][:, 1:, :], self.img_size) # (2, 256, 96, 96) + feat_0 = self.decoder_0(x_list[0][:, 1:, :], self.img_size) # (2, 128, 192, 192) + return feat_0, feat_1, feat_2, feat_3 \ No newline at end of file diff --git a/src/scheduler.py b/src/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..87a9a1c5fcbf2df2a9263d49a5d3f5ba87ccb48d --- /dev/null +++ b/src/scheduler.py @@ -0,0 +1,40 @@ +from torch.optim.lr_scheduler import _LRScheduler + +class PolynomialLR(_LRScheduler): + def __init__( + self, + optimizer, + step_size, + iter_warmup, + iter_max, + power, + min_lr=0, + last_epoch=-1, + ): + self.step_size = step_size + self.iter_warmup = int(iter_warmup) + self.iter_max = int(iter_max) + self.power = power + self.min_lr = min_lr + super(PolynomialLR, self).__init__(optimizer, last_epoch) + + def polynomial_decay(self, lr): + iter_cur = float(self.last_epoch) + if iter_cur < self.iter_warmup: + coef = iter_cur / self.iter_warmup + coef *= (1 - self.iter_warmup / self.iter_max) ** self.power + else: + coef = (1 - iter_cur / self.iter_max) ** self.power + return (lr - self.min_lr) * coef + self.min_lr + + def get_lr(self): + if ( + (self.last_epoch == 0) + or (self.last_epoch % self.step_size != 0) + or (self.last_epoch > self.iter_max) + ): + return [group["lr"] for group in self.optimizer.param_groups] + return [self.polynomial_decay(lr) for lr in self.base_lrs] + + def step_update(self, num_updates): + self.step() \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c76f96480aceb430ebca9f4d6655b2f0847a23 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,849 @@ +import sys +import time +import numpy as np +from PIL import Image +from skimage import color +from skimage.transform import resize +import src.data.functional as F +import torch +from torch import nn +import torch.nn.functional as F_torch +import torchvision.transforms.functional as F_torchvision +from numba import cuda, jit +import math +import torchvision.utils as vutils +from torch.autograd import Variable + +rgb_from_xyz = np.array( + [ + [3.24048134, -0.96925495, 0.05564664], + [-1.53715152, 1.87599, -0.20404134], + [-0.49853633, 0.04155593, 1.05731107], + ] +) +l_norm, ab_norm = 1.0, 1.0 +l_mean, ab_mean = 50.0, 0 + + +import numpy as np +from PIL import Image +from skimage.transform import resize + +import numpy as np +from PIL import Image +from skimage.transform import resize + +class SquaredPadding: + def __init__(self, target_size=384, fill_value=0): + self.target_size = target_size + self.fill_value = fill_value + + def __call__(self, img, return_pil=True, return_paddings=False, dtype=np.uint8): + if not isinstance(img, np.ndarray): + img = np.array(img) + ndim = len(img.shape) + H, W = img.shape[:2] + if H > W: + H_new, W_new = self.target_size, int(W/H*self.target_size) + # Resize image + img = resize(img, (H_new, W_new), preserve_range=True).astype(dtype) + + # Padding image + padded_size = H_new - W_new + if ndim == 3: + paddings = [(0, 0), (padded_size // 2, (padded_size // 2) + (padded_size % 2)), (0,0)] + elif ndim == 2: + paddings = [(0, 0), (padded_size // 2, (padded_size // 2) + (padded_size % 2))] + padded_img = np.pad(img, paddings, mode='constant', constant_values=self.fill_value) + else: + H_new, W_new = int(H/W*self.target_size), self.target_size + # Resize image + img = resize(img, (H_new, W_new), preserve_range=True).astype(dtype) + + # Padding image + padded_size = W_new - H_new + if ndim == 3: + paddings = [(padded_size // 2, (padded_size // 2) + (padded_size % 2)), (0, 0), (0,0)] + elif ndim == 2: + paddings = [(padded_size // 2, (padded_size // 2) + (padded_size % 2)), (0, 0)] + padded_img = np.pad(img, paddings, mode='constant', constant_values=self.fill_value) + + if return_pil: + padded_img = Image.fromarray(padded_img) + + if return_paddings: + return padded_img, paddings + + return padded_img + +class UnpaddingSquare(): + def __call__(self, img, paddings): + if not isinstance(img, np.ndarray): + img = np.array(img) + + H, W = img.shape[0], img.shape[1] + (pad_top, pad_bottom), (pad_left, pad_right), _ = paddings + W_ori = W - pad_left - pad_right + H_ori = H - pad_top - pad_bottom + + return img[pad_top:pad_top+H_ori, pad_left:pad_left+W_ori, :] + +class UnpaddingSquare_Tensor(): + def __call__(self, img, paddings): + H, W = img.shape[1], img.shape[2] + (pad_top, pad_bottom), (pad_left, pad_right), _ = paddings + W_ori = W - pad_left - pad_right + H_ori = H - pad_top - pad_bottom + + return img[:, pad_top:pad_top+H_ori, pad_left:pad_left+W_ori] + +class ResizeFlow(object): + def __init__(self, target_size=(224,224)): + self.target_size = target_size + pass + + def __call__(self, flow): + return F_torch.interpolate(flow.unsqueeze(0), self.target_size, mode='bilinear', align_corners=True).squeeze(0) + +class SquaredPaddingFlow(object): + def __init__(self, fill_value=0): + self.fill_value = fill_value + + def __call__(self, flow): + H, W = flow.size(1), flow.size(2) + + if H > W: + # Padding flow + padded_size = H - W + paddings = (padded_size // 2, (padded_size // 2) + (padded_size % 2), 0, 0) + padded_img = F_torch.pad(flow, paddings, value=self.fill_value) + else: + # Padding flow + padded_size = W - H + paddings = (0, 0, padded_size // 2, (padded_size // 2) + (padded_size % 2)) + padded_img = F_torch.pad(flow, paddings, value=self.fill_value) + + return padded_img + + +def gray2rgb_batch(l): + # gray image tensor to rgb image tensor + l_uncenter = uncenter_l(l) + l_uncenter = l_uncenter / (2 * l_mean) + return torch.cat((l_uncenter, l_uncenter, l_uncenter), dim=1) + +def batch_lab2rgb_transpose_mc(img_l_mc, img_ab_mc, nrow=8): + if isinstance(img_l_mc, Variable): + img_l_mc = img_l_mc.data.cpu() + if isinstance(img_ab_mc, Variable): + img_ab_mc = img_ab_mc.data.cpu() + + if img_l_mc.is_cuda: + img_l_mc = img_l_mc.cpu() + if img_ab_mc.is_cuda: + img_ab_mc = img_ab_mc.cpu() + + assert img_l_mc.dim() == 4 and img_ab_mc.dim() == 4, "only for batch input" + + img_l = img_l_mc * l_norm + l_mean + img_ab = img_ab_mc * ab_norm + ab_mean + pred_lab = torch.cat((img_l, img_ab), dim=1) + grid_lab = vutils.make_grid(pred_lab, nrow=nrow).numpy().astype("float64") + return (np.clip(color.lab2rgb(grid_lab.transpose((1, 2, 0))), 0, 1) * 255).astype("uint8") + + +def vgg_preprocess(tensor): + # input is RGB tensor which ranges in [0,1] + # output is BGR tensor which ranges in [0,255] + tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) + tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1) + return tensor_bgr_ml * 255 + + +def tensor_lab2rgb(input): + """ + n * 3* h *w + """ + input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3 + L, a, b = ( + input_trans[:, :, :, 0:1], + input_trans[:, :, :, 1:2], + input_trans[:, :, :, 2:], + ) + y = (L + 16.0) / 116.0 + x = (a / 500.0) + y + z = y - (b / 200.0) + + neg_mask = z.data < 0 + z[neg_mask] = 0 + xyz = torch.cat((x, y, z), dim=3) + + mask = xyz.data > 0.2068966 + mask_xyz = xyz.clone() + mask_xyz[mask] = torch.pow(xyz[mask], 3.0) + mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787 + mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047 + mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883 + + rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view( + input.size(0), input.size(2), input.size(3), 3 + ) + rgb = rgb_trans.transpose(2, 3).transpose(1, 2) + + mask = rgb > 0.0031308 + mask_rgb = rgb.clone() + mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 + mask_rgb[~mask] = rgb[~mask] * 12.92 + + neg_mask = mask_rgb.data < 0 + large_mask = mask_rgb.data > 1 + mask_rgb[neg_mask] = 0 + mask_rgb[large_mask] = 1 + return mask_rgb + + +###### loss functions ###### +def feature_normalize(feature_in): + feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + sys.float_info.epsilon + feature_in_norm = torch.div(feature_in, feature_in_norm) + return feature_in_norm + + +# denormalization for l +def uncenter_l(l): + return l * l_norm + l_mean + + +def get_grid(x): + torchHorizontal = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3)) + torchVertical = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3)) + + return torch.cat([torchHorizontal, torchVertical], 1) + + +class WarpingLayer(nn.Module): + def __init__(self, device): + super(WarpingLayer, self).__init__() + self.device = device + + def forward(self, x, flow): + """ + It takes the input image and the flow and warps the input image according to the flow + + Args: + x: the input image + flow: the flow tensor, which is a 4D tensor of shape (batch_size, 2, height, width) + + Returns: + The warped image + """ + # WarpingLayer uses F.grid_sample, which expects normalized grid + # we still output unnormalized flow for the convenience of comparing EPEs with FlowNet2 and original code + # so here we need to denormalize the flow + flow_for_grip = torch.zeros_like(flow).to(self.device) + flow_for_grip[:, 0, :, :] = flow[:, 0, :, :] / ((flow.size(3) - 1.0) / 2.0) + flow_for_grip[:, 1, :, :] = flow[:, 1, :, :] / ((flow.size(2) - 1.0) / 2.0) + + grid = (get_grid(x).to(self.device) + flow_for_grip).permute(0, 2, 3, 1) + return F_torch.grid_sample(x, grid, align_corners=True) + + +class CenterPad_threshold(object): + def __init__(self, image_size, threshold=3 / 4): + self.height = image_size[0] + self.width = image_size[1] + self.threshold = threshold + + def __call__(self, image): + # pad the image to 16:9 + # pad height + I = np.array(image) + + # for padded input + height_old = np.size(I, 0) + width_old = np.size(I, 1) + old_size = [height_old, width_old] + height = self.height + width = self.width + I_pad = np.zeros((height, width, np.size(I, 2))) + + ratio = height / width + + if height_old / width_old == ratio: + if height_old == height: + return Image.fromarray(I.astype(np.uint8)) + new_size = [int(x * height / height_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + return Image.fromarray(I_resize.astype(np.uint8)) + + if height_old / width_old > self.threshold: + width_new, height_new = width_old, int(width_old * self.threshold) + height_margin = height_old - height_new + height_crop_start = height_margin // 2 + I_crop = I[height_crop_start : (height_crop_start + height_new), :, :] + I_resize = resize(I_crop, [height, width], mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + + return Image.fromarray(I_resize.astype(np.uint8)) + + if height_old / width_old > ratio: # pad the width and crop + new_size = [int(x * width / width_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + width_resize = np.size(I_resize, 1) + height_resize = np.size(I_resize, 0) + start_height = (height_resize - height) // 2 + I_pad[:, :, :] = I_resize[start_height : (start_height + height), :, :] + else: # pad the height and crop + new_size = [int(x * height / height_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + width_resize = np.size(I_resize, 1) + height_resize = np.size(I_resize, 0) + start_width = (width_resize - width) // 2 + I_pad[:, :, :] = I_resize[:, start_width : (start_width + width), :] + + return Image.fromarray(I_pad.astype(np.uint8)) + + +class Normalize(object): + def __init__(self): + pass + + def __call__(self, inputs): + inputs[0:1, :, :] = F.normalize(inputs[0:1, :, :], 50, 1) + inputs[1:3, :, :] = F.normalize(inputs[1:3, :, :], (0, 0), (1, 1)) + return inputs + + +class RGB2Lab(object): + def __init__(self): + pass + + def __call__(self, inputs): + return color.rgb2lab(inputs) + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, inputs): + return F.to_mytensor(inputs) + + +class CenterPad(object): + def __init__(self, image_size): + self.height = image_size[0] + self.width = image_size[1] + + def __call__(self, image): + # pad the image to 16:9 + # pad height + I = np.array(image) + + # for padded input + height_old = np.size(I, 0) + width_old = np.size(I, 1) + old_size = [height_old, width_old] + height = self.height + width = self.width + I_pad = np.zeros((height, width, np.size(I, 2))) + + ratio = height / width + if height_old / width_old == ratio: + if height_old == height: + return Image.fromarray(I.astype(np.uint8)) + new_size = [int(x * height / height_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + return Image.fromarray(I_resize.astype(np.uint8)) + + if height_old / width_old > ratio: # pad the width and crop + new_size = [int(x * width / width_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + width_resize = np.size(I_resize, 1) + height_resize = np.size(I_resize, 0) + start_height = (height_resize - height) // 2 + I_pad[:, :, :] = I_resize[start_height : (start_height + height), :, :] + else: # pad the height and crop + new_size = [int(x * height / height_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + width_resize = np.size(I_resize, 1) + height_resize = np.size(I_resize, 0) + start_width = (width_resize - width) // 2 + I_pad[:, :, :] = I_resize[:, start_width : (start_width + width), :] + + return Image.fromarray(I_pad.astype(np.uint8)) + + +class CenterPadCrop_numpy(object): + """ + pad the image according to the height + """ + + def __init__(self, image_size): + self.height = image_size[0] + self.width = image_size[1] + + def __call__(self, image, threshold=3 / 4): + # pad the image to 16:9 + # pad height + I = np.array(image) + # for padded input + height_old = np.size(I, 0) + width_old = np.size(I, 1) + old_size = [height_old, width_old] + height = self.height + width = self.width + padding_size = width + if image.ndim == 2: + I_pad = np.zeros((width, width)) + else: + I_pad = np.zeros((width, width, I.shape[2])) + + ratio = height / width + if height_old / width_old == ratio: + return I + + # if height_old / width_old > threshold: + # width_new, height_new = width_old, int(width_old * threshold) + # height_margin = height_old - height_new + # height_crop_start = height_margin // 2 + # I_crop = I[height_start : (height_start + height_new), :] + # I_resize = resize( + # I_crop, [height, width], mode="reflect", preserve_range=True, clip=False, anti_aliasing=True + # ) + # return I_resize + + if height_old / width_old > ratio: # pad the width and crop + new_size = [int(x * width / width_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + width_resize = np.size(I_resize, 1) + height_resize = np.size(I_resize, 0) + start_height = (height_resize - height) // 2 + start_height_block = (padding_size - height) // 2 + if image.ndim == 2: + I_pad[start_height_block : (start_height_block + height), :] = I_resize[ + start_height : (start_height + height), : + ] + else: + I_pad[start_height_block : (start_height_block + height), :, :] = I_resize[ + start_height : (start_height + height), :, : + ] + else: # pad the height and crop + new_size = [int(x * height / height_old) for x in old_size] + I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True) + width_resize = np.size(I_resize, 1) + height_resize = np.size(I_resize, 0) + start_width = (width_resize - width) // 2 + start_width_block = (padding_size - width) // 2 + if image.ndim == 2: + I_pad[:, start_width_block : (start_width_block + width)] = I_resize[:, start_width : (start_width + width)] + + else: + I_pad[:, start_width_block : (start_width_block + width), :] = I_resize[ + :, start_width : (start_width + width), : + ] + + crop_start_height = (I_pad.shape[0] - height) // 2 + crop_start_width = (I_pad.shape[1] - width) // 2 + + if image.ndim == 2: + return I_pad[crop_start_height : (crop_start_height + height), crop_start_width : (crop_start_width + width)] + else: + return I_pad[crop_start_height : (crop_start_height + height), crop_start_width : (crop_start_width + width), :] + + +@jit(nopython=True, nogil=True) +def biInterpolation_cpu(distorted, i, j): + i = np.uint16(i) + j = np.uint16(j) + Q11 = distorted[j, i] + Q12 = distorted[j, i + 1] + Q21 = distorted[j + 1, i] + Q22 = distorted[j + 1, i + 1] + + return np.int8( + Q11 * (i + 1 - i) * (j + 1 - j) + Q12 * (i - i) * (j + 1 - j) + Q21 * (i + 1 - i) * (j - j) + Q22 * (i - i) * (j - j) + ) + +@jit(nopython=True, nogil=True) +def iterSearchShader_cpu(padu, padv, xr, yr, W, H, maxIter, precision): + # print('processing location', (xr, yr)) + # + if abs(padu[yr, xr]) < precision and abs(padv[yr, xr]) < precision: + return xr, yr + + # Our initialize method in this paper, can see the overleaf for detail + if (xr + 1) <= (W - 1): + dif = padu[yr, xr + 1] - padu[yr, xr] + else: + dif = padu[yr, xr] - padu[yr, xr - 1] + u_next = padu[yr, xr] / (1 + dif) + if (yr + 1) <= (H - 1): + dif = padv[yr + 1, xr] - padv[yr, xr] + else: + dif = padv[yr, xr] - padv[yr - 1, xr] + v_next = padv[yr, xr] / (1 + dif) + i = xr - u_next + j = yr - v_next + i_int = int(i) + j_int = int(j) + + # The same as traditional iterative search method + for _ in range(maxIter): + if not 0 <= i <= (W - 1) or not 0 <= j <= (H - 1): + return i, j + + u11 = padu[j_int, i_int] + v11 = padv[j_int, i_int] + + u12 = padu[j_int, i_int + 1] + v12 = padv[j_int, i_int + 1] + + int1 = padu[j_int + 1, i_int] + v21 = padv[j_int + 1, i_int] + + int2 = padu[j_int + 1, i_int + 1] + v22 = padv[j_int + 1, i_int + 1] + + u = ( + u11 * (i_int + 1 - i) * (j_int + 1 - j) + + u12 * (i - i_int) * (j_int + 1 - j) + + int1 * (i_int + 1 - i) * (j - j_int) + + int2 * (i - i_int) * (j - j_int) + ) + + v = ( + v11 * (i_int + 1 - i) * (j_int + 1 - j) + + v12 * (i - i_int) * (j_int + 1 - j) + + v21 * (i_int + 1 - i) * (j - j_int) + + v22 * (i - i_int) * (j - j_int) + ) + + i_next = xr - u + j_next = yr - v + + if abs(i - i_next) < precision and abs(j - j_next) < precision: + return i, j + + i = i_next + j = j_next + + # if the search doesn't converge within max iter, it will return the last iter result + return i_next, j_next + +@jit(nopython=True, nogil=True) +def iterSearch_cpu(distortImg, resultImg, padu, padv, W, H, maxIter=5, precision=1e-2): + for xr in range(W): + for yr in range(H): + # (xr, yr) is the point in result image, (i, j) is the search result in distorted image + i, j = iterSearchShader_cpu(padu, padv, xr, yr, W, H, maxIter, precision) + + # reflect the pixels outside the border + if i > W - 1: + i = 2 * W - 1 - i + if i < 0: + i = -i + if j > H - 1: + j = 2 * H - 1 - j + if j < 0: + j = -j + + # Bilinear interpolation to get the pixel at (i, j) in distorted image + resultImg[yr, xr, 0] = biInterpolation_cpu( + distortImg[:, :, 0], + i, + j, + ) + resultImg[yr, xr, 1] = biInterpolation_cpu( + distortImg[:, :, 1], + i, + j, + ) + resultImg[yr, xr, 2] = biInterpolation_cpu( + distortImg[:, :, 2], + i, + j, + ) + return None + + +def forward_mapping_cpu(source_image, u, v, maxIter=5, precision=1e-2): + """ + warp the image according to the forward flow + u: horizontal + v: vertical + """ + H = source_image.shape[0] + W = source_image.shape[1] + + distortImg = np.array(np.zeros((H + 1, W + 1, 3)), dtype=np.uint8) + distortImg[0:H, 0:W] = source_image[0:H, 0:W] + distortImg[H, 0:W] = source_image[H - 1, 0:W] + distortImg[0:H, W] = source_image[0:H, W - 1] + distortImg[H, W] = source_image[H - 1, W - 1] + + padu = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32) + padu[0:H, 0:W] = u[0:H, 0:W] + padu[H, 0:W] = u[H - 1, 0:W] + padu[0:H, W] = u[0:H, W - 1] + padu[H, W] = u[H - 1, W - 1] + + padv = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32) + padv[0:H, 0:W] = v[0:H, 0:W] + padv[H, 0:W] = v[H - 1, 0:W] + padv[0:H, W] = v[0:H, W - 1] + padv[H, W] = v[H - 1, W - 1] + + resultImg = np.array(np.zeros((H, W, 3)), dtype=np.uint8) + iterSearch_cpu(distortImg, resultImg, padu, padv, W, H, maxIter, precision) + return resultImg + +class Distortion_with_flow_cpu(object): + """Elastic distortion""" + + def __init__(self, maxIter=3, precision=1e-3): + self.maxIter = maxIter + self.precision = precision + + def __call__(self, inputs, dx, dy): + inputs = np.array(inputs) + shape = inputs.shape[0], inputs.shape[1] + remap_image = forward_mapping_cpu(inputs, dy, dx, maxIter=self.maxIter, precision=self.precision) + + return Image.fromarray(remap_image) + +@cuda.jit(device=True) +def biInterpolation_gpu(distorted, i, j): + i = int(i) + j = int(j) + Q11 = distorted[j, i] + Q12 = distorted[j, i + 1] + Q21 = distorted[j + 1, i] + Q22 = distorted[j + 1, i + 1] + + return np.int8( + Q11 * (i + 1 - i) * (j + 1 - j) + Q12 * (i - i) * (j + 1 - j) + Q21 * (i + 1 - i) * (j - j) + Q22 * (i - i) * (j - j) + ) + +@cuda.jit(device=True) +def iterSearchShader_gpu(padu, padv, xr, yr, W, H, maxIter, precision): + # print('processing location', (xr, yr)) + # + if abs(padu[yr, xr]) < precision and abs(padv[yr, xr]) < precision: + return xr, yr + + # Our initialize method in this paper, can see the overleaf for detail + if (xr + 1) <= (W - 1): + dif = padu[yr, xr + 1] - padu[yr, xr] + else: + dif = padu[yr, xr] - padu[yr, xr - 1] + u_next = padu[yr, xr] / (1 + dif) + if (yr + 1) <= (H - 1): + dif = padv[yr + 1, xr] - padv[yr, xr] + else: + dif = padv[yr, xr] - padv[yr - 1, xr] + v_next = padv[yr, xr] / (1 + dif) + i = xr - u_next + j = yr - v_next + i_int = int(i) + j_int = int(j) + + # The same as traditional iterative search method + for _ in range(maxIter): + if not 0 <= i <= (W - 1) or not 0 <= j <= (H - 1): + return i, j + + u11 = padu[j_int, i_int] + v11 = padv[j_int, i_int] + + u12 = padu[j_int, i_int + 1] + v12 = padv[j_int, i_int + 1] + + int1 = padu[j_int + 1, i_int] + v21 = padv[j_int + 1, i_int] + + int2 = padu[j_int + 1, i_int + 1] + v22 = padv[j_int + 1, i_int + 1] + + u = ( + u11 * (i_int + 1 - i) * (j_int + 1 - j) + + u12 * (i - i_int) * (j_int + 1 - j) + + int1 * (i_int + 1 - i) * (j - j_int) + + int2 * (i - i_int) * (j - j_int) + ) + + v = ( + v11 * (i_int + 1 - i) * (j_int + 1 - j) + + v12 * (i - i_int) * (j_int + 1 - j) + + v21 * (i_int + 1 - i) * (j - j_int) + + v22 * (i - i_int) * (j - j_int) + ) + + i_next = xr - u + j_next = yr - v + + if abs(i - i_next) < precision and abs(j - j_next) < precision: + return i, j + + i = i_next + j = j_next + + # if the search doesn't converge within max iter, it will return the last iter result + return i_next, j_next + +@cuda.jit +def iterSearch_gpu(distortImg, resultImg, padu, padv, W, H, maxIter=5, precision=1e-2): + + start_x, start_y = cuda.grid(2) + stride_x, stride_y = cuda.gridsize(2) + + for xr in range(start_x, W, stride_x): + for yr in range(start_y, H, stride_y): + + i,j = iterSearchShader_gpu(padu, padv, xr, yr, W, H, maxIter, precision) + + if i > W - 1: + i = 2 * W - 1 - i + if i < 0: + i = -i + if j > H - 1: + j = 2 * H - 1 - j + if j < 0: + j = -j + + resultImg[yr, xr,0] = biInterpolation_gpu(distortImg[:,:,0], i, j) + resultImg[yr, xr,1] = biInterpolation_gpu(distortImg[:,:,1], i, j) + resultImg[yr, xr,2] = biInterpolation_gpu(distortImg[:,:,2], i, j) + return None + +def forward_mapping_gpu(source_image, u, v, maxIter=5, precision=1e-2): + """ + warp the image according to the forward flow + u: horizontal + v: vertical + """ + H = source_image.shape[0] + W = source_image.shape[1] + + resultImg = np.array(np.zeros((H, W, 3)), dtype=np.uint8) + + distortImg = np.array(np.zeros((H + 1, W + 1, 3)), dtype=np.uint8) + distortImg[0:H, 0:W] = source_image[0:H, 0:W] + distortImg[H, 0:W] = source_image[H - 1, 0:W] + distortImg[0:H, W] = source_image[0:H, W - 1] + distortImg[H, W] = source_image[H - 1, W - 1] + + padu = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32) + padu[0:H, 0:W] = u[0:H, 0:W] + padu[H, 0:W] = u[H - 1, 0:W] + padu[0:H, W] = u[0:H, W - 1] + padu[H, W] = u[H - 1, W - 1] + + padv = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32) + padv[0:H, 0:W] = v[0:H, 0:W] + padv[H, 0:W] = v[H - 1, 0:W] + padv[0:H, W] = v[0:H, W - 1] + padv[H, W] = v[H - 1, W - 1] + + padu = cuda.to_device(padu) + padv = cuda.to_device(padv) + distortImg = cuda.to_device(distortImg) + resultImg = cuda.to_device(resultImg) + + threadsperblock = (16, 16) + blockspergrid_x = math.ceil(W / threadsperblock[0]) + blockspergrid_y = math.ceil(H / threadsperblock[1]) + blockspergrid = (blockspergrid_x, blockspergrid_y) + + + iterSearch_gpu[blockspergrid, threadsperblock](distortImg, resultImg, padu, padv, W, H, maxIter, precision) + resultImg = resultImg.copy_to_host() + return resultImg + +class Distortion_with_flow_gpu(object): + + def __init__(self, maxIter=3, precision=1e-3): + self.maxIter = maxIter + self.precision = precision + + def __call__(self, inputs, dx, dy): + inputs = np.array(inputs) + shape = inputs.shape[0], inputs.shape[1] + remap_image = forward_mapping_gpu(inputs, dy, dx, maxIter=self.maxIter, precision=self.precision) + + return Image.fromarray(remap_image) + +def read_flow(filename): + """ + read optical flow from Middlebury .flo file + :param filename: name of the flow file + :return: optical flow data in matrix + """ + f = open(filename, "rb") + try: + magic = np.fromfile(f, np.float32, count=1)[0] # For Python3.x + except: + magic = np.fromfile(f, np.float32, count=1) # For Python2.x + data2d = None + if (202021.25 != magic)and(123.25!=magic): + print("Magic number incorrect. Invalid .flo file") + elif (123.25==magic): + w = np.fromfile(f, np.int32, count=1)[0] + h = np.fromfile(f, np.int32, count=1)[0] + # print("Reading %d x %d flo file" % (h, w)) + data2d = np.fromfile(f, np.float16, count=2 * w * h) + # reshape data into 3D array (columns, rows, channels) + data2d = np.resize(data2d, (h, w, 2)) + elif (202021.25 == magic): + w = np.fromfile(f, np.int32, count=1)[0] + h = np.fromfile(f, np.int32, count=1)[0] + # print("Reading %d x %d flo file" % (h, w)) + data2d = np.fromfile(f, np.float32, count=2 * w * h) + # reshape data into 3D array (columns, rows, channels) + data2d = np.resize(data2d, (h, w, 2)) + f.close() + return data2d.astype(np.float32) + +class LossHandler: + def __init__(self): + self.loss_dict = {} + self.count_sample = 0 + + def add_loss(self, key, loss): + if key not in self.loss_dict: + self.loss_dict[key] = 0 + self.loss_dict[key] += loss + + def get_loss(self, key): + return self.loss_dict[key] / self.count_sample + + def count_one_sample(self): + self.count_sample += 1 + + def reset(self): + self.loss_dict = {} + self.count_sample = 0 + + +class TimeHandler: + def __init__(self): + self.time_handler = {} + + def compute_time(self, key): + if key not in self.time_handler: + self.time_handler[key] = time.time() + return None + else: + return time.time() - self.time_handler.pop(key) + + +def print_num_params(model, is_trainable=False): + model_name = model.__class__.__name__.ljust(30) + + if is_trainable: + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"| TRAINABLE | {model_name} | {('{:,}'.format(num_params)).rjust(10)} |") + else: + num_params = sum(p.numel() for p in model.parameters()) + print(f"| GENERAL | {model_name} | {('{:,}'.format(num_params)).rjust(10)} |") + + return num_params diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/convert_folder_to_video.py b/utils/convert_folder_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0e79c2197e56028de5079cc0f59775a67dd65f --- /dev/null +++ b/utils/convert_folder_to_video.py @@ -0,0 +1,29 @@ +import cv2 +import os +import argparse +from tqdm import tqdm + +def convert_frames_to_video(input_folder_path, output_path, fps=24): + list_frames = sorted(os.listdir(input_folder_path)) + first_frame = cv2.imread(os.path.join(input_folder_path, list_frames[0])) + height, width, _ = first_frame.shape + + # Create a VideoWriter object + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Use appropriate codec based on the file extension + video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + + for frame_file in tqdm(list_frames): + frame_path = os.path.join(input_folder_path, frame_file) + frame = cv2.imread(frame_path) + video_writer.write(frame) + + video_writer.release() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_folder_path', type=str, required=True) + parser.add_argument('--output_path', type=str, required=True) + parser.add_argument('--fps', type=int, default=24) + args = parser.parse_args() + + convert_frames_to_video(args.input_folder_path, args.output_path, args.fps) \ No newline at end of file