diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..53aee21b9a286bf9d1904e9527c3f0b047f4f0ea 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8b3d8a077a8756a19b61019756a549494972cac0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,138 @@
+flagged/
+sample_output/
+wandb/
+.vscode
+.DS_Store
+*ckpt*/
+# Custom
+*.pt
+data/local
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/README.md b/README.md
index c908a08c25f0cbf1578787a063aa2365ce0e1167..9a1ce52391d49d316e66b141754ddbf8f9bc2e06 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
---
-title: SwinTExCo Cpu
+title: Exemplar-based Video Colorization using Vision Transformer (CPU version)
emoji: 🏃
colorFrom: green
colorTo: yellow
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..effa29f423f1ba41ba153d380e06277753c1990b
--- /dev/null
+++ b/app.py
@@ -0,0 +1,50 @@
+import gradio as gr
+from src.inference import SwinTExCo
+import cv2
+import os
+from PIL import Image
+import time
+import app_config as cfg
+
+
+model = SwinTExCo(weights_path=cfg.ckpt_path)
+
+def video_colorization(video_path, ref_image, progress=gr.Progress()):
+ # Initialize video reader
+ video_reader = cv2.VideoCapture(video_path)
+ fps = video_reader.get(cv2.CAP_PROP_FPS)
+ height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
+ num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ # Initialize reference image
+ ref_image = Image.fromarray(ref_image)
+
+ # Initialize video writer
+ output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
+ video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
+
+ # Init progress bar
+
+ for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
+ colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
+ video_writer.write(colorized_frame)
+
+ # for i in progress.tqdm(range(1000)):
+ # time.sleep(0.5)
+
+ video_writer.release()
+
+ return output_path
+
+app = gr.Interface(
+ fn=video_colorization,
+ inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
+ gr.Image(sources="upload", label="Reference image (color)")],
+ outputs=gr.Video(label="Output video (colorized)"),
+ title=cfg.TITLE,
+ description=cfg.DESCRIPTION
+).queue()
+
+
+app.launch()
\ No newline at end of file
diff --git a/app_config.py b/app_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2727ead5fdcf6eb4d135a55504089ceeb2a93a9
--- /dev/null
+++ b/app_config.py
@@ -0,0 +1,9 @@
+ckpt_path = 'checkpoints/epoch_20'
+TITLE = 'Deep Exemplar-based Video Colorization using Vision Transformer'
+DESCRIPTION = '''
+
+This is a demo app of the thesis: Deep Exemplar-based Video Colorization using Vision Transformer.
+The code is available at: The link will be updated soon.
+Our previous work was also written into paper and accepted at the ICTC 2023 conference (Section B1-4).
+
+'''.strip()
\ No newline at end of file
diff --git a/checkpoints/epoch_10/colornet.pth b/checkpoints/epoch_10/colornet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f3a8d9c2b872f34bd7080d01dfbf6d4d22bbe3ab
--- /dev/null
+++ b/checkpoints/epoch_10/colornet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ecb43b5e02b77bec5342e2e296d336bf8f384a07d3c809d1a548fd5fb1e7365
+size 131239411
diff --git a/checkpoints/epoch_10/discriminator.pth b/checkpoints/epoch_10/discriminator.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9e14193d0dcd96bc08ca72cf330320a72d3cfdd5
--- /dev/null
+++ b/checkpoints/epoch_10/discriminator.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce8968a9d3d2f99b1bc1e32080507e0d671cee00b66200105c8839be684b84b4
+size 45073068
diff --git a/checkpoints/epoch_10/embed_net.pth b/checkpoints/epoch_10/embed_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0439349777f69682d5b01e03b96659ad64c817c9
--- /dev/null
+++ b/checkpoints/epoch_10/embed_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc711755a75c43025dabe9407cbd11d164eaa9e21f26430d0c16c7493410d902
+size 110352261
diff --git a/checkpoints/epoch_10/learning_state.pth b/checkpoints/epoch_10/learning_state.pth
new file mode 100644
index 0000000000000000000000000000000000000000..81ab6499ba83f0d580c2668e4bf2915a1a1f1ff2
--- /dev/null
+++ b/checkpoints/epoch_10/learning_state.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d09b1e96fdf0205930a21928449a44c51cedd965cc0d573068c73971bcb8bd2
+size 748166487
diff --git a/checkpoints/epoch_10/nonlocal_net.pth b/checkpoints/epoch_10/nonlocal_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..389e42c55fcd2e6a853e1f3a288f23fd3d653a8d
--- /dev/null
+++ b/checkpoints/epoch_10/nonlocal_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86c97d6803d625a0dff8c6c09b70852371906eb5ef77df0277c27875666a68e2
+size 73189765
diff --git a/checkpoints/epoch_12/colornet.pth b/checkpoints/epoch_12/colornet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d56a75ae928a1e69d125437d714b997a3736b8e7
--- /dev/null
+++ b/checkpoints/epoch_12/colornet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:50f4b92cd59f4c88c0c1d7c93652413d54b1b96d729fc4b93e235887b5164f28
+size 131239846
diff --git a/checkpoints/epoch_12/discriminator.pth b/checkpoints/epoch_12/discriminator.pth
new file mode 100644
index 0000000000000000000000000000000000000000..80f3701e5da8d9ba91e1e59e91c6fc42c228245a
--- /dev/null
+++ b/checkpoints/epoch_12/discriminator.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b54b0bad6ceec33569cc5833cbf03ed8ddbb5f07998aa634badf8298d3cd15f
+size 45073513
diff --git a/checkpoints/epoch_12/embed_net.pth b/checkpoints/epoch_12/embed_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..a7c02eee1819cb220e835fa88db8c0265bb783f8
--- /dev/null
+++ b/checkpoints/epoch_12/embed_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
+size 110352698
diff --git a/checkpoints/epoch_12/learning_state.pth b/checkpoints/epoch_12/learning_state.pth
new file mode 100644
index 0000000000000000000000000000000000000000..b2d2a3012bb8c67cb8d7ad3630316b78b0a42341
--- /dev/null
+++ b/checkpoints/epoch_12/learning_state.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f8bb4dbb3cb8e497a9a2079947f0221823fa8b44695e2d2ad8478be48464fad
+size 748166934
diff --git a/checkpoints/epoch_12/nonlocal_net.pth b/checkpoints/epoch_12/nonlocal_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..bf47c4949ef6096e234d8582179d65881d97e9f2
--- /dev/null
+++ b/checkpoints/epoch_12/nonlocal_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1f76b53dad7bf15c7d26aa106c95387e75751b8c31fafef2bd73ea7d77160cb
+size 73190208
diff --git a/checkpoints/epoch_16/colornet.pth b/checkpoints/epoch_16/colornet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8ad3e949edc5ef3b1b9fbe8242c1be89d8a69398
--- /dev/null
+++ b/checkpoints/epoch_16/colornet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81ec9cff0ad5b0d920179fa7a9cc229e1424bfc796b7134604ff66b97d748c49
+size 131239846
diff --git a/checkpoints/epoch_16/discriminator.pth b/checkpoints/epoch_16/discriminator.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d417db3092c3558beb444bc8e1f34784d207c2a
--- /dev/null
+++ b/checkpoints/epoch_16/discriminator.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42262d5ed7596f38e65774085222530eee57da8dfaa7fe1aa223d824ed166f62
+size 45073513
diff --git a/checkpoints/epoch_16/embed_net.pth b/checkpoints/epoch_16/embed_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..a7c02eee1819cb220e835fa88db8c0265bb783f8
--- /dev/null
+++ b/checkpoints/epoch_16/embed_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
+size 110352698
diff --git a/checkpoints/epoch_16/learning_state.pth b/checkpoints/epoch_16/learning_state.pth
new file mode 100644
index 0000000000000000000000000000000000000000..b0c7771adc7c37e9c5e873fde1ab4dfa881e9d6e
--- /dev/null
+++ b/checkpoints/epoch_16/learning_state.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea4cf81341750ebf517c696a0f6241bfeede0584b0ce75ad208e3ffc8280877f
+size 748166934
diff --git a/checkpoints/epoch_16/nonlocal_net.pth b/checkpoints/epoch_16/nonlocal_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..6e7933d9f6e20f85894a67ca281028007e47e035
--- /dev/null
+++ b/checkpoints/epoch_16/nonlocal_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85b63363bc9c79732df78ba50ed19491ed86e961214bbd1f796a871334eba516
+size 73190208
diff --git a/checkpoints/epoch_20/colornet.pth b/checkpoints/epoch_20/colornet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..283002997e001fbe870daa4c868165fab041fae0
--- /dev/null
+++ b/checkpoints/epoch_20/colornet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c524f4e5df5f6ce91db1973a30de55299ebcbbde1edd2009718d3b4cd2631339
+size 131239846
diff --git a/checkpoints/epoch_20/discriminator.pth b/checkpoints/epoch_20/discriminator.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d9bdc788c644fe5a20280acea1f20ebf9fe55014
--- /dev/null
+++ b/checkpoints/epoch_20/discriminator.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcd80950c796fcfe6e4b6bdeeb358776700458d868da94ee31df3d1d37779310
+size 45073513
diff --git a/checkpoints/epoch_20/embed_net.pth b/checkpoints/epoch_20/embed_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..a7c02eee1819cb220e835fa88db8c0265bb783f8
--- /dev/null
+++ b/checkpoints/epoch_20/embed_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
+size 110352698
diff --git a/checkpoints/epoch_20/learning_state.pth b/checkpoints/epoch_20/learning_state.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f2efcf054b44b72c12ce845a30f2c37bdfd50535
--- /dev/null
+++ b/checkpoints/epoch_20/learning_state.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b1163b210b246b07d8f1c50eb3766d97c6f03bf409c854d00b7c69edb6d7391
+size 748166934
diff --git a/checkpoints/epoch_20/nonlocal_net.pth b/checkpoints/epoch_20/nonlocal_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f68eb2413624faabd93ff65986e7c13be4da1485
--- /dev/null
+++ b/checkpoints/epoch_20/nonlocal_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:031e5f38cc79eb3c0ed51ca2ad3c8921fdda2fa05946c357f84881259de74e6d
+size 73190208
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..025946634ce9270d63ec6dd5e0443f78a768ce54
Binary files /dev/null and b/requirements.txt differ
diff --git a/sample_input/ref1.jpg b/sample_input/ref1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d2ed0a6f318d7b1a4716ef64422575c5d2078afe
Binary files /dev/null and b/sample_input/ref1.jpg differ
diff --git a/sample_input/video1.mp4 b/sample_input/video1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e96476e110511c201d972fc780d39a833abb92f1
--- /dev/null
+++ b/sample_input/video1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:077ebcd3cf6c020c95732e74a0fe1fab9b80102bc14d5e201b12c4917e0c0d1d
+size 1011726
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data/dataloader.py b/src/data/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..16fb6d288c37b1dd450f6d00a60909e26b75c3bb
--- /dev/null
+++ b/src/data/dataloader.py
@@ -0,0 +1,332 @@
+import numpy as np
+import pandas as pd
+from src.utils import (
+ CenterPadCrop_numpy,
+ Distortion_with_flow_cpu,
+ Distortion_with_flow_gpu,
+ Normalize,
+ RGB2Lab,
+ ToTensor,
+ Normalize,
+ RGB2Lab,
+ ToTensor,
+ CenterPad,
+ read_flow,
+ SquaredPadding
+)
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+from numpy import random
+import os
+from PIL import Image
+from scipy.ndimage.filters import gaussian_filter
+from scipy.ndimage import map_coordinates
+import glob
+
+
+
+def image_loader(path):
+ with open(path, "rb") as f:
+ with Image.open(f) as img:
+ return img.convert("RGB")
+
+
+class CenterCrop(object):
+ """
+ center crop the numpy array
+ """
+
+ def __init__(self, image_size):
+ self.h0, self.w0 = image_size
+
+ def __call__(self, input_numpy):
+ if input_numpy.ndim == 3:
+ h, w, channel = input_numpy.shape
+ output_numpy = np.zeros((self.h0, self.w0, channel))
+ output_numpy = input_numpy[
+ (h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0, :
+ ]
+ else:
+ h, w = input_numpy.shape
+ output_numpy = np.zeros((self.h0, self.w0))
+ output_numpy = input_numpy[
+ (h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0
+ ]
+ return output_numpy
+
+
+class VideosDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ video_data_root,
+ flow_data_root,
+ mask_data_root,
+ imagenet_folder,
+ annotation_file_path,
+ image_size,
+ num_refs=5, # max = 20
+ image_transform=None,
+ real_reference_probability=1,
+ nonzero_placeholder_probability=0.5,
+ ):
+ self.video_data_root = video_data_root
+ self.flow_data_root = flow_data_root
+ self.mask_data_root = mask_data_root
+ self.imagenet_folder = imagenet_folder
+ self.image_transform = image_transform
+ self.CenterPad = CenterPad(image_size)
+ self.Resize = transforms.Resize(image_size)
+ self.ToTensor = ToTensor()
+ self.CenterCrop = transforms.CenterCrop(image_size)
+ self.SquaredPadding = SquaredPadding(image_size[0])
+ self.num_refs = num_refs
+
+ assert os.path.exists(self.video_data_root), "find no video dataroot"
+ assert os.path.exists(self.flow_data_root), "find no flow dataroot"
+ assert os.path.exists(self.imagenet_folder), "find no imagenet folder"
+ # self.epoch = epoch
+ self.image_pairs = pd.read_csv(annotation_file_path, dtype=str)
+ self.real_len = len(self.image_pairs)
+ # self.image_pairs = pd.concat([self.image_pairs] * self.epoch, ignore_index=True)
+ self.real_reference_probability = real_reference_probability
+ self.nonzero_placeholder_probability = nonzero_placeholder_probability
+ print("##### parsing image pairs in %s: %d pairs #####" % (video_data_root, self.__len__()))
+
+ def __getitem__(self, index):
+ (
+ video_name,
+ prev_frame,
+ current_frame,
+ flow_forward_name,
+ mask_name,
+ reference_1_name,
+ reference_2_name,
+ reference_3_name,
+ reference_4_name,
+ reference_5_name
+ ) = self.image_pairs.iloc[index, :5+self.num_refs].values.tolist()
+
+ video_path = os.path.join(self.video_data_root, video_name)
+ flow_path = os.path.join(self.flow_data_root, video_name)
+ mask_path = os.path.join(self.mask_data_root, video_name)
+
+ prev_frame_path = os.path.join(video_path, prev_frame)
+ current_frame_path = os.path.join(video_path, current_frame)
+ list_frame_path = glob.glob(os.path.join(video_path, '*'))
+ list_frame_path.sort()
+
+ reference_1_path = os.path.join(self.imagenet_folder, reference_1_name)
+ reference_2_path = os.path.join(self.imagenet_folder, reference_2_name)
+ reference_3_path = os.path.join(self.imagenet_folder, reference_3_name)
+ reference_4_path = os.path.join(self.imagenet_folder, reference_4_name)
+ reference_5_path = os.path.join(self.imagenet_folder, reference_5_name)
+
+ flow_forward_path = os.path.join(flow_path, flow_forward_name)
+ mask_path = os.path.join(mask_path, mask_name)
+
+ #reference_gt_1_path = prev_frame_path
+ #reference_gt_2_path = current_frame_path
+ try:
+ I1 = Image.open(prev_frame_path).convert("RGB")
+ I2 = Image.open(current_frame_path).convert("RGB")
+ try:
+ I_reference_video = Image.open(list_frame_path[0]).convert("RGB") # Get first frame
+ except:
+ I_reference_video = Image.open(current_frame_path).convert("RGB") # Get current frame if error
+
+ reference_list = [reference_1_path, reference_2_path, reference_3_path, reference_4_path, reference_5_path]
+ while reference_list: # run until getting the colorized reference
+ reference_path = random.choice(reference_list)
+ I_reference_video_real = Image.open(reference_path)
+ if I_reference_video_real.mode == 'L':
+ reference_list.remove(reference_path)
+ else:
+ break
+ if not reference_list:
+ I_reference_video_real = I_reference_video
+
+ flow_forward = read_flow(flow_forward_path) # numpy
+
+ mask = Image.open(mask_path) # PIL
+ mask = self.Resize(mask)
+ mask = np.array(mask)
+ # mask = self.SquaredPadding(mask, return_pil=False, return_paddings=False)
+ # binary mask
+ mask[mask < 240] = 0
+ mask[mask >= 240] = 1
+ mask = self.ToTensor(mask)
+
+ # transform
+ I1 = self.image_transform(I1)
+ I2 = self.image_transform(I2)
+ I_reference_video = self.image_transform(I_reference_video)
+ I_reference_video_real = self.image_transform(I_reference_video_real)
+ flow_forward = self.ToTensor(flow_forward)
+ flow_forward = self.Resize(flow_forward)#, return_pil=False, return_paddings=False, dtype=np.float32)
+
+
+ if np.random.random() < self.real_reference_probability:
+ I_reference_output = I_reference_video_real # Use reference from imagenet
+ placeholder = torch.zeros_like(I1)
+ self_ref_flag = torch.zeros_like(I1)
+ else:
+ I_reference_output = I_reference_video # Use reference from ground truth
+ placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
+ self_ref_flag = torch.ones_like(I1)
+
+ outputs = [
+ I1,
+ I2,
+ I_reference_output,
+ flow_forward,
+ mask,
+ placeholder,
+ self_ref_flag,
+ video_name + prev_frame,
+ video_name + current_frame,
+ reference_path
+ ]
+
+ except Exception as e:
+ print("error in reading image pair: %s" % str(self.image_pairs[index]))
+ print(e)
+ return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
+ return outputs
+
+ def __len__(self):
+ return len(self.image_pairs)
+
+
+def parse_imgnet_images(pairs_file):
+ pairs = []
+ with open(pairs_file, "r") as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.strip().split("|")
+ image_a = line[0]
+ image_b = line[1]
+ pairs.append((image_a, image_b))
+ return pairs
+
+
+class VideosDataset_ImageNet(data.Dataset):
+ def __init__(
+ self,
+ imagenet_data_root,
+ pairs_file,
+ image_size,
+ transforms_imagenet=None,
+ distortion_level=3,
+ brightnessjitter=0,
+ nonzero_placeholder_probability=0.5,
+ extra_reference_transform=None,
+ real_reference_probability=1,
+ distortion_device='cpu'
+ ):
+ self.imagenet_data_root = imagenet_data_root
+ self.image_pairs = pd.read_csv(pairs_file, names=['i1', 'i2'])
+ self.transforms_imagenet_raw = transforms_imagenet
+ self.extra_reference_transform = transforms.Compose(extra_reference_transform)
+ self.real_reference_probability = real_reference_probability
+ self.transforms_imagenet = transforms.Compose(transforms_imagenet)
+ self.image_size = image_size
+ self.real_len = len(self.image_pairs)
+ self.distortion_level = distortion_level
+ self.distortion_transform = Distortion_with_flow_cpu() if distortion_device == 'cpu' else Distortion_with_flow_gpu()
+ self.brightnessjitter = brightnessjitter
+ self.flow_transform = transforms.Compose([CenterPadCrop_numpy(self.image_size), ToTensor()])
+ self.nonzero_placeholder_probability = nonzero_placeholder_probability
+ self.ToTensor = ToTensor()
+ self.Normalize = Normalize()
+ print("##### parsing imageNet pairs in %s: %d pairs #####" % (imagenet_data_root, self.__len__()))
+
+ def __getitem__(self, index):
+ pa, pb = self.image_pairs.iloc[index].values.tolist()
+ if np.random.random() > 0.5:
+ pa, pb = pb, pa
+
+ image_a_path = os.path.join(self.imagenet_data_root, pa)
+ image_b_path = os.path.join(self.imagenet_data_root, pb)
+
+ I1 = image_loader(image_a_path)
+ I2 = I1
+ I_reference_video = I1
+ I_reference_video_real = image_loader(image_b_path)
+ # print("i'm here get image 2")
+ # generate the flow
+ alpha = np.random.rand() * self.distortion_level
+ distortion_range = 50
+ random_state = np.random.RandomState(None)
+ shape = self.image_size[0], self.image_size[1]
+ # dx: flow on the vertical direction; dy: flow on the horizontal direction
+ forward_dx = (
+ gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
+ )
+ forward_dy = (
+ gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
+ )
+ # print("i'm here get image 3")
+ for transform in self.transforms_imagenet_raw:
+ if type(transform) is RGB2Lab:
+ I1_raw = I1
+ I1 = transform(I1)
+ for transform in self.transforms_imagenet_raw:
+ if type(transform) is RGB2Lab:
+ I2 = self.distortion_transform(I2, forward_dx, forward_dy)
+ I2_raw = I2
+ I2 = transform(I2)
+ # print("i'm here get image 4")
+ I2[0:1, :, :] = I2[0:1, :, :] + torch.randn(1) * self.brightnessjitter
+
+ I_reference_video = self.extra_reference_transform(I_reference_video)
+ for transform in self.transforms_imagenet_raw:
+ I_reference_video = transform(I_reference_video)
+
+ I_reference_video_real = self.transforms_imagenet(I_reference_video_real)
+ # print("i'm here get image 5")
+ flow_forward_raw = np.stack((forward_dy, forward_dx), axis=-1)
+ flow_forward = self.flow_transform(flow_forward_raw)
+
+ # update the mask for the pixels on the border
+ grid_x, grid_y = np.meshgrid(np.arange(self.image_size[0]), np.arange(self.image_size[1]), indexing="ij")
+ grid = np.stack((grid_y, grid_x), axis=-1)
+ grid_warp = grid + flow_forward_raw
+ location_y = grid_warp[:, :, 0].flatten()
+ location_x = grid_warp[:, :, 1].flatten()
+ I2_raw = np.array(I2_raw).astype(float)
+ I21_r = map_coordinates(I2_raw[:, :, 0], np.stack((location_x, location_y)), cval=-1).reshape(
+ (self.image_size[0], self.image_size[1])
+ )
+ I21_g = map_coordinates(I2_raw[:, :, 1], np.stack((location_x, location_y)), cval=-1).reshape(
+ (self.image_size[0], self.image_size[1])
+ )
+ I21_b = map_coordinates(I2_raw[:, :, 2], np.stack((location_x, location_y)), cval=-1).reshape(
+ (self.image_size[0], self.image_size[1])
+ )
+ I21_raw = np.stack((I21_r, I21_g, I21_b), axis=2)
+ mask = np.ones((self.image_size[0], self.image_size[1]))
+ mask[(I21_raw[:, :, 0] == -1) & (I21_raw[:, :, 1] == -1) & (I21_raw[:, :, 2] == -1)] = 0
+ mask[abs(I21_raw - I1_raw).sum(axis=-1) > 50] = 0
+ mask = self.ToTensor(mask)
+ # print("i'm here get image 6")
+ if np.random.random() < self.real_reference_probability:
+ I_reference_output = I_reference_video_real
+ placeholder = torch.zeros_like(I1)
+ self_ref_flag = torch.zeros_like(I1)
+ else:
+ I_reference_output = I_reference_video
+ placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
+ self_ref_flag = torch.ones_like(I1)
+
+ # except Exception as e:
+ # if combo_path is not None:
+ # print("problem in ", combo_path)
+ # print("problem in, ", image_a_path)
+ # print(e)
+ # return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
+ # print("i'm here get image 7")
+ return [I1, I2, I_reference_output, flow_forward, mask, placeholder, self_ref_flag, "holder", pb, pa]
+
+ def __len__(self):
+ return len(self.image_pairs)
\ No newline at end of file
diff --git a/src/data/functional.py b/src/data/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..14aa7882d3dfca1ba6649d0b7fdb2c443e3b7f20
--- /dev/null
+++ b/src/data/functional.py
@@ -0,0 +1,84 @@
+from __future__ import division
+
+import torch
+import numbers
+import collections
+import numpy as np
+from PIL import Image, ImageOps
+
+
+def _is_pil_image(img):
+ return isinstance(img, Image.Image)
+
+
+def _is_tensor_image(img):
+ return torch.is_tensor(img) and img.ndimension() == 3
+
+
+def _is_numpy_image(img):
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
+
+
+def to_mytensor(pic):
+ pic_arr = np.array(pic)
+ if pic_arr.ndim == 2:
+ pic_arr = pic_arr[..., np.newaxis]
+ img = torch.from_numpy(pic_arr.transpose((2, 0, 1)))
+ if not isinstance(img, torch.FloatTensor):
+ return img.float() # no normalize .div(255)
+ else:
+ return img
+
+
+def normalize(tensor, mean, std):
+ if not _is_tensor_image(tensor):
+ raise TypeError("tensor is not a torch image.")
+ if tensor.size(0) == 1:
+ tensor.sub_(mean).div_(std)
+ else:
+ for t, m, s in zip(tensor, mean, std):
+ t.sub_(m).div_(s)
+ return tensor
+
+
+def resize(img, size, interpolation=Image.BILINEAR):
+ if not _is_pil_image(img):
+ raise TypeError("img should be PIL Image. Got {}".format(type(img)))
+ if not isinstance(size, int) and (not isinstance(size, collections.Iterable) or len(size) != 2):
+ raise TypeError("Got inappropriate size arg: {}".format(size))
+
+ if not isinstance(size, int):
+ return img.resize(size[::-1], interpolation)
+
+ w, h = img.size
+ if (w <= h and w == size) or (h <= w and h == size):
+ return img
+ if w < h:
+ ow = size
+ oh = int(round(size * h / w))
+ else:
+ oh = size
+ ow = int(round(size * w / h))
+ return img.resize((ow, oh), interpolation)
+
+
+def pad(img, padding, fill=0):
+ if not _is_pil_image(img):
+ raise TypeError("img should be PIL Image. Got {}".format(type(img)))
+
+ if not isinstance(padding, (numbers.Number, tuple)):
+ raise TypeError("Got inappropriate padding arg")
+ if not isinstance(fill, (numbers.Number, str, tuple)):
+ raise TypeError("Got inappropriate fill arg")
+
+ if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)))
+
+ return ImageOps.expand(img, border=padding, fill=fill)
+
+
+def crop(img, i, j, h, w):
+ if not _is_pil_image(img):
+ raise TypeError("img should be PIL Image. Got {}".format(type(img)))
+
+ return img.crop((j, i, j + w, i + h))
diff --git a/src/data/transforms.py b/src/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1ce42b501ed66802df534491a633498af7f824c
--- /dev/null
+++ b/src/data/transforms.py
@@ -0,0 +1,348 @@
+from __future__ import division
+
+import collections
+import numbers
+import random
+
+import torch
+from PIL import Image
+from skimage import color
+
+import src.data.functional as F
+
+__all__ = [
+ "Compose",
+ "Concatenate",
+ "ToTensor",
+ "Normalize",
+ "Resize",
+ "Scale",
+ "CenterCrop",
+ "Pad",
+ "RandomCrop",
+ "RandomHorizontalFlip",
+ "RandomVerticalFlip",
+ "RandomResizedCrop",
+ "RandomSizedCrop",
+ "FiveCrop",
+ "TenCrop",
+ "RGB2Lab",
+]
+
+
+def CustomFunc(inputs, func, *args, **kwargs):
+ im_l = func(inputs[0], *args, **kwargs)
+ im_ab = func(inputs[1], *args, **kwargs)
+ warp_ba = func(inputs[2], *args, **kwargs)
+ warp_aba = func(inputs[3], *args, **kwargs)
+ im_gbl_ab = func(inputs[4], *args, **kwargs)
+ bgr_mc_im = func(inputs[5], *args, **kwargs)
+
+ layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
+
+ for l in range(5):
+ layer = inputs[6 + l]
+ err_ba = func(layer[0], *args, **kwargs)
+ err_ab = func(layer[1], *args, **kwargs)
+
+ layer_data.append([err_ba, err_ab])
+
+ return layer_data
+
+
+class Compose(object):
+ """Composes several transforms together.
+
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms to compose.
+
+ Example:
+ >>> transforms.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, inputs):
+ for t in self.transforms:
+ inputs = t(inputs)
+ return inputs
+
+
+class Concatenate(object):
+ """
+ Input: [im_l, im_ab, inputs]
+ inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba]
+
+ Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba]
+ """
+
+ def __call__(self, inputs):
+ im_l = inputs[0]
+ im_ab = inputs[1]
+ warp_ba = inputs[2]
+ warp_aba = inputs[3]
+ im_glb_ab = inputs[4]
+ bgr_mc_im = inputs[5]
+ bgr_mc_im = bgr_mc_im[[2, 1, 0], ...]
+
+ err_ba = []
+ err_ab = []
+
+ for l in range(5):
+ layer = inputs[6 + l]
+ err_ba.append(layer[0])
+ err_ab.append(layer[1])
+
+ cerr_ba = torch.cat(err_ba, 0)
+ cerr_ab = torch.cat(err_ab, 0)
+
+ return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab)
+
+
+class ToTensor(object):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
+ """
+
+ def __call__(self, inputs):
+ """
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+ Returns:
+ Tensor: Converted image.
+ """
+ return CustomFunc(inputs, F.to_mytensor)
+
+
+class Normalize(object):
+ """Normalize an tensor image with mean and standard deviation.
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
+
+ Args:
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ """
+
+ def __call__(self, inputs):
+ """
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+
+ Returns:
+ Tensor: Normalized Tensor image.
+ """
+
+ im_l = F.normalize(inputs[0], 50, 1) # [0, 100]
+ im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100]
+
+ inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1)
+ inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1))
+ warp_ba = inputs[2]
+
+ inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1)
+ inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1))
+ warp_aba = inputs[3]
+
+ im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100]
+
+ bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1))
+
+ layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
+
+ for l in range(5):
+ layer = inputs[6 + l]
+ err_ba = F.normalize(layer[0], 127, 2) # [0, 255]
+ err_ab = F.normalize(layer[1], 127, 2) # [0, 255]
+ layer_data.append([err_ba, err_ab])
+
+ return layer_data
+
+
+class Resize(object):
+ """Resize the input PIL Image to the given size.
+
+ Args:
+ size (sequence or int): Desired output size. If size is a sequence like
+ (h, w), output size will be matched to this. If size is an int,
+ smaller edge of the image will be matched to this number.
+ i.e, if height > width, then image will be rescaled to
+ (size * height / width, size)
+ interpolation (int, optional): Desired interpolation. Default is
+ ``PIL.Image.BILINEAR``
+ """
+
+ def __init__(self, size, interpolation=Image.BILINEAR):
+ assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, inputs):
+ """
+ Args:
+ img (PIL Image): Image to be scaled.
+
+ Returns:
+ PIL Image: Rescaled image.
+ """
+ return CustomFunc(inputs, F.resize, self.size, self.interpolation)
+
+
+class RandomCrop(object):
+ """Crop the given PIL Image at a random location.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ padding (int or sequence, optional): Optional padding on each border
+ of the image. Default is 0, i.e no padding. If a sequence of length
+ 4 is provided, it is used to pad left, top, right, bottom borders
+ respectively.
+ """
+
+ def __init__(self, size, padding=0):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+
+ @staticmethod
+ def get_params(img, output_size):
+ """Get parameters for ``crop`` for a random crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ output_size (tuple): Expected output size of the crop.
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+ """
+ w, h = img.size
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = random.randint(0, h - th)
+ j = random.randint(0, w - tw)
+ return i, j, th, tw
+
+ def __call__(self, inputs):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+
+ Returns:
+ PIL Image: Cropped image.
+ """
+ if self.padding > 0:
+ inputs = CustomFunc(inputs, F.pad, self.padding)
+
+ i, j, h, w = self.get_params(inputs[0], self.size)
+ return CustomFunc(inputs, F.crop, i, j, h, w)
+
+
+class CenterCrop(object):
+ """Crop the given PIL Image at a random location.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ padding (int or sequence, optional): Optional padding on each border
+ of the image. Default is 0, i.e no padding. If a sequence of length
+ 4 is provided, it is used to pad left, top, right, bottom borders
+ respectively.
+ """
+
+ def __init__(self, size, padding=0):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+
+ @staticmethod
+ def get_params(img, output_size):
+ """Get parameters for ``crop`` for a random crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ output_size (tuple): Expected output size of the crop.
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+ """
+ w, h = img.size
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = (h - th) // 2
+ j = (w - tw) // 2
+ return i, j, th, tw
+
+ def __call__(self, inputs):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+
+ Returns:
+ PIL Image: Cropped image.
+ """
+ if self.padding > 0:
+ inputs = CustomFunc(inputs, F.pad, self.padding)
+
+ i, j, h, w = self.get_params(inputs[0], self.size)
+ return CustomFunc(inputs, F.crop, i, j, h, w)
+
+
+class RandomHorizontalFlip(object):
+ """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
+
+ def __call__(self, inputs):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+
+ if random.random() < 0.5:
+ return CustomFunc(inputs, F.hflip)
+ return inputs
+
+
+class RGB2Lab(object):
+ def __call__(self, inputs):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+
+ def __call__(self, inputs):
+ image_lab = color.rgb2lab(inputs[0])
+ warp_ba_lab = color.rgb2lab(inputs[2])
+ warp_aba_lab = color.rgb2lab(inputs[3])
+ im_gbl_lab = color.rgb2lab(inputs[4])
+
+ inputs[0] = image_lab[:, :, :1] # l channel
+ inputs[1] = image_lab[:, :, 1:] # ab channel
+ inputs[2] = warp_ba_lab # lab channel
+ inputs[3] = warp_aba_lab # lab channel
+ inputs[4] = im_gbl_lab[:, :, 1:] # ab channel
+
+ return inputs
diff --git a/src/inference.py b/src/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2306f659e4ab5684058c8ccfe7d560de97e84815
--- /dev/null
+++ b/src/inference.py
@@ -0,0 +1,174 @@
+from src.models.CNN.ColorVidNet import ColorVidNet
+from src.models.vit.embed import SwinModel
+from src.models.CNN.NonlocalNet import WarpNet
+from src.models.CNN.FrameColor import frame_colorization
+import torch
+from src.models.vit.utils import load_params
+import os
+import cv2
+from PIL import Image
+from PIL import ImageEnhance as IE
+import torchvision.transforms as T
+from src.utils import (
+ RGB2Lab,
+ ToTensor,
+ Normalize,
+ uncenter_l,
+ tensor_lab2rgb
+)
+import numpy as np
+from tqdm import tqdm
+
+class SwinTExCo:
+ def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
+ if device == None:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ else:
+ self.device = device
+
+ self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device)
+ self.nonlocal_net = WarpNet(feature_channel=128).to(self.device)
+ self.colornet = ColorVidNet(7).to(self.device)
+
+ self.embed_net.eval()
+ self.nonlocal_net.eval()
+ self.colornet.eval()
+
+ self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth"))
+ self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth"))
+ self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth"))
+
+ self.processor = T.Compose([
+ T.Resize((224,224)),
+ RGB2Lab(),
+ ToTensor(),
+ Normalize()
+ ])
+
+ pass
+
+ def __load_models(self, model, weight_path):
+ params = load_params(weight_path, self.device)
+ model.load_state_dict(params, strict=True)
+
+ def __preprocess_reference(self, img):
+ color_enhancer = IE.Color(img)
+ img = color_enhancer.enhance(1.5)
+ return img
+
+ def __upscale_image(self, large_IA_l, I_current_ab_predict):
+ H, W = large_IA_l.shape[2:]
+ large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict,
+ size=(H,W),
+ mode="bilinear",
+ align_corners=False)
+ large_IA_l = torch.cat((large_IA_l, large_current_ab_predict.cpu()), dim=1)
+ large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
+ return large_current_rgb_predict
+
+ def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
+ large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
+ large_IA_l = large_IA_lab[:, 0:1, :, :]
+
+ IA_lab = self.processor(curr_frame)
+ IA_lab = IA_lab.unsqueeze(0).to(self.device)
+ IA_l = IA_lab[:, 0:1, :, :]
+ if I_last_lab_predict is None:
+ I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device)
+
+
+ with torch.no_grad():
+ I_current_ab_predict, _ = frame_colorization(
+ IA_l,
+ I_reference_lab,
+ I_last_lab_predict,
+ features_B,
+ self.embed_net,
+ self.nonlocal_net,
+ self.colornet,
+ luminance_noise=0,
+ temperature=1e-10,
+ joint_training=False
+ )
+ I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
+
+ IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict)
+ IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.)
+ IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8)
+
+ return I_last_lab_predict, IA_predict_rgb
+
+ def predict_video(self, video, ref_image):
+ ref_image = self.__preprocess_reference(ref_image)
+
+ I_last_lab_predict = None
+
+ IB_lab = self.processor(ref_image)
+ IB_lab = IB_lab.unsqueeze(0).to(self.device)
+
+ with torch.no_grad():
+ I_reference_lab = IB_lab
+ I_reference_l = I_reference_lab[:, 0:1, :, :]
+ I_reference_ab = I_reference_lab[:, 1:3, :, :]
+ I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
+ features_B = self.embed_net(I_reference_rgb)
+
+ #PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame")
+ while video.isOpened():
+ #PBAR.update(1)
+ ret, curr_frame = video.read()
+
+ if not ret:
+ break
+
+ curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
+ curr_frame = Image.fromarray(curr_frame)
+
+ I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
+
+ IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
+
+ yield IA_predict_rgb
+
+ #PBAR.close()
+ video.release()
+
+ def predict_image(self, image, ref_image):
+ ref_image = self.__preprocess_reference(ref_image)
+
+ I_last_lab_predict = None
+
+ IB_lab = self.processor(ref_image)
+ IB_lab = IB_lab.unsqueeze(0).to(self.device)
+
+ with torch.no_grad():
+ I_reference_lab = IB_lab
+ I_reference_l = I_reference_lab[:, 0:1, :, :]
+ I_reference_ab = I_reference_lab[:, 1:3, :, :]
+ I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
+ features_B = self.embed_net(I_reference_rgb)
+
+ curr_frame = image
+ I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
+
+ IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
+
+ return IA_predict_rgb
+
+if __name__ == "__main__":
+ model = SwinTExCo('checkpoints/epoch_20/')
+
+ # Initialize video reader and writer
+ video = cv2.VideoCapture('sample_input/video_2.mp4')
+ fps = video.get(cv2.CAP_PROP_FPS)
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
+
+ # Initialize reference image
+ ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB')
+
+ for colorized_frame in model.predict_video(video, ref_image):
+ video_writer.write(colorized_frame)
+
+ video_writer.release()
\ No newline at end of file
diff --git a/src/losses.py b/src/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd78f9226bdee39354fa8fb31a05e4aefeb9e55d
--- /dev/null
+++ b/src/losses.py
@@ -0,0 +1,277 @@
+import torch
+import torch.nn as nn
+from src.utils import feature_normalize
+
+
+### START### CONTEXTUAL LOSS ####
+class ContextualLoss(nn.Module):
+ """
+ input is Al, Bl, channel = 1, range ~ [0, 255]
+ """
+
+ def __init__(self):
+ super(ContextualLoss, self).__init__()
+ return None
+
+ def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
+ """
+ X_features&Y_features are are feature vectors or feature 2d array
+ h: bandwidth
+ return the per-sample loss
+ """
+ batch_size = X_features.shape[0]
+ feature_depth = X_features.shape[1]
+
+ # to normalized feature vectors
+ if feature_centering:
+ X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
+ dim=-1
+ )
+ Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
+ dim=-1
+ )
+ X_features = feature_normalize(X_features).view(
+ batch_size, feature_depth, -1
+ ) # batch_size * feature_depth * feature_size^2
+ Y_features = feature_normalize(Y_features).view(
+ batch_size, feature_depth, -1
+ ) # batch_size * feature_depth * feature_size^2
+
+ # conine distance = 1 - similarity
+ X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
+ d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
+
+ # normalized distance: dij_bar
+ d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
+
+ # pairwise affinity
+ w = torch.exp((1 - d_norm) / h)
+ A_ij = w / torch.sum(w, dim=-1, keepdim=True)
+
+ # contextual loss per sample
+ CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1)
+ return -torch.log(CX)
+
+
+class ContextualLoss_forward(nn.Module):
+ """
+ input is Al, Bl, channel = 1, range ~ [0, 255]
+ """
+
+ def __init__(self):
+ super(ContextualLoss_forward, self).__init__()
+ return None
+
+ def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
+ """
+ X_features&Y_features are are feature vectors or feature 2d array
+ h: bandwidth
+ return the per-sample loss
+ """
+ batch_size = X_features.shape[0]
+ feature_depth = X_features.shape[1]
+
+ # to normalized feature vectors
+ if feature_centering:
+ X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
+ dim=-1
+ )
+ Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
+ dim=-1
+ )
+ X_features = feature_normalize(X_features).view(
+ batch_size, feature_depth, -1
+ ) # batch_size * feature_depth * feature_size^2
+ Y_features = feature_normalize(Y_features).view(
+ batch_size, feature_depth, -1
+ ) # batch_size * feature_depth * feature_size^2
+
+ # conine distance = 1 - similarity
+ X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
+ d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
+
+ # normalized distance: dij_bar
+ d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
+
+ # pairwise affinity
+ w = torch.exp((1 - d_norm) / h)
+ A_ij = w / torch.sum(w, dim=-1, keepdim=True)
+
+ # contextual loss per sample
+ CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
+ return -torch.log(CX)
+
+
+### END### CONTEXTUAL LOSS ####
+
+
+##########################
+
+
+def mse_loss_fn(input, target=0):
+ return torch.mean((input - target) ** 2)
+
+
+### START### PERCEPTUAL LOSS ###
+def Perceptual_loss(domain_invariant, weight_perceptual):
+ instancenorm = nn.InstanceNorm2d(512, affine=False)
+
+ def __call__(A_relu5_1, predict_relu5_1):
+ if domain_invariant:
+ feat_loss = (
+ mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2
+ )
+ else:
+ feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual
+ return feat_loss
+
+ return __call__
+
+
+### END### PERCEPTUAL LOSS ###
+
+
+def l1_loss_fn(input, target=0):
+ return torch.mean(torch.abs(input - target))
+
+
+### END#################
+
+
+### START### ADVERSIAL LOSS ###
+def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device):
+ if weight_gan > 0:
+ y_pred_fake, _ = discriminator(fake_data_lab)
+ y_pred_real, _ = discriminator(real_data_lab)
+
+ y = torch.ones_like(y_pred_real)
+ generator_loss = (
+ (
+ torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2)
+ + torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2)
+ )
+ / 2
+ * weight_gan
+ )
+ return generator_loss
+
+ return torch.Tensor([0]).to(device)
+
+
+def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator):
+ y_pred_fake, _ = discriminator(fake_data_lab.detach())
+ y_pred_real, _ = discriminator(real_data_lab.detach())
+
+ y = torch.ones_like(y_pred_real)
+ discriminator_loss = (
+ torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2)
+ + torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2)
+ ) / 2
+ return discriminator_loss
+
+
+### END### ADVERSIAL LOSS #####
+
+
+def consistent_loss_fn(
+ I_current_lab_predict,
+ I_last_ab_predict,
+ I_current_nonlocal_lab_predict,
+ I_last_nonlocal_lab_predict,
+ flow_forward,
+ mask,
+ warping_layer,
+ weight_consistent=0.02,
+ weight_nonlocal_consistent=0.0,
+ device="cuda",
+):
+ def weighted_mse_loss(input, target, weights):
+ out = (input - target) ** 2
+ out = out * weights.expand_as(out)
+ return out.mean()
+
+ def consistent():
+ I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward)
+ I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :]
+ consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent
+ return consistent_loss
+
+ def nonlocal_consistent():
+ I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward)
+ nonlocal_consistent_loss = (
+ weighted_mse_loss(
+ I_current_nonlocal_lab_predict_warp[:, 1:3, :, :],
+ I_last_nonlocal_lab_predict[:, 1:3, :, :],
+ mask,
+ )
+ * weight_nonlocal_consistent
+ )
+
+ return nonlocal_consistent_loss
+
+ consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device)
+ nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device)
+
+ return consistent_loss + nonlocal_consistent_loss
+
+
+### END### CONSISTENCY LOSS #####
+
+
+### START### SMOOTHNESS LOSS ###
+def smoothness_loss_fn(
+ I_current_l,
+ I_current_lab,
+ I_current_ab_predict,
+ A_relu2_1,
+ weighted_layer_color,
+ nonlocal_weighted_layer,
+ weight_smoothness=5.0,
+ weight_nonlocal_smoothness=0.0,
+ device="cuda",
+):
+ def smoothness(scale_factor=1.0):
+ I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
+ IA_ab_weighed = weighted_layer_color(
+ I_current_lab,
+ I_current_lab_predict,
+ patch_size=3,
+ alpha=10,
+ scale_factor=scale_factor,
+ )
+ smoothness_loss = (
+ mse_loss_fn(
+ nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
+ IA_ab_weighed,
+ )
+ * weight_smoothness
+ )
+
+ return smoothness_loss
+
+ def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5):
+ nonlocal_smooth_feature = feature_normalize(A_relu2_1)
+ I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
+ I_current_ab_weighted_nonlocal = nonlocal_weighted_layer(
+ I_current_lab_predict,
+ nonlocal_smooth_feature.detach(),
+ patch_size=3,
+ alpha=alpha_nonlocal_smoothness,
+ scale_factor=scale_factor,
+ )
+ nonlocal_smoothness_loss = (
+ mse_loss_fn(
+ nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
+ I_current_ab_weighted_nonlocal,
+ )
+ * weight_nonlocal_smoothness
+ )
+ return nonlocal_smoothness_loss
+
+ smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device)
+ nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device)
+
+ return smoothness_loss + nonlocal_smoothness_loss
+
+
+### END### SMOOTHNESS LOSS #####
diff --git a/src/metrics.py b/src/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cfff45b2dda234756c3d809263be3143a47e952
--- /dev/null
+++ b/src/metrics.py
@@ -0,0 +1,225 @@
+from skimage.metrics import structural_similarity, peak_signal_noise_ratio
+import numpy as np
+import lpips
+import torch
+from pytorch_fid.fid_score import calculate_frechet_distance
+from pytorch_fid.inception import InceptionV3
+import torch.nn as nn
+import cv2
+from scipy import stats
+import os
+
+def calc_ssim(pred_image, gt_image):
+ '''
+ Structural Similarity Index (SSIM) is a perceptual metric that quantifies the image quality degradation that is
+ caused by processing such as data compression or by losses in data transmission.
+
+ # Arguments
+ img1: PIL.Image
+ img2: PIL.Image
+ # Returns
+ ssim: float (-1.0, 1.0)
+ '''
+ pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
+ gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
+ ssim = structural_similarity(pred_image, gt_image, channel_axis=2, data_range=255.)
+ return ssim
+
+def calc_psnr(pred_image, gt_image):
+ '''
+ Peak Signal-to-Noise Ratio (PSNR) is an expression for the ratio between the maximum possible value (power) of a signal
+ and the power of distorting noise that affects the quality of its representation.
+
+ # Arguments
+ img1: PIL.Image
+ img2: PIL.Image
+ # Returns
+ psnr: float
+ '''
+ pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
+ gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
+
+ psnr = peak_signal_noise_ratio(gt_image, pred_image, data_range=255.)
+ return psnr
+
+class LPIPS_utils:
+ def __init__(self, device = 'cuda'):
+ self.loss_fn = lpips.LPIPS(net='vgg', spatial=True) # Can set net = 'squeeze' or 'vgg'or 'alex'
+ self.loss_fn = self.loss_fn.to(device)
+ self.device = device
+
+ def compare_lpips(self,img_fake, img_real, data_range=255.): # input: torch 1 c h w / h w c
+ img_fake = torch.from_numpy(np.array(img_fake).astype(np.float32)/data_range)
+ img_real = torch.from_numpy(np.array(img_real).astype(np.float32)/data_range)
+ if img_fake.ndim==3:
+ img_fake = img_fake.permute(2,0,1).unsqueeze(0)
+ img_real = img_real.permute(2,0,1).unsqueeze(0)
+ img_fake = img_fake.to(self.device)
+ img_real = img_real.to(self.device)
+
+ dist = self.loss_fn.forward(img_fake,img_real)
+ return dist.mean().item()
+
+class FID_utils(nn.Module):
+ """Class for computing the Fréchet Inception Distance (FID) metric score.
+ It is implemented as a class in order to hold the inception model instance
+ in its state.
+ Parameters
+ ----------
+ resize_input : bool (optional)
+ Whether or not to resize the input images to the image size (299, 299)
+ on which the inception model was trained. Since the model is fully
+ convolutional, the score also works without resizing. In literature
+ and when working with GANs people tend to set this value to True,
+ however, for internal evaluation this is not necessary.
+ device : str or torch.device
+ The device on which to run the inception model.
+ """
+
+ def __init__(self, resize_input=True, device="cuda"):
+ super(FID_utils, self).__init__()
+ self.device = device
+ if self.device is None:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ #self.model = InceptionV3(resize_input=resize_input).to(device)
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
+ self.model = InceptionV3([block_idx]).to(device)
+ self.model = self.model.eval()
+
+ def get_activations(self,batch): # 1 c h w
+ with torch.no_grad():
+ pred = self.model(batch)[0]
+ # If model output is not scalar, apply global spatial average pooling.
+ # This happens if you choose a dimensionality not equal 2048.
+ if pred.size(2) != 1 or pred.size(3) != 1:
+ #pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+ print("error in get activations!")
+ #pred = pred.squeeze(3).squeeze(2).cpu().numpy()
+ return pred
+
+
+ def _get_mu_sigma(self, batch,data_range):
+ """Compute the inception mu and sigma for a batch of images.
+ Parameters
+ ----------
+ images : np.ndarray
+ A batch of images with shape (n_images,3, width, height).
+ Returns
+ -------
+ mu : np.ndarray
+ The array of mean activations with shape (2048,).
+ sigma : np.ndarray
+ The covariance matrix of activations with shape (2048, 2048).
+ """
+ # forward pass
+ if batch.ndim ==3 and batch.shape[2]==3:
+ batch=batch.permute(2,0,1).unsqueeze(0)
+ batch /= data_range
+ #batch = torch.tensor(batch)#.unsqueeze(1).repeat((1, 3, 1, 1))
+ batch = batch.to(self.device, torch.float32)
+ #(activations,) = self.model(batch)
+ activations = self.get_activations(batch)
+ activations = activations.detach().cpu().numpy().squeeze(3).squeeze(2)
+
+ # compute statistics
+ mu = np.mean(activations,axis=0)
+ sigma = np.cov(activations, rowvar=False)
+
+ return mu, sigma
+
+ def score(self, images_1, images_2, data_range=255.):
+ """Compute the FID score.
+ The input batches should have the shape (n_images,3, width, height). or (h,w,3)
+ Parameters
+ ----------
+ images_1 : np.ndarray
+ First batch of images.
+ images_2 : np.ndarray
+ Section batch of images.
+ Returns
+ -------
+ score : float
+ The FID score.
+ """
+ images_1 = torch.from_numpy(np.array(images_1).astype(np.float32))
+ images_2 = torch.from_numpy(np.array(images_2).astype(np.float32))
+ images_1 = images_1.to(self.device)
+ images_2 = images_2.to(self.device)
+
+ mu_1, sigma_1 = self._get_mu_sigma(images_1,data_range)
+ mu_2, sigma_2 = self._get_mu_sigma(images_2,data_range)
+ score = calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2)
+
+ return score
+
+def JS_divergence(p, q):
+ M = (p + q) / 2
+ return 0.5 * stats.entropy(p, M) + 0.5 * stats.entropy(q, M)
+
+
+def compute_JS_bgr(input_dir, dilation=1):
+ input_img_list = os.listdir(input_dir)
+ input_img_list.sort()
+ # print(input_img_list)
+
+ hist_b_list = [] # [img1_histb, img2_histb, ...]
+ hist_g_list = []
+ hist_r_list = []
+
+ for img_name in input_img_list:
+ # print(os.path.join(input_dir, img_name))
+ img_in = cv2.imread(os.path.join(input_dir, img_name))
+ H, W, C = img_in.shape
+
+ hist_b = cv2.calcHist([img_in], [0], None, [256], [0,256]) # B
+ hist_g = cv2.calcHist([img_in], [1], None, [256], [0,256]) # G
+ hist_r = cv2.calcHist([img_in], [2], None, [256], [0,256]) # R
+
+ hist_b = hist_b / (H * W)
+ hist_g = hist_g / (H * W)
+ hist_r = hist_r / (H * W)
+
+ hist_b_list.append(hist_b)
+ hist_g_list.append(hist_g)
+ hist_r_list.append(hist_r)
+
+ JS_b_list = []
+ JS_g_list = []
+ JS_r_list = []
+
+ for i in range(len(hist_b_list)):
+ if i + dilation > len(hist_b_list) - 1:
+ break
+ hist_b_img1 = hist_b_list[i]
+ hist_b_img2 = hist_b_list[i + dilation]
+ JS_b = JS_divergence(hist_b_img1, hist_b_img2)
+ JS_b_list.append(JS_b)
+
+ hist_g_img1 = hist_g_list[i]
+ hist_g_img2 = hist_g_list[i+dilation]
+ JS_g = JS_divergence(hist_g_img1, hist_g_img2)
+ JS_g_list.append(JS_g)
+
+ hist_r_img1 = hist_r_list[i]
+ hist_r_img2 = hist_r_list[i+dilation]
+ JS_r = JS_divergence(hist_r_img1, hist_r_img2)
+ JS_r_list.append(JS_r)
+
+ return JS_b_list, JS_g_list, JS_r_list
+
+
+def calc_cdc(vid_folder, dilation=[1, 2, 4], weight=[1/3, 1/3, 1/3]):
+ mean_b, mean_g, mean_r = 0, 0, 0
+ for d, w in zip(dilation, weight):
+ JS_b_list_one, JS_g_list_one, JS_r_list_one = compute_JS_bgr(vid_folder, d)
+ mean_b += w * np.mean(JS_b_list_one)
+ mean_g += w * np.mean(JS_g_list_one)
+ mean_r += w * np.mean(JS_r_list_one)
+
+ cdc = np.mean([mean_b, mean_g, mean_r])
+ return cdc
+
+
+
+
+
\ No newline at end of file
diff --git a/src/models/CNN/ColorVidNet.py b/src/models/CNN/ColorVidNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0394f06a27582f950cd8df598eef9be715e26242
--- /dev/null
+++ b/src/models/CNN/ColorVidNet.py
@@ -0,0 +1,141 @@
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+
+class ColorVidNet(nn.Module):
+ def __init__(self, ic):
+ super(ColorVidNet, self).__init__()
+ self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1))
+ self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv1_2norm = nn.BatchNorm2d(64, affine=False)
+ self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64)
+ self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
+ self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
+ self.conv2_2norm = nn.BatchNorm2d(128, affine=False)
+ self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128)
+ self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
+ self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
+ self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
+ self.conv3_3norm = nn.BatchNorm2d(256, affine=False)
+ self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256)
+ self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
+ self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
+ self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
+ self.conv4_3norm = nn.BatchNorm2d(512, affine=False)
+ self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
+ self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
+ self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
+ self.conv5_3norm = nn.BatchNorm2d(512, affine=False)
+ self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
+ self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
+ self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
+ self.conv6_3norm = nn.BatchNorm2d(512, affine=False)
+ self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1)
+ self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1)
+ self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1)
+ self.conv7_3norm = nn.BatchNorm2d(512, affine=False)
+ self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
+ self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1)
+ self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1)
+ self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1)
+ self.conv8_3norm = nn.BatchNorm2d(256, affine=False)
+ self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
+ self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1)
+ self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1)
+ self.conv9_2norm = nn.BatchNorm2d(128, affine=False)
+ self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1)
+ self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1)
+ self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1)
+ self.conv10_ab = nn.Conv2d(128, 2, 1, 1)
+
+ # add self.relux_x
+ self.relu1_1 = nn.PReLU()
+ self.relu1_2 = nn.PReLU()
+ self.relu2_1 = nn.PReLU()
+ self.relu2_2 = nn.PReLU()
+ self.relu3_1 = nn.PReLU()
+ self.relu3_2 = nn.PReLU()
+ self.relu3_3 = nn.PReLU()
+ self.relu4_1 = nn.PReLU()
+ self.relu4_2 = nn.PReLU()
+ self.relu4_3 = nn.PReLU()
+ self.relu5_1 = nn.PReLU()
+ self.relu5_2 = nn.PReLU()
+ self.relu5_3 = nn.PReLU()
+ self.relu6_1 = nn.PReLU()
+ self.relu6_2 = nn.PReLU()
+ self.relu6_3 = nn.PReLU()
+ self.relu7_1 = nn.PReLU()
+ self.relu7_2 = nn.PReLU()
+ self.relu7_3 = nn.PReLU()
+ self.relu8_1_comb = nn.PReLU()
+ self.relu8_2 = nn.PReLU()
+ self.relu8_3 = nn.PReLU()
+ self.relu9_1_comb = nn.PReLU()
+ self.relu9_2 = nn.PReLU()
+ self.relu10_1_comb = nn.PReLU()
+ self.relu10_2 = nn.LeakyReLU(0.2, True)
+
+ self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1))
+ self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1))
+ self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1))
+
+ self.conv1_2norm = nn.InstanceNorm2d(64)
+ self.conv2_2norm = nn.InstanceNorm2d(128)
+ self.conv3_3norm = nn.InstanceNorm2d(256)
+ self.conv4_3norm = nn.InstanceNorm2d(512)
+ self.conv5_3norm = nn.InstanceNorm2d(512)
+ self.conv6_3norm = nn.InstanceNorm2d(512)
+ self.conv7_3norm = nn.InstanceNorm2d(512)
+ self.conv8_3norm = nn.InstanceNorm2d(256)
+ self.conv9_2norm = nn.InstanceNorm2d(128)
+
+ def forward(self, x):
+ """x: gray image (1 channel), ab(2 channel), ab_err, ba_err"""
+ conv1_1 = self.relu1_1(self.conv1_1(x))
+ conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
+ conv1_2norm = self.conv1_2norm(conv1_2)
+ conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm)
+ conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss))
+ conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
+ conv2_2norm = self.conv2_2norm(conv2_2)
+ conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm)
+ conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss))
+ conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
+ conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
+ conv3_3norm = self.conv3_3norm(conv3_3)
+ conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm)
+ conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss))
+ conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
+ conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
+ conv4_3norm = self.conv4_3norm(conv4_3)
+ conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm))
+ conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
+ conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
+ conv5_3norm = self.conv5_3norm(conv5_3)
+ conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm))
+ conv6_2 = self.relu6_2(self.conv6_2(conv6_1))
+ conv6_3 = self.relu6_3(self.conv6_3(conv6_2))
+ conv6_3norm = self.conv6_3norm(conv6_3)
+ conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm))
+ conv7_2 = self.relu7_2(self.conv7_2(conv7_1))
+ conv7_3 = self.relu7_3(self.conv7_3(conv7_2))
+ conv7_3norm = self.conv7_3norm(conv7_3)
+ conv8_1 = self.conv8_1(conv7_3norm)
+ conv3_3_short = self.conv3_3_short(conv3_3norm)
+ conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short)
+ conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb))
+ conv8_3 = self.relu8_3(self.conv8_3(conv8_2))
+ conv8_3norm = self.conv8_3norm(conv8_3)
+ conv9_1 = self.conv9_1(conv8_3norm)
+ conv2_2_short = self.conv2_2_short(conv2_2norm)
+ conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short)
+ conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb))
+ conv9_2norm = self.conv9_2norm(conv9_2)
+ conv10_1 = self.conv10_1(conv9_2norm)
+ conv1_2_short = self.conv1_2_short(conv1_2norm)
+ conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short)
+ conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb))
+ conv10_ab = self.conv10_ab(conv10_2)
+
+ return torch.tanh(conv10_ab) * 128
diff --git a/src/models/CNN/FrameColor.py b/src/models/CNN/FrameColor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4f4e7d0f712aef415909ddac3786f84d7de665
--- /dev/null
+++ b/src/models/CNN/FrameColor.py
@@ -0,0 +1,76 @@
+import torch
+from src.utils import *
+from src.models.vit.vit import FeatureTransform
+
+
+def warp_color(
+ IA_l,
+ IB_lab,
+ features_B,
+ embed_net,
+ nonlocal_net,
+ temperature=0.01,
+):
+ IA_rgb_from_gray = gray2rgb_batch(IA_l)
+
+ with torch.no_grad():
+ A_feat0, A_feat1, A_feat2, A_feat3 = embed_net(IA_rgb_from_gray)
+ B_feat0, B_feat1, B_feat2, B_feat3 = features_B
+
+ A_feat0 = feature_normalize(A_feat0)
+ A_feat1 = feature_normalize(A_feat1)
+ A_feat2 = feature_normalize(A_feat2)
+ A_feat3 = feature_normalize(A_feat3)
+
+ B_feat0 = feature_normalize(B_feat0)
+ B_feat1 = feature_normalize(B_feat1)
+ B_feat2 = feature_normalize(B_feat2)
+ B_feat3 = feature_normalize(B_feat3)
+
+ return nonlocal_net(
+ IB_lab,
+ A_feat0,
+ A_feat1,
+ A_feat2,
+ A_feat3,
+ B_feat0,
+ B_feat1,
+ B_feat2,
+ B_feat3,
+ temperature=temperature,
+ )
+
+
+def frame_colorization(
+ IA_l,
+ IB_lab,
+ IA_last_lab,
+ features_B,
+ embed_net,
+ nonlocal_net,
+ colornet,
+ joint_training=True,
+ luminance_noise=0,
+ temperature=0.01,
+):
+ if luminance_noise:
+ IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise
+
+ with torch.autograd.set_grad_enabled(joint_training):
+ nonlocal_BA_lab, similarity_map = warp_color(
+ IA_l,
+ IB_lab,
+ features_B,
+ embed_net,
+ nonlocal_net,
+ temperature=temperature,
+ )
+ nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :]
+ IA_ab_predict = colornet(
+ torch.cat(
+ (IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab),
+ dim=1,
+ )
+ )
+
+ return IA_ab_predict, nonlocal_BA_lab
\ No newline at end of file
diff --git a/src/models/CNN/GAN_models.py b/src/models/CNN/GAN_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..b82bdfd32f27126d03e20c8f4c9d9c1bdf1c4803
--- /dev/null
+++ b/src/models/CNN/GAN_models.py
@@ -0,0 +1,212 @@
+# DCGAN-like generator and discriminator
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+def l2normalize(v, eps=1e-12):
+ return v / (v.norm() + eps)
+
+
+class SpectralNorm(nn.Module):
+ def __init__(self, module, name="weight", power_iterations=1):
+ super(SpectralNorm, self).__init__()
+ self.module = module
+ self.name = name
+ self.power_iterations = power_iterations
+ if not self._made_params():
+ self._make_params()
+
+ def _update_u_v(self):
+ u = getattr(self.module, self.name + "_u")
+ v = getattr(self.module, self.name + "_v")
+ w = getattr(self.module, self.name + "_bar")
+
+ height = w.data.shape[0]
+ for _ in range(self.power_iterations):
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
+
+ sigma = u.dot(w.view(height, -1).mv(v))
+ setattr(self.module, self.name, w / sigma.expand_as(w))
+
+ def _made_params(self):
+ try:
+ u = getattr(self.module, self.name + "_u")
+ v = getattr(self.module, self.name + "_v")
+ w = getattr(self.module, self.name + "_bar")
+ return True
+ except AttributeError:
+ return False
+
+ def _make_params(self):
+ w = getattr(self.module, self.name)
+
+ height = w.data.shape[0]
+ width = w.view(height, -1).data.shape[1]
+
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
+ u.data = l2normalize(u.data)
+ v.data = l2normalize(v.data)
+ w_bar = Parameter(w.data)
+
+ del self.module._parameters[self.name]
+
+ self.module.register_parameter(self.name + "_u", u)
+ self.module.register_parameter(self.name + "_v", v)
+ self.module.register_parameter(self.name + "_bar", w_bar)
+
+ def forward(self, *args):
+ self._update_u_v()
+ return self.module.forward(*args)
+
+
+class Generator(nn.Module):
+ def __init__(self, z_dim):
+ super(Generator, self).__init__()
+ self.z_dim = z_dim
+
+ self.model = nn.Sequential(
+ nn.ConvTranspose2d(z_dim, 512, 4, stride=1),
+ nn.InstanceNorm2d(512),
+ nn.ReLU(),
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)),
+ nn.InstanceNorm2d(256),
+ nn.ReLU(),
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)),
+ nn.InstanceNorm2d(128),
+ nn.ReLU(),
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)),
+ nn.InstanceNorm2d(64),
+ nn.ReLU(),
+ nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1, 1)),
+ nn.Tanh(),
+ )
+
+ def forward(self, z):
+ return self.model(z.view(-1, self.z_dim, 1, 1))
+
+
+channels = 3
+leak = 0.1
+w_g = 4
+
+
+class Discriminator(nn.Module):
+ def __init__(self):
+ super(Discriminator, self).__init__()
+
+ self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1, 1)))
+ self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1, 1)))
+ self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1, 1)))
+ self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1, 1)))
+ self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1, 1)))
+ self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1, 1)))
+ self.conv7 = SpectralNorm(nn.Conv2d(256, 256, 3, stride=1, padding=(1, 1)))
+ self.conv8 = SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1, 1)))
+ self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, 1))
+
+ def forward(self, x):
+ m = x
+ m = nn.LeakyReLU(leak)(self.conv1(m))
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(64)(self.conv2(m)))
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv3(m)))
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv4(m)))
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv5(m)))
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv6(m)))
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv7(m)))
+ m = nn.LeakyReLU(leak)(self.conv8(m))
+
+ return self.fc(m.view(-1, w_g * w_g * 512))
+
+
+class Self_Attention(nn.Module):
+ """Self attention Layer"""
+
+ def __init__(self, in_dim):
+ super(Self_Attention, self).__init__()
+ self.chanel_in = in_dim
+
+ self.query_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
+ self.key_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
+ self.value_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1))
+ self.gamma = nn.Parameter(torch.zeros(1))
+
+ self.softmax = nn.Softmax(dim=-1) #
+
+ def forward(self, x):
+ """
+ inputs :
+ x : input feature maps( B X C X W X H)
+ returns :
+ out : self attention value + input feature
+ attention: B X N X N (N is Width*Height)
+ """
+ m_batchsize, C, width, height = x.size()
+ proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
+ proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
+ energy = torch.bmm(proj_query, proj_key) # transpose check
+ attention = self.softmax(energy) # BX (N) X (N)
+ proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
+
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
+ out = out.view(m_batchsize, C, width, height)
+
+ out = self.gamma * out + x
+ return out
+
+class Discriminator_x64_224(nn.Module):
+ """
+ Discriminative Network
+ """
+
+ def __init__(self, in_size=6, ndf=64):
+ super(Discriminator_x64_224, self).__init__()
+ self.in_size = in_size
+ self.ndf = ndf
+
+ self.layer1 = nn.Sequential(SpectralNorm(nn.Conv2d(self.in_size, self.ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True))
+
+ self.layer2 = nn.Sequential(
+ SpectralNorm(nn.Conv2d(self.ndf, self.ndf, 4, 2, 1)),
+ nn.InstanceNorm2d(self.ndf),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+ self.attention = Self_Attention(self.ndf)
+ self.layer3 = nn.Sequential(
+ SpectralNorm(nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1)),
+ nn.InstanceNorm2d(self.ndf * 2),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+ self.layer4 = nn.Sequential(
+ SpectralNorm(nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1)),
+ nn.InstanceNorm2d(self.ndf * 4),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+ self.layer5 = nn.Sequential(
+ SpectralNorm(nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1)),
+ nn.InstanceNorm2d(self.ndf * 8),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+ self.layer6 = nn.Sequential(
+ SpectralNorm(nn.Conv2d(self.ndf * 8, self.ndf * 16, 4, 2, 1)),
+ nn.InstanceNorm2d(self.ndf * 16),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+
+ self.last = SpectralNorm(nn.Conv2d(self.ndf * 16, 1, [3, 3], 1, 0))
+
+ def forward(self, input):
+ feature1 = self.layer1(input)
+ feature2 = self.layer2(feature1)
+ feature_attention = self.attention(feature2)
+ feature3 = self.layer3(feature_attention)
+ feature4 = self.layer4(feature3)
+ feature5 = self.layer5(feature4)
+ feature6 = self.layer6(feature5)
+ output = self.last(feature6)
+ output = F.avg_pool2d(output, output.size()[2:]).view(output.size()[0], -1)
+
+ return output, feature4
diff --git a/src/models/CNN/NonlocalNet.py b/src/models/CNN/NonlocalNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce995da91cda1f65ee4bdf3e26cf3e2992f70b0
--- /dev/null
+++ b/src/models/CNN/NonlocalNet.py
@@ -0,0 +1,437 @@
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from src.utils import uncenter_l
+
+
+def find_local_patch(x, patch_size):
+ """
+ > We take a tensor `x` and return a tensor `x_unfold` that contains all the patches of size
+ `patch_size` in `x`
+
+ Args:
+ x: the input tensor
+ patch_size: the size of the patch to be extracted.
+ """
+
+ N, C, H, W = x.shape
+ x_unfold = F.unfold(x, kernel_size=(patch_size, patch_size), padding=(patch_size // 2, patch_size // 2), stride=(1, 1))
+
+ return x_unfold.view(N, x_unfold.shape[1], H, W)
+
+
+class WeightedAverage(nn.Module):
+ def __init__(
+ self,
+ ):
+ super(WeightedAverage, self).__init__()
+
+ def forward(self, x_lab, patch_size=3, alpha=1, scale_factor=1):
+ """
+ It takes a 3-channel image (L, A, B) and returns a 2-channel image (A, B) where each pixel is a
+ weighted average of the A and B values of the pixels in a 3x3 neighborhood around it
+
+ Args:
+ x_lab: the input image in LAB color space
+ patch_size: the size of the patch to use for the local average. Defaults to 3
+ alpha: the higher the alpha, the smoother the output. Defaults to 1
+ scale_factor: the scale factor of the input image. Defaults to 1
+
+ Returns:
+ The output of the forward function is a tensor of size (batch_size, 2, height, width)
+ """
+ # alpha=0: less smooth; alpha=inf: smoother
+ x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
+ l = x_lab[:, 0:1, :, :]
+ a = x_lab[:, 1:2, :, :]
+ b = x_lab[:, 2:3, :, :]
+ local_l = find_local_patch(l, patch_size)
+ local_a = find_local_patch(a, patch_size)
+ local_b = find_local_patch(b, patch_size)
+ local_difference_l = (local_l - l) ** 2
+ correlation = nn.functional.softmax(-1 * local_difference_l / alpha, dim=1)
+
+ return torch.cat(
+ (
+ torch.sum(correlation * local_a, dim=1, keepdim=True),
+ torch.sum(correlation * local_b, dim=1, keepdim=True),
+ ),
+ 1,
+ )
+
+
+class WeightedAverage_color(nn.Module):
+ """
+ smooth the image according to the color distance in the LAB space
+ """
+
+ def __init__(
+ self,
+ ):
+ super(WeightedAverage_color, self).__init__()
+
+ def forward(self, x_lab, x_lab_predict, patch_size=3, alpha=1, scale_factor=1):
+ """
+ It takes the predicted a and b channels, and the original a and b channels, and finds the
+ weighted average of the predicted a and b channels based on the similarity of the original a and
+ b channels to the predicted a and b channels
+
+ Args:
+ x_lab: the input image in LAB color space
+ x_lab_predict: the predicted LAB image
+ patch_size: the size of the patch to use for the local color correction. Defaults to 3
+ alpha: controls the smoothness of the output. Defaults to 1
+ scale_factor: the scale factor of the input image. Defaults to 1
+
+ Returns:
+ The return is the weighted average of the local a and b channels.
+ """
+ """ alpha=0: less smooth; alpha=inf: smoother """
+ x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
+ l = uncenter_l(x_lab[:, 0:1, :, :])
+ a = x_lab[:, 1:2, :, :]
+ b = x_lab[:, 2:3, :, :]
+ a_predict = x_lab_predict[:, 1:2, :, :]
+ b_predict = x_lab_predict[:, 2:3, :, :]
+ local_l = find_local_patch(l, patch_size)
+ local_a = find_local_patch(a, patch_size)
+ local_b = find_local_patch(b, patch_size)
+ local_a_predict = find_local_patch(a_predict, patch_size)
+ local_b_predict = find_local_patch(b_predict, patch_size)
+
+ local_color_difference = (local_l - l) ** 2 + (local_a - a) ** 2 + (local_b - b) ** 2
+ # so that sum of weights equal to 1
+ correlation = nn.functional.softmax(-1 * local_color_difference / alpha, dim=1)
+
+ return torch.cat(
+ (
+ torch.sum(correlation * local_a_predict, dim=1, keepdim=True),
+ torch.sum(correlation * local_b_predict, dim=1, keepdim=True),
+ ),
+ 1,
+ )
+
+
+class NonlocalWeightedAverage(nn.Module):
+ def __init__(
+ self,
+ ):
+ super(NonlocalWeightedAverage, self).__init__()
+
+ def forward(self, x_lab, feature, patch_size=3, alpha=0.1, scale_factor=1):
+ """
+ It takes in a feature map and a label map, and returns a smoothed label map
+
+ Args:
+ x_lab: the input image in LAB color space
+ feature: the feature map of the input image
+ patch_size: the size of the patch to be used for the correlation matrix. Defaults to 3
+ alpha: the higher the alpha, the smoother the output.
+ scale_factor: the scale factor of the input image. Defaults to 1
+
+ Returns:
+ weighted_ab is the weighted ab channel of the image.
+ """
+ # alpha=0: less smooth; alpha=inf: smoother
+ # input feature is normalized feature
+ x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
+ batch_size, channel, height, width = x_lab.shape
+ feature = F.interpolate(feature, size=(height, width))
+ batch_size = x_lab.shape[0]
+ x_ab = x_lab[:, 1:3, :, :].view(batch_size, 2, -1)
+ x_ab = x_ab.permute(0, 2, 1)
+
+ local_feature = find_local_patch(feature, patch_size)
+ local_feature = local_feature.view(batch_size, local_feature.shape[1], -1)
+
+ correlation_matrix = torch.matmul(local_feature.permute(0, 2, 1), local_feature)
+ correlation_matrix = nn.functional.softmax(correlation_matrix / alpha, dim=-1)
+
+ weighted_ab = torch.matmul(correlation_matrix, x_ab)
+ weighted_ab = weighted_ab.permute(0, 2, 1).contiguous()
+ weighted_ab = weighted_ab.view(batch_size, 2, height, width)
+ return weighted_ab
+
+
+class CorrelationLayer(nn.Module):
+ def __init__(self, search_range):
+ super(CorrelationLayer, self).__init__()
+ self.search_range = search_range
+
+ def forward(self, x1, x2, alpha=1, raw_output=False, metric="similarity"):
+ """
+ It takes two tensors, x1 and x2, and returns a tensor of shape (batch_size, (search_range * 2 +
+ 1) ** 2, height, width) where each element is the dot product of the corresponding patch in x1
+ and x2
+
+ Args:
+ x1: the first image
+ x2: the image to be warped
+ alpha: the temperature parameter for the softmax function. Defaults to 1
+ raw_output: if True, return the raw output of the network, otherwise return the softmax
+ output. Defaults to False
+ metric: "similarity" or "subtraction". Defaults to similarity
+
+ Returns:
+ The output of the forward function is a softmax of the correlation volume.
+ """
+ shape = list(x1.size())
+ shape[1] = (self.search_range * 2 + 1) ** 2
+ cv = torch.zeros(shape).to(torch.device("cuda"))
+
+ for i in range(-self.search_range, self.search_range + 1):
+ for j in range(-self.search_range, self.search_range + 1):
+ if i < 0:
+ slice_h, slice_h_r = slice(None, i), slice(-i, None)
+ elif i > 0:
+ slice_h, slice_h_r = slice(i, None), slice(None, -i)
+ else:
+ slice_h, slice_h_r = slice(None), slice(None)
+
+ if j < 0:
+ slice_w, slice_w_r = slice(None, j), slice(-j, None)
+ elif j > 0:
+ slice_w, slice_w_r = slice(j, None), slice(None, -j)
+ else:
+ slice_w, slice_w_r = slice(None), slice(None)
+
+ if metric == "similarity":
+ cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = (
+ x1[:, :, slice_h, slice_w] * x2[:, :, slice_h_r, slice_w_r]
+ ).sum(1)
+ else: # patchwise subtraction
+ cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = -(
+ (x1[:, :, slice_h, slice_w] - x2[:, :, slice_h_r, slice_w_r]) ** 2
+ ).sum(1)
+
+ # TODO sigmoid?
+ if raw_output:
+ return cv
+ else:
+ return nn.functional.softmax(cv / alpha, dim=1)
+
+
+class WTA_scale(torch.autograd.Function):
+ """
+ We can implement our own custom autograd Functions by subclassing
+ torch.autograd.Function and implementing the forward and backward passes
+ which operate on Tensors.
+ """
+
+ @staticmethod
+ def forward(ctx, input, scale=1e-4):
+ """
+ In the forward pass we receive a Tensor containing the input and return a
+ Tensor containing the output. You can cache arbitrary Tensors for use in the
+ backward pass using the save_for_backward method.
+ """
+ activation_max, index_max = torch.max(input, -1, keepdim=True)
+ input_scale = input * scale # default: 1e-4
+ # input_scale = input * scale # default: 1e-4
+ output_max_scale = torch.where(input == activation_max, input, input_scale)
+
+ mask = (input == activation_max).type(torch.float)
+ ctx.save_for_backward(input, mask)
+ return output_max_scale
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ In the backward pass we receive a Tensor containing the gradient of the loss
+ with respect to the output, and we need to compute the gradient of the loss
+ with respect to the input.
+ """
+ input, mask = ctx.saved_tensors
+ mask_ones = torch.ones_like(mask)
+ mask_small_ones = torch.ones_like(mask) * 1e-4
+ # mask_small_ones = torch.ones_like(mask) * 1e-4
+
+ grad_scale = torch.where(mask == 1, mask_ones, mask_small_ones)
+ grad_input = grad_output.clone() * grad_scale
+ return grad_input, None
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
+ super(ResidualBlock, self).__init__()
+ self.padding1 = nn.ReflectionPad2d(padding)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
+ self.bn1 = nn.InstanceNorm2d(out_channels)
+ self.prelu = nn.PReLU()
+ self.padding2 = nn.ReflectionPad2d(padding)
+ self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
+ self.bn2 = nn.InstanceNorm2d(out_channels)
+
+ def forward(self, x):
+ residual = x
+ out = self.padding1(x)
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = self.prelu(out)
+ out = self.padding2(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out += residual
+ out = self.prelu(out)
+ return out
+
+
+class WarpNet(nn.Module):
+ """input is Al, Bl, channel = 1, range~[0,255]"""
+
+ def __init__(self, feature_channel=128):
+ super(WarpNet, self).__init__()
+ self.feature_channel = feature_channel
+ self.in_channels = self.feature_channel * 4
+ self.inter_channels = 256
+ # 44*44
+ self.layer2_1 = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ # nn.Conv2d(128, 128, kernel_size=3, padding=0, stride=1),
+ # nn.Conv2d(96, 128, kernel_size=3, padding=20, stride=1),
+ nn.Conv2d(96, 128, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(128),
+ nn.PReLU(),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=2),
+ nn.InstanceNorm2d(self.feature_channel),
+ nn.PReLU(),
+ nn.Dropout(0.2),
+ )
+ self.layer3_1 = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ # nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1),
+ # nn.Conv2d(192, 128, kernel_size=3, padding=10, stride=1),
+ nn.Conv2d(192, 128, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(128),
+ nn.PReLU(),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(self.feature_channel),
+ nn.PReLU(),
+ nn.Dropout(0.2),
+ )
+
+ # 22*22->44*44
+ self.layer4_1 = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ # nn.Conv2d(512, 256, kernel_size=3, padding=0, stride=1),
+ # nn.Conv2d(384, 256, kernel_size=3, padding=5, stride=1),
+ nn.Conv2d(384, 256, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(256),
+ nn.PReLU(),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(self.feature_channel),
+ nn.PReLU(),
+ nn.Upsample(scale_factor=2),
+ nn.Dropout(0.2),
+ )
+
+ # 11*11->44*44
+ self.layer5_1 = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ # nn.Conv2d(1024, 256, kernel_size=3, padding=0, stride=1),
+ # nn.Conv2d(768, 256, kernel_size=2, padding=2, stride=1),
+ nn.Conv2d(768, 256, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(256),
+ nn.PReLU(),
+ nn.Upsample(scale_factor=2),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1),
+ nn.InstanceNorm2d(self.feature_channel),
+ nn.PReLU(),
+ nn.Upsample(scale_factor=2),
+ nn.Dropout(0.2),
+ )
+
+ self.layer = nn.Sequential(
+ ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
+ ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
+ ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
+ )
+
+ self.theta = nn.Conv2d(
+ in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
+
+ self.upsampling = nn.Upsample(scale_factor=4)
+
+ def forward(
+ self,
+ B_lab_map,
+ A_relu2_1,
+ A_relu3_1,
+ A_relu4_1,
+ A_relu5_1,
+ B_relu2_1,
+ B_relu3_1,
+ B_relu4_1,
+ B_relu5_1,
+ temperature=0.001 * 5,
+ detach_flag=False,
+ WTA_scale_weight=1,
+ ):
+ batch_size = B_lab_map.shape[0]
+ channel = B_lab_map.shape[1]
+ image_height = B_lab_map.shape[2]
+ image_width = B_lab_map.shape[3]
+ feature_height = int(image_height / 4)
+ feature_width = int(image_width / 4)
+
+ # scale feature size to 44*44
+ A_feature2_1 = self.layer2_1(A_relu2_1)
+ B_feature2_1 = self.layer2_1(B_relu2_1)
+ A_feature3_1 = self.layer3_1(A_relu3_1)
+ B_feature3_1 = self.layer3_1(B_relu3_1)
+ A_feature4_1 = self.layer4_1(A_relu4_1)
+ B_feature4_1 = self.layer4_1(B_relu4_1)
+ A_feature5_1 = self.layer5_1(A_relu5_1)
+ B_feature5_1 = self.layer5_1(B_relu5_1)
+
+ # concatenate features
+ if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]:
+ A_feature5_1 = F.pad(A_feature5_1, (0, 0, 1, 1), "replicate")
+ B_feature5_1 = F.pad(B_feature5_1, (0, 0, 1, 1), "replicate")
+
+ A_features = self.layer(torch.cat((A_feature2_1, A_feature3_1, A_feature4_1, A_feature5_1), 1))
+ B_features = self.layer(torch.cat((B_feature2_1, B_feature3_1, B_feature4_1, B_feature5_1), 1))
+
+ # pairwise cosine similarity
+ theta = self.theta(A_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width)
+ theta = theta - theta.mean(dim=-1, keepdim=True) # center the feature
+ theta_norm = torch.norm(theta, 2, 1, keepdim=True) + sys.float_info.epsilon
+ theta = torch.div(theta, theta_norm)
+ theta_permute = theta.permute(0, 2, 1) # 2*(feature_height*feature_width)*256
+ phi = self.phi(B_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width)
+ phi = phi - phi.mean(dim=-1, keepdim=True) # center the feature
+ phi_norm = torch.norm(phi, 2, 1, keepdim=True) + sys.float_info.epsilon
+ phi = torch.div(phi, phi_norm)
+ f = torch.matmul(theta_permute, phi) # 2*(feature_height*feature_width)*(feature_height*feature_width)
+ if detach_flag:
+ f = f.detach()
+
+ f_similarity = f.unsqueeze_(dim=1)
+ similarity_map = torch.max(f_similarity, -1, keepdim=True)[0]
+ similarity_map = similarity_map.view(batch_size, 1, feature_height, feature_width)
+
+ # f can be negative
+ f_WTA = f if WTA_scale_weight == 1 else WTA_scale.apply(f, WTA_scale_weight)
+ f_WTA = f_WTA / temperature
+ f_div_C = F.softmax(f_WTA.squeeze_(), dim=-1) # 2*1936*1936;
+
+ # downsample the reference color
+ B_lab = F.avg_pool2d(B_lab_map, 4)
+ B_lab = B_lab.view(batch_size, channel, -1)
+ B_lab = B_lab.permute(0, 2, 1) # 2*1936*channel
+
+ # multiply the corr map with color
+ y = torch.matmul(f_div_C, B_lab) # 2*1936*channel
+ y = y.permute(0, 2, 1).contiguous()
+ y = y.view(batch_size, channel, feature_height, feature_width) # 2*3*44*44
+ y = self.upsampling(y)
+ similarity_map = self.upsampling(similarity_map)
+
+ return y, similarity_map
\ No newline at end of file
diff --git a/src/models/CNN/__init__.py b/src/models/CNN/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/vit/__init__.py b/src/models/vit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/vit/blocks.py b/src/models/vit/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1634453ba76163a1d6b72123765933825049be
--- /dev/null
+++ b/src/models/vit/blocks.py
@@ -0,0 +1,80 @@
+import torch.nn as nn
+from timm.models.layers import DropPath
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout, out_dim=None):
+ super().__init__()
+ self.fc1 = nn.Linear(dim, hidden_dim)
+ self.act = nn.GELU()
+ if out_dim is None:
+ out_dim = dim
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
+ self.drop = nn.Dropout(dropout)
+
+ @property
+ def unwrapped(self):
+ return self
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads, dropout):
+ super().__init__()
+ self.heads = heads
+ head_dim = dim // heads
+ self.scale = head_dim**-0.5
+ self.attn = None
+
+ self.qkv = nn.Linear(dim, dim * 3)
+ self.attn_drop = nn.Dropout(dropout)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(dropout)
+
+ @property
+ def unwrapped(self):
+ return self
+
+ def forward(self, x, mask=None):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ )
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x, attn
+
+
+class Block(nn.Module):
+ def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.attn = Attention(dim, heads, dropout)
+ self.mlp = FeedForward(dim, mlp_dim, dropout)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, mask=None, return_attention=False):
+ y, attn = self.attn(self.norm1(x), mask)
+ if return_attention:
+ return attn
+ x = x + self.drop_path(y)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
diff --git a/src/models/vit/config.py b/src/models/vit/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af621c90b36c60d0e1bd0b399c5a01295c7343b
--- /dev/null
+++ b/src/models/vit/config.py
@@ -0,0 +1,22 @@
+import yaml
+from pathlib import Path
+
+import os
+
+
+def load_config():
+ return yaml.load(
+ open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader
+ )
+
+
+def check_os_environ(key, use):
+ if key not in os.environ:
+ raise ValueError(
+ f"{key} is not defined in the os variables, it is required for {use}."
+ )
+
+
+def dataset_dir():
+ check_os_environ("DATASET", "data loading")
+ return os.environ["DATASET"]
diff --git a/src/models/vit/config.yml b/src/models/vit/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..79485c328c22ff0fb50d60f0d4c360e0801cc0bc
--- /dev/null
+++ b/src/models/vit/config.yml
@@ -0,0 +1,132 @@
+model:
+ # deit
+ deit_tiny_distilled_patch16_224:
+ image_size: 224
+ patch_size: 16
+ d_model: 192
+ n_heads: 3
+ n_layers: 12
+ normalization: deit
+ distilled: true
+ deit_small_distilled_patch16_224:
+ image_size: 224
+ patch_size: 16
+ d_model: 384
+ n_heads: 6
+ n_layers: 12
+ normalization: deit
+ distilled: true
+ deit_base_distilled_patch16_224:
+ image_size: 224
+ patch_size: 16
+ d_model: 768
+ n_heads: 12
+ n_layers: 12
+ normalization: deit
+ distilled: true
+ deit_base_distilled_patch16_384:
+ image_size: 384
+ patch_size: 16
+ d_model: 768
+ n_heads: 12
+ n_layers: 12
+ normalization: deit
+ distilled: true
+ # vit
+ vit_base_patch8_384:
+ image_size: 384
+ patch_size: 8
+ d_model: 768
+ n_heads: 12
+ n_layers: 12
+ normalization: vit
+ distilled: false
+ vit_tiny_patch16_384:
+ image_size: 384
+ patch_size: 16
+ d_model: 192
+ n_heads: 3
+ n_layers: 12
+ normalization: vit
+ distilled: false
+ vit_small_patch16_384:
+ image_size: 384
+ patch_size: 16
+ d_model: 384
+ n_heads: 6
+ n_layers: 12
+ normalization: vit
+ distilled: false
+ vit_base_patch16_384:
+ image_size: 384
+ patch_size: 16
+ d_model: 768
+ n_heads: 12
+ n_layers: 12
+ normalization: vit
+ distilled: false
+ vit_large_patch16_384:
+ image_size: 384
+ patch_size: 16
+ d_model: 1024
+ n_heads: 16
+ n_layers: 24
+ normalization: vit
+ vit_small_patch32_384:
+ image_size: 384
+ patch_size: 32
+ d_model: 384
+ n_heads: 6
+ n_layers: 12
+ normalization: vit
+ distilled: false
+ vit_base_patch32_384:
+ image_size: 384
+ patch_size: 32
+ d_model: 768
+ n_heads: 12
+ n_layers: 12
+ normalization: vit
+ vit_large_patch32_384:
+ image_size: 384
+ patch_size: 32
+ d_model: 1024
+ n_heads: 16
+ n_layers: 24
+ normalization: vit
+decoder:
+ linear: {}
+ deeplab_dec:
+ encoder_layer: -1
+ mask_transformer:
+ drop_path_rate: 0.0
+ dropout: 0.1
+ n_layers: 2
+dataset:
+ ade20k:
+ epochs: 64
+ eval_freq: 2
+ batch_size: 8
+ learning_rate: 0.001
+ im_size: 512
+ crop_size: 512
+ window_size: 512
+ window_stride: 512
+ pascal_context:
+ epochs: 256
+ eval_freq: 8
+ batch_size: 16
+ learning_rate: 0.001
+ im_size: 520
+ crop_size: 480
+ window_size: 480
+ window_stride: 320
+ cityscapes:
+ epochs: 216
+ eval_freq: 4
+ batch_size: 8
+ learning_rate: 0.01
+ im_size: 1024
+ crop_size: 768
+ window_size: 768
+ window_stride: 512
diff --git a/src/models/vit/decoder.py b/src/models/vit/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..acdb2f83660904423b97f9163bd81a4016dc8723
--- /dev/null
+++ b/src/models/vit/decoder.py
@@ -0,0 +1,34 @@
+import torch.nn as nn
+from einops import rearrange
+from src.models.vit.utils import init_weights
+
+
+class DecoderLinear(nn.Module):
+ def __init__(
+ self,
+ n_cls,
+ d_encoder,
+ scale_factor,
+ dropout_rate=0.3,
+ ):
+ super().__init__()
+ self.scale_factor = scale_factor
+ self.head = nn.Linear(d_encoder, n_cls)
+ self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear")
+ self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor))
+ self.dropout = nn.Dropout(dropout_rate)
+ self.gelu = nn.GELU()
+ self.apply(init_weights)
+
+ def forward(self, x, img_size):
+ H, _ = img_size
+ x = self.head(x) ####### (2, 577, 64)
+ x = x.transpose(2, 1) ## (2, 64, 576)
+ x = self.upsampling(x) # (2, 64, 576*scale_factor*scale_factor)
+ x = x.transpose(2, 1) ## (2, 576*scale_factor*scale_factor, 64)
+ x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor)) # (2, 64, 24*scale_factor, 24*scale_factor)
+ x = self.norm(x)
+ x = self.dropout(x)
+ x = self.gelu(x)
+
+ return x # (2, 64, a, a)
diff --git a/src/models/vit/embed.py b/src/models/vit/embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..50fe91a1b6798f002b50e7253f42743763311908
--- /dev/null
+++ b/src/models/vit/embed.py
@@ -0,0 +1,52 @@
+from torch import nn
+from timm import create_model
+from torchvision.transforms import Normalize
+
+class SwinModel(nn.Module):
+ def __init__(self, pretrained_model="swinv2-cr-t-224", device="cuda") -> None:
+ """
+ vit_tiny_patch16_224.augreg_in21k_ft_in1k
+ swinv2_cr_tiny_ns_224.sw_in1k
+ """
+ super().__init__()
+ self.device = device
+ self.pretrained_model = pretrained_model
+ if pretrained_model == "swinv2-cr-t-224":
+ self.pretrained = create_model(
+ "swinv2_cr_tiny_ns_224.sw_in1k",
+ pretrained=True,
+ features_only=True,
+ out_indices=[-4, -3, -2, -1],
+ ).to(device)
+ elif pretrained_model == "swinv2-t-256":
+ self.pretrained = create_model(
+ "swinv2_tiny_window16_256.ms_in1k",
+ pretrained=True,
+ features_only=True,
+ out_indices=[-4, -3, -2, -1],
+ ).to(device)
+ elif pretrained_model == "swinv2-cr-s-224":
+ self.pretrained = create_model(
+ "swinv2_cr_small_ns_224.sw_in1k",
+ pretrained=True,
+ features_only=True,
+ out_indices=[-4, -3, -2, -1],
+ ).to(device)
+ else:
+ raise NotImplementedError
+
+ self.pretrained.eval()
+ self.normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ self.upsample = nn.Upsample(scale_factor=2)
+
+ for params in self.pretrained.parameters():
+ params.requires_grad = False
+
+ def forward(self, x):
+ outputs = self.pretrained(x)
+ if self.pretrained_model in ["swinv2-t-256"]:
+ for i in range(len(outputs)):
+ outputs[i] = outputs[i].permute(0, 3, 1, 2) # Change channel-last to channel-first
+ outputs = [self.upsample(feat) for feat in outputs]
+
+ return outputs
\ No newline at end of file
diff --git a/src/models/vit/factory.py b/src/models/vit/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab2cad05744bf6ed60ee6278a5b79b92321ac4c5
--- /dev/null
+++ b/src/models/vit/factory.py
@@ -0,0 +1,45 @@
+import os
+import torch
+from timm.models.vision_transformer import default_cfgs
+from timm.models.helpers import load_pretrained, load_custom_pretrained
+from src.models.vit.utils import checkpoint_filter_fn
+from src.models.vit.vit import VisionTransformer
+
+
+def create_vit(model_cfg):
+ model_cfg = model_cfg.copy()
+ backbone = model_cfg.pop("backbone")
+
+ model_cfg.pop("normalization")
+ model_cfg["n_cls"] = 1000
+ mlp_expansion_ratio = 4
+ model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
+
+ if backbone in default_cfgs:
+ default_cfg = default_cfgs[backbone]
+ else:
+ default_cfg = dict(
+ pretrained=False,
+ num_classes=1000,
+ drop_rate=0.0,
+ drop_path_rate=0.0,
+ drop_block_rate=None,
+ )
+
+ default_cfg["input_size"] = (
+ 3,
+ model_cfg["image_size"][0],
+ model_cfg["image_size"][1],
+ )
+ model = VisionTransformer(**model_cfg)
+ if backbone == "vit_base_patch8_384":
+ path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth")
+ state_dict = torch.load(path, map_location="cpu")
+ filtered_dict = checkpoint_filter_fn(state_dict, model)
+ model.load_state_dict(filtered_dict, strict=True)
+ elif "deit" in backbone:
+ load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
+ else:
+ load_custom_pretrained(model, default_cfg)
+
+ return model
diff --git a/src/models/vit/utils.py b/src/models/vit/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9aad57595ed5c7c2b84f436223ea9f8d8b961a2
--- /dev/null
+++ b/src/models/vit/utils.py
@@ -0,0 +1,71 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_
+from collections import OrderedDict
+
+
+def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ posemb_tok, posemb_grid = (
+ posemb[:, :num_extra_tokens],
+ posemb[0, num_extra_tokens:],
+ )
+ if grid_old_shape is None:
+ gs_old_h = int(math.sqrt(len(posemb_grid)))
+ gs_old_w = gs_old_h
+ else:
+ gs_old_h, gs_old_w = grid_old_shape
+
+ gs_h, gs_w = grid_new_shape
+ posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+
+
+def init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ if "model" in state_dict:
+ # For deit models
+ state_dict = state_dict["model"]
+ num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
+ patch_size = model.patch_size
+ image_size = model.patch_embed.image_size
+ for k, v in state_dict.items():
+ if k == "pos_embed" and v.shape != model.pos_embed.shape:
+ # To resize pos embedding when using model at different size from pretrained weights
+ v = resize_pos_embed(
+ v,
+ None,
+ (image_size[0] // patch_size, image_size[1] // patch_size),
+ num_extra_tokens,
+ )
+ out_dict[k] = v
+ return out_dict
+
+def load_params(ckpt_file, device):
+ # params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}')
+ # new_params = []
+ # for key, value in params.items():
+ # new_params.append(("module."+key if has_module else key, value))
+ # return OrderedDict(new_params)
+ params = torch.load(ckpt_file, map_location=device)
+ new_params = []
+ for key, value in params.items():
+ new_params.append((key, value))
+ return OrderedDict(new_params)
diff --git a/src/models/vit/vit.py b/src/models/vit/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d868d3ef239283195f492f29f5434ebe55b322
--- /dev/null
+++ b/src/models/vit/vit.py
@@ -0,0 +1,199 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.vision_transformer import _load_weights
+from timm.models.layers import trunc_normal_
+from typing import List
+
+from src.models.vit.utils import init_weights, resize_pos_embed
+from src.models.vit.blocks import Block
+from src.models.vit.decoder import DecoderLinear
+
+
+class PatchEmbedding(nn.Module):
+ def __init__(self, image_size, patch_size, embed_dim, channels):
+ super().__init__()
+
+ self.image_size = image_size
+ if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
+ raise ValueError("image dimensions must be divisible by the patch size")
+ self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, im):
+ B, C, H, W = im.shape
+ x = self.proj(im).flatten(2).transpose(1, 2)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ patch_size,
+ n_layers,
+ d_model,
+ d_ff,
+ n_heads,
+ n_cls,
+ dropout=0.1,
+ drop_path_rate=0.0,
+ distilled=False,
+ channels=3,
+ ):
+ super().__init__()
+ self.patch_embed = PatchEmbedding(
+ image_size,
+ patch_size,
+ d_model,
+ channels,
+ )
+ self.patch_size = patch_size
+ self.n_layers = n_layers
+ self.d_model = d_model
+ self.d_ff = d_ff
+ self.n_heads = n_heads
+ self.dropout = nn.Dropout(dropout)
+ self.n_cls = n_cls
+
+ # cls and pos tokens
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
+ self.distilled = distilled
+ if self.distilled:
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
+ self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 2, d_model))
+ self.head_dist = nn.Linear(d_model, n_cls)
+ else:
+ self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, d_model))
+
+ # transformer blocks
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
+ self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)])
+
+ # output head
+ self.norm = nn.LayerNorm(d_model)
+ self.head = nn.Linear(d_model, n_cls)
+
+ trunc_normal_(self.pos_embed, std=0.02)
+ trunc_normal_(self.cls_token, std=0.02)
+ if self.distilled:
+ trunc_normal_(self.dist_token, std=0.02)
+ self.pre_logits = nn.Identity()
+
+ self.apply(init_weights)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token", "dist_token"}
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=""):
+ _load_weights(self, checkpoint_path, prefix)
+
+ def forward(self, im, head_out_idx: List[int], n_dim_output=3, return_features=False):
+ B, _, H, W = im.shape
+ PS = self.patch_size
+ assert n_dim_output == 3 or n_dim_output == 4, "n_dim_output must be 3 or 4"
+ x = self.patch_embed(im)
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ if self.distilled:
+ dist_tokens = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
+ else:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ pos_embed = self.pos_embed
+ num_extra_tokens = 1 + self.distilled
+ if x.shape[1] != pos_embed.shape[1]:
+ pos_embed = resize_pos_embed(
+ pos_embed,
+ self.patch_embed.grid_size,
+ (H // PS, W // PS),
+ num_extra_tokens,
+ )
+ x = x + pos_embed
+ x = self.dropout(x)
+ device = x.device
+
+ if n_dim_output == 3:
+ heads_out = torch.zeros(size=(len(head_out_idx), B, (H // PS) ** 2 + 1, self.d_model)).to(device)
+ else:
+ heads_out = torch.zeros(size=(len(head_out_idx), B, self.d_model, H // PS, H // PS)).to(device)
+ self.register_buffer("heads_out", heads_out)
+
+ head_idx = 0
+ for idx_layer, blk in enumerate(self.blocks):
+ x = blk(x)
+ if idx_layer in head_out_idx:
+ if n_dim_output == 3:
+ heads_out[head_idx] = x
+ else:
+ heads_out[head_idx] = x[:, 1:, :].reshape((-1, 24, 24, self.d_model)).permute(0, 3, 1, 2)
+ head_idx += 1
+
+ x = self.norm(x)
+
+ if return_features:
+ return heads_out
+
+ if self.distilled:
+ x, x_dist = x[:, 0], x[:, 1]
+ x = self.head(x)
+ x_dist = self.head_dist(x_dist)
+ x = (x + x_dist) / 2
+ else:
+ x = x[:, 0]
+ x = self.head(x)
+ return x
+
+ def get_attention_map(self, im, layer_id):
+ if layer_id >= self.n_layers or layer_id < 0:
+ raise ValueError(f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}.")
+ B, _, H, W = im.shape
+ PS = self.patch_size
+
+ x = self.patch_embed(im)
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ if self.distilled:
+ dist_tokens = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
+ else:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ pos_embed = self.pos_embed
+ num_extra_tokens = 1 + self.distilled
+ if x.shape[1] != pos_embed.shape[1]:
+ pos_embed = resize_pos_embed(
+ pos_embed,
+ self.patch_embed.grid_size,
+ (H // PS, W // PS),
+ num_extra_tokens,
+ )
+ x = x + pos_embed
+
+ for i, blk in enumerate(self.blocks):
+ if i < layer_id:
+ x = blk(x)
+ else:
+ return blk(x, return_attention=True)
+
+
+class FeatureTransform(nn.Module):
+ def __init__(self, img_size, d_encoder, nls_list=[128, 256, 512, 512], scale_factor_list=[8, 4, 2, 1]):
+ super(FeatureTransform, self).__init__()
+ self.img_size = img_size
+
+ self.decoder_0 = DecoderLinear(n_cls=nls_list[0], d_encoder=d_encoder, scale_factor=scale_factor_list[0])
+ self.decoder_1 = DecoderLinear(n_cls=nls_list[1], d_encoder=d_encoder, scale_factor=scale_factor_list[1])
+ self.decoder_2 = DecoderLinear(n_cls=nls_list[2], d_encoder=d_encoder, scale_factor=scale_factor_list[2])
+ self.decoder_3 = DecoderLinear(n_cls=nls_list[3], d_encoder=d_encoder, scale_factor=scale_factor_list[3])
+
+ def forward(self, x_list):
+ feat_3 = self.decoder_3(x_list[3][:, 1:, :], self.img_size) # (2, 512, 24, 24)
+ feat_2 = self.decoder_2(x_list[2][:, 1:, :], self.img_size) # (2, 512, 48, 48)
+ feat_1 = self.decoder_1(x_list[1][:, 1:, :], self.img_size) # (2, 256, 96, 96)
+ feat_0 = self.decoder_0(x_list[0][:, 1:, :], self.img_size) # (2, 128, 192, 192)
+ return feat_0, feat_1, feat_2, feat_3
\ No newline at end of file
diff --git a/src/scheduler.py b/src/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..87a9a1c5fcbf2df2a9263d49a5d3f5ba87ccb48d
--- /dev/null
+++ b/src/scheduler.py
@@ -0,0 +1,40 @@
+from torch.optim.lr_scheduler import _LRScheduler
+
+class PolynomialLR(_LRScheduler):
+ def __init__(
+ self,
+ optimizer,
+ step_size,
+ iter_warmup,
+ iter_max,
+ power,
+ min_lr=0,
+ last_epoch=-1,
+ ):
+ self.step_size = step_size
+ self.iter_warmup = int(iter_warmup)
+ self.iter_max = int(iter_max)
+ self.power = power
+ self.min_lr = min_lr
+ super(PolynomialLR, self).__init__(optimizer, last_epoch)
+
+ def polynomial_decay(self, lr):
+ iter_cur = float(self.last_epoch)
+ if iter_cur < self.iter_warmup:
+ coef = iter_cur / self.iter_warmup
+ coef *= (1 - self.iter_warmup / self.iter_max) ** self.power
+ else:
+ coef = (1 - iter_cur / self.iter_max) ** self.power
+ return (lr - self.min_lr) * coef + self.min_lr
+
+ def get_lr(self):
+ if (
+ (self.last_epoch == 0)
+ or (self.last_epoch % self.step_size != 0)
+ or (self.last_epoch > self.iter_max)
+ ):
+ return [group["lr"] for group in self.optimizer.param_groups]
+ return [self.polynomial_decay(lr) for lr in self.base_lrs]
+
+ def step_update(self, num_updates):
+ self.step()
\ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7c76f96480aceb430ebca9f4d6655b2f0847a23
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,849 @@
+import sys
+import time
+import numpy as np
+from PIL import Image
+from skimage import color
+from skimage.transform import resize
+import src.data.functional as F
+import torch
+from torch import nn
+import torch.nn.functional as F_torch
+import torchvision.transforms.functional as F_torchvision
+from numba import cuda, jit
+import math
+import torchvision.utils as vutils
+from torch.autograd import Variable
+
+rgb_from_xyz = np.array(
+ [
+ [3.24048134, -0.96925495, 0.05564664],
+ [-1.53715152, 1.87599, -0.20404134],
+ [-0.49853633, 0.04155593, 1.05731107],
+ ]
+)
+l_norm, ab_norm = 1.0, 1.0
+l_mean, ab_mean = 50.0, 0
+
+
+import numpy as np
+from PIL import Image
+from skimage.transform import resize
+
+import numpy as np
+from PIL import Image
+from skimage.transform import resize
+
+class SquaredPadding:
+ def __init__(self, target_size=384, fill_value=0):
+ self.target_size = target_size
+ self.fill_value = fill_value
+
+ def __call__(self, img, return_pil=True, return_paddings=False, dtype=np.uint8):
+ if not isinstance(img, np.ndarray):
+ img = np.array(img)
+ ndim = len(img.shape)
+ H, W = img.shape[:2]
+ if H > W:
+ H_new, W_new = self.target_size, int(W/H*self.target_size)
+ # Resize image
+ img = resize(img, (H_new, W_new), preserve_range=True).astype(dtype)
+
+ # Padding image
+ padded_size = H_new - W_new
+ if ndim == 3:
+ paddings = [(0, 0), (padded_size // 2, (padded_size // 2) + (padded_size % 2)), (0,0)]
+ elif ndim == 2:
+ paddings = [(0, 0), (padded_size // 2, (padded_size // 2) + (padded_size % 2))]
+ padded_img = np.pad(img, paddings, mode='constant', constant_values=self.fill_value)
+ else:
+ H_new, W_new = int(H/W*self.target_size), self.target_size
+ # Resize image
+ img = resize(img, (H_new, W_new), preserve_range=True).astype(dtype)
+
+ # Padding image
+ padded_size = W_new - H_new
+ if ndim == 3:
+ paddings = [(padded_size // 2, (padded_size // 2) + (padded_size % 2)), (0, 0), (0,0)]
+ elif ndim == 2:
+ paddings = [(padded_size // 2, (padded_size // 2) + (padded_size % 2)), (0, 0)]
+ padded_img = np.pad(img, paddings, mode='constant', constant_values=self.fill_value)
+
+ if return_pil:
+ padded_img = Image.fromarray(padded_img)
+
+ if return_paddings:
+ return padded_img, paddings
+
+ return padded_img
+
+class UnpaddingSquare():
+ def __call__(self, img, paddings):
+ if not isinstance(img, np.ndarray):
+ img = np.array(img)
+
+ H, W = img.shape[0], img.shape[1]
+ (pad_top, pad_bottom), (pad_left, pad_right), _ = paddings
+ W_ori = W - pad_left - pad_right
+ H_ori = H - pad_top - pad_bottom
+
+ return img[pad_top:pad_top+H_ori, pad_left:pad_left+W_ori, :]
+
+class UnpaddingSquare_Tensor():
+ def __call__(self, img, paddings):
+ H, W = img.shape[1], img.shape[2]
+ (pad_top, pad_bottom), (pad_left, pad_right), _ = paddings
+ W_ori = W - pad_left - pad_right
+ H_ori = H - pad_top - pad_bottom
+
+ return img[:, pad_top:pad_top+H_ori, pad_left:pad_left+W_ori]
+
+class ResizeFlow(object):
+ def __init__(self, target_size=(224,224)):
+ self.target_size = target_size
+ pass
+
+ def __call__(self, flow):
+ return F_torch.interpolate(flow.unsqueeze(0), self.target_size, mode='bilinear', align_corners=True).squeeze(0)
+
+class SquaredPaddingFlow(object):
+ def __init__(self, fill_value=0):
+ self.fill_value = fill_value
+
+ def __call__(self, flow):
+ H, W = flow.size(1), flow.size(2)
+
+ if H > W:
+ # Padding flow
+ padded_size = H - W
+ paddings = (padded_size // 2, (padded_size // 2) + (padded_size % 2), 0, 0)
+ padded_img = F_torch.pad(flow, paddings, value=self.fill_value)
+ else:
+ # Padding flow
+ padded_size = W - H
+ paddings = (0, 0, padded_size // 2, (padded_size // 2) + (padded_size % 2))
+ padded_img = F_torch.pad(flow, paddings, value=self.fill_value)
+
+ return padded_img
+
+
+def gray2rgb_batch(l):
+ # gray image tensor to rgb image tensor
+ l_uncenter = uncenter_l(l)
+ l_uncenter = l_uncenter / (2 * l_mean)
+ return torch.cat((l_uncenter, l_uncenter, l_uncenter), dim=1)
+
+def batch_lab2rgb_transpose_mc(img_l_mc, img_ab_mc, nrow=8):
+ if isinstance(img_l_mc, Variable):
+ img_l_mc = img_l_mc.data.cpu()
+ if isinstance(img_ab_mc, Variable):
+ img_ab_mc = img_ab_mc.data.cpu()
+
+ if img_l_mc.is_cuda:
+ img_l_mc = img_l_mc.cpu()
+ if img_ab_mc.is_cuda:
+ img_ab_mc = img_ab_mc.cpu()
+
+ assert img_l_mc.dim() == 4 and img_ab_mc.dim() == 4, "only for batch input"
+
+ img_l = img_l_mc * l_norm + l_mean
+ img_ab = img_ab_mc * ab_norm + ab_mean
+ pred_lab = torch.cat((img_l, img_ab), dim=1)
+ grid_lab = vutils.make_grid(pred_lab, nrow=nrow).numpy().astype("float64")
+ return (np.clip(color.lab2rgb(grid_lab.transpose((1, 2, 0))), 0, 1) * 255).astype("uint8")
+
+
+def vgg_preprocess(tensor):
+ # input is RGB tensor which ranges in [0,1]
+ # output is BGR tensor which ranges in [0,255]
+ tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
+ tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1)
+ return tensor_bgr_ml * 255
+
+
+def tensor_lab2rgb(input):
+ """
+ n * 3* h *w
+ """
+ input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3
+ L, a, b = (
+ input_trans[:, :, :, 0:1],
+ input_trans[:, :, :, 1:2],
+ input_trans[:, :, :, 2:],
+ )
+ y = (L + 16.0) / 116.0
+ x = (a / 500.0) + y
+ z = y - (b / 200.0)
+
+ neg_mask = z.data < 0
+ z[neg_mask] = 0
+ xyz = torch.cat((x, y, z), dim=3)
+
+ mask = xyz.data > 0.2068966
+ mask_xyz = xyz.clone()
+ mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
+ mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787
+ mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047
+ mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883
+
+ rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(
+ input.size(0), input.size(2), input.size(3), 3
+ )
+ rgb = rgb_trans.transpose(2, 3).transpose(1, 2)
+
+ mask = rgb > 0.0031308
+ mask_rgb = rgb.clone()
+ mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
+ mask_rgb[~mask] = rgb[~mask] * 12.92
+
+ neg_mask = mask_rgb.data < 0
+ large_mask = mask_rgb.data > 1
+ mask_rgb[neg_mask] = 0
+ mask_rgb[large_mask] = 1
+ return mask_rgb
+
+
+###### loss functions ######
+def feature_normalize(feature_in):
+ feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + sys.float_info.epsilon
+ feature_in_norm = torch.div(feature_in, feature_in_norm)
+ return feature_in_norm
+
+
+# denormalization for l
+def uncenter_l(l):
+ return l * l_norm + l_mean
+
+
+def get_grid(x):
+ torchHorizontal = torch.linspace(-1.0, 1.0, x.size(3)).view(1, 1, 1, x.size(3)).expand(x.size(0), 1, x.size(2), x.size(3))
+ torchVertical = torch.linspace(-1.0, 1.0, x.size(2)).view(1, 1, x.size(2), 1).expand(x.size(0), 1, x.size(2), x.size(3))
+
+ return torch.cat([torchHorizontal, torchVertical], 1)
+
+
+class WarpingLayer(nn.Module):
+ def __init__(self, device):
+ super(WarpingLayer, self).__init__()
+ self.device = device
+
+ def forward(self, x, flow):
+ """
+ It takes the input image and the flow and warps the input image according to the flow
+
+ Args:
+ x: the input image
+ flow: the flow tensor, which is a 4D tensor of shape (batch_size, 2, height, width)
+
+ Returns:
+ The warped image
+ """
+ # WarpingLayer uses F.grid_sample, which expects normalized grid
+ # we still output unnormalized flow for the convenience of comparing EPEs with FlowNet2 and original code
+ # so here we need to denormalize the flow
+ flow_for_grip = torch.zeros_like(flow).to(self.device)
+ flow_for_grip[:, 0, :, :] = flow[:, 0, :, :] / ((flow.size(3) - 1.0) / 2.0)
+ flow_for_grip[:, 1, :, :] = flow[:, 1, :, :] / ((flow.size(2) - 1.0) / 2.0)
+
+ grid = (get_grid(x).to(self.device) + flow_for_grip).permute(0, 2, 3, 1)
+ return F_torch.grid_sample(x, grid, align_corners=True)
+
+
+class CenterPad_threshold(object):
+ def __init__(self, image_size, threshold=3 / 4):
+ self.height = image_size[0]
+ self.width = image_size[1]
+ self.threshold = threshold
+
+ def __call__(self, image):
+ # pad the image to 16:9
+ # pad height
+ I = np.array(image)
+
+ # for padded input
+ height_old = np.size(I, 0)
+ width_old = np.size(I, 1)
+ old_size = [height_old, width_old]
+ height = self.height
+ width = self.width
+ I_pad = np.zeros((height, width, np.size(I, 2)))
+
+ ratio = height / width
+
+ if height_old / width_old == ratio:
+ if height_old == height:
+ return Image.fromarray(I.astype(np.uint8))
+ new_size = [int(x * height / height_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ return Image.fromarray(I_resize.astype(np.uint8))
+
+ if height_old / width_old > self.threshold:
+ width_new, height_new = width_old, int(width_old * self.threshold)
+ height_margin = height_old - height_new
+ height_crop_start = height_margin // 2
+ I_crop = I[height_crop_start : (height_crop_start + height_new), :, :]
+ I_resize = resize(I_crop, [height, width], mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+
+ return Image.fromarray(I_resize.astype(np.uint8))
+
+ if height_old / width_old > ratio: # pad the width and crop
+ new_size = [int(x * width / width_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ width_resize = np.size(I_resize, 1)
+ height_resize = np.size(I_resize, 0)
+ start_height = (height_resize - height) // 2
+ I_pad[:, :, :] = I_resize[start_height : (start_height + height), :, :]
+ else: # pad the height and crop
+ new_size = [int(x * height / height_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ width_resize = np.size(I_resize, 1)
+ height_resize = np.size(I_resize, 0)
+ start_width = (width_resize - width) // 2
+ I_pad[:, :, :] = I_resize[:, start_width : (start_width + width), :]
+
+ return Image.fromarray(I_pad.astype(np.uint8))
+
+
+class Normalize(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, inputs):
+ inputs[0:1, :, :] = F.normalize(inputs[0:1, :, :], 50, 1)
+ inputs[1:3, :, :] = F.normalize(inputs[1:3, :, :], (0, 0), (1, 1))
+ return inputs
+
+
+class RGB2Lab(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, inputs):
+ return color.rgb2lab(inputs)
+
+
+class ToTensor(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, inputs):
+ return F.to_mytensor(inputs)
+
+
+class CenterPad(object):
+ def __init__(self, image_size):
+ self.height = image_size[0]
+ self.width = image_size[1]
+
+ def __call__(self, image):
+ # pad the image to 16:9
+ # pad height
+ I = np.array(image)
+
+ # for padded input
+ height_old = np.size(I, 0)
+ width_old = np.size(I, 1)
+ old_size = [height_old, width_old]
+ height = self.height
+ width = self.width
+ I_pad = np.zeros((height, width, np.size(I, 2)))
+
+ ratio = height / width
+ if height_old / width_old == ratio:
+ if height_old == height:
+ return Image.fromarray(I.astype(np.uint8))
+ new_size = [int(x * height / height_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ return Image.fromarray(I_resize.astype(np.uint8))
+
+ if height_old / width_old > ratio: # pad the width and crop
+ new_size = [int(x * width / width_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ width_resize = np.size(I_resize, 1)
+ height_resize = np.size(I_resize, 0)
+ start_height = (height_resize - height) // 2
+ I_pad[:, :, :] = I_resize[start_height : (start_height + height), :, :]
+ else: # pad the height and crop
+ new_size = [int(x * height / height_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ width_resize = np.size(I_resize, 1)
+ height_resize = np.size(I_resize, 0)
+ start_width = (width_resize - width) // 2
+ I_pad[:, :, :] = I_resize[:, start_width : (start_width + width), :]
+
+ return Image.fromarray(I_pad.astype(np.uint8))
+
+
+class CenterPadCrop_numpy(object):
+ """
+ pad the image according to the height
+ """
+
+ def __init__(self, image_size):
+ self.height = image_size[0]
+ self.width = image_size[1]
+
+ def __call__(self, image, threshold=3 / 4):
+ # pad the image to 16:9
+ # pad height
+ I = np.array(image)
+ # for padded input
+ height_old = np.size(I, 0)
+ width_old = np.size(I, 1)
+ old_size = [height_old, width_old]
+ height = self.height
+ width = self.width
+ padding_size = width
+ if image.ndim == 2:
+ I_pad = np.zeros((width, width))
+ else:
+ I_pad = np.zeros((width, width, I.shape[2]))
+
+ ratio = height / width
+ if height_old / width_old == ratio:
+ return I
+
+ # if height_old / width_old > threshold:
+ # width_new, height_new = width_old, int(width_old * threshold)
+ # height_margin = height_old - height_new
+ # height_crop_start = height_margin // 2
+ # I_crop = I[height_start : (height_start + height_new), :]
+ # I_resize = resize(
+ # I_crop, [height, width], mode="reflect", preserve_range=True, clip=False, anti_aliasing=True
+ # )
+ # return I_resize
+
+ if height_old / width_old > ratio: # pad the width and crop
+ new_size = [int(x * width / width_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ width_resize = np.size(I_resize, 1)
+ height_resize = np.size(I_resize, 0)
+ start_height = (height_resize - height) // 2
+ start_height_block = (padding_size - height) // 2
+ if image.ndim == 2:
+ I_pad[start_height_block : (start_height_block + height), :] = I_resize[
+ start_height : (start_height + height), :
+ ]
+ else:
+ I_pad[start_height_block : (start_height_block + height), :, :] = I_resize[
+ start_height : (start_height + height), :, :
+ ]
+ else: # pad the height and crop
+ new_size = [int(x * height / height_old) for x in old_size]
+ I_resize = resize(I, new_size, mode="reflect", preserve_range=True, clip=False, anti_aliasing=True)
+ width_resize = np.size(I_resize, 1)
+ height_resize = np.size(I_resize, 0)
+ start_width = (width_resize - width) // 2
+ start_width_block = (padding_size - width) // 2
+ if image.ndim == 2:
+ I_pad[:, start_width_block : (start_width_block + width)] = I_resize[:, start_width : (start_width + width)]
+
+ else:
+ I_pad[:, start_width_block : (start_width_block + width), :] = I_resize[
+ :, start_width : (start_width + width), :
+ ]
+
+ crop_start_height = (I_pad.shape[0] - height) // 2
+ crop_start_width = (I_pad.shape[1] - width) // 2
+
+ if image.ndim == 2:
+ return I_pad[crop_start_height : (crop_start_height + height), crop_start_width : (crop_start_width + width)]
+ else:
+ return I_pad[crop_start_height : (crop_start_height + height), crop_start_width : (crop_start_width + width), :]
+
+
+@jit(nopython=True, nogil=True)
+def biInterpolation_cpu(distorted, i, j):
+ i = np.uint16(i)
+ j = np.uint16(j)
+ Q11 = distorted[j, i]
+ Q12 = distorted[j, i + 1]
+ Q21 = distorted[j + 1, i]
+ Q22 = distorted[j + 1, i + 1]
+
+ return np.int8(
+ Q11 * (i + 1 - i) * (j + 1 - j) + Q12 * (i - i) * (j + 1 - j) + Q21 * (i + 1 - i) * (j - j) + Q22 * (i - i) * (j - j)
+ )
+
+@jit(nopython=True, nogil=True)
+def iterSearchShader_cpu(padu, padv, xr, yr, W, H, maxIter, precision):
+ # print('processing location', (xr, yr))
+ #
+ if abs(padu[yr, xr]) < precision and abs(padv[yr, xr]) < precision:
+ return xr, yr
+
+ # Our initialize method in this paper, can see the overleaf for detail
+ if (xr + 1) <= (W - 1):
+ dif = padu[yr, xr + 1] - padu[yr, xr]
+ else:
+ dif = padu[yr, xr] - padu[yr, xr - 1]
+ u_next = padu[yr, xr] / (1 + dif)
+ if (yr + 1) <= (H - 1):
+ dif = padv[yr + 1, xr] - padv[yr, xr]
+ else:
+ dif = padv[yr, xr] - padv[yr - 1, xr]
+ v_next = padv[yr, xr] / (1 + dif)
+ i = xr - u_next
+ j = yr - v_next
+ i_int = int(i)
+ j_int = int(j)
+
+ # The same as traditional iterative search method
+ for _ in range(maxIter):
+ if not 0 <= i <= (W - 1) or not 0 <= j <= (H - 1):
+ return i, j
+
+ u11 = padu[j_int, i_int]
+ v11 = padv[j_int, i_int]
+
+ u12 = padu[j_int, i_int + 1]
+ v12 = padv[j_int, i_int + 1]
+
+ int1 = padu[j_int + 1, i_int]
+ v21 = padv[j_int + 1, i_int]
+
+ int2 = padu[j_int + 1, i_int + 1]
+ v22 = padv[j_int + 1, i_int + 1]
+
+ u = (
+ u11 * (i_int + 1 - i) * (j_int + 1 - j)
+ + u12 * (i - i_int) * (j_int + 1 - j)
+ + int1 * (i_int + 1 - i) * (j - j_int)
+ + int2 * (i - i_int) * (j - j_int)
+ )
+
+ v = (
+ v11 * (i_int + 1 - i) * (j_int + 1 - j)
+ + v12 * (i - i_int) * (j_int + 1 - j)
+ + v21 * (i_int + 1 - i) * (j - j_int)
+ + v22 * (i - i_int) * (j - j_int)
+ )
+
+ i_next = xr - u
+ j_next = yr - v
+
+ if abs(i - i_next) < precision and abs(j - j_next) < precision:
+ return i, j
+
+ i = i_next
+ j = j_next
+
+ # if the search doesn't converge within max iter, it will return the last iter result
+ return i_next, j_next
+
+@jit(nopython=True, nogil=True)
+def iterSearch_cpu(distortImg, resultImg, padu, padv, W, H, maxIter=5, precision=1e-2):
+ for xr in range(W):
+ for yr in range(H):
+ # (xr, yr) is the point in result image, (i, j) is the search result in distorted image
+ i, j = iterSearchShader_cpu(padu, padv, xr, yr, W, H, maxIter, precision)
+
+ # reflect the pixels outside the border
+ if i > W - 1:
+ i = 2 * W - 1 - i
+ if i < 0:
+ i = -i
+ if j > H - 1:
+ j = 2 * H - 1 - j
+ if j < 0:
+ j = -j
+
+ # Bilinear interpolation to get the pixel at (i, j) in distorted image
+ resultImg[yr, xr, 0] = biInterpolation_cpu(
+ distortImg[:, :, 0],
+ i,
+ j,
+ )
+ resultImg[yr, xr, 1] = biInterpolation_cpu(
+ distortImg[:, :, 1],
+ i,
+ j,
+ )
+ resultImg[yr, xr, 2] = biInterpolation_cpu(
+ distortImg[:, :, 2],
+ i,
+ j,
+ )
+ return None
+
+
+def forward_mapping_cpu(source_image, u, v, maxIter=5, precision=1e-2):
+ """
+ warp the image according to the forward flow
+ u: horizontal
+ v: vertical
+ """
+ H = source_image.shape[0]
+ W = source_image.shape[1]
+
+ distortImg = np.array(np.zeros((H + 1, W + 1, 3)), dtype=np.uint8)
+ distortImg[0:H, 0:W] = source_image[0:H, 0:W]
+ distortImg[H, 0:W] = source_image[H - 1, 0:W]
+ distortImg[0:H, W] = source_image[0:H, W - 1]
+ distortImg[H, W] = source_image[H - 1, W - 1]
+
+ padu = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32)
+ padu[0:H, 0:W] = u[0:H, 0:W]
+ padu[H, 0:W] = u[H - 1, 0:W]
+ padu[0:H, W] = u[0:H, W - 1]
+ padu[H, W] = u[H - 1, W - 1]
+
+ padv = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32)
+ padv[0:H, 0:W] = v[0:H, 0:W]
+ padv[H, 0:W] = v[H - 1, 0:W]
+ padv[0:H, W] = v[0:H, W - 1]
+ padv[H, W] = v[H - 1, W - 1]
+
+ resultImg = np.array(np.zeros((H, W, 3)), dtype=np.uint8)
+ iterSearch_cpu(distortImg, resultImg, padu, padv, W, H, maxIter, precision)
+ return resultImg
+
+class Distortion_with_flow_cpu(object):
+ """Elastic distortion"""
+
+ def __init__(self, maxIter=3, precision=1e-3):
+ self.maxIter = maxIter
+ self.precision = precision
+
+ def __call__(self, inputs, dx, dy):
+ inputs = np.array(inputs)
+ shape = inputs.shape[0], inputs.shape[1]
+ remap_image = forward_mapping_cpu(inputs, dy, dx, maxIter=self.maxIter, precision=self.precision)
+
+ return Image.fromarray(remap_image)
+
+@cuda.jit(device=True)
+def biInterpolation_gpu(distorted, i, j):
+ i = int(i)
+ j = int(j)
+ Q11 = distorted[j, i]
+ Q12 = distorted[j, i + 1]
+ Q21 = distorted[j + 1, i]
+ Q22 = distorted[j + 1, i + 1]
+
+ return np.int8(
+ Q11 * (i + 1 - i) * (j + 1 - j) + Q12 * (i - i) * (j + 1 - j) + Q21 * (i + 1 - i) * (j - j) + Q22 * (i - i) * (j - j)
+ )
+
+@cuda.jit(device=True)
+def iterSearchShader_gpu(padu, padv, xr, yr, W, H, maxIter, precision):
+ # print('processing location', (xr, yr))
+ #
+ if abs(padu[yr, xr]) < precision and abs(padv[yr, xr]) < precision:
+ return xr, yr
+
+ # Our initialize method in this paper, can see the overleaf for detail
+ if (xr + 1) <= (W - 1):
+ dif = padu[yr, xr + 1] - padu[yr, xr]
+ else:
+ dif = padu[yr, xr] - padu[yr, xr - 1]
+ u_next = padu[yr, xr] / (1 + dif)
+ if (yr + 1) <= (H - 1):
+ dif = padv[yr + 1, xr] - padv[yr, xr]
+ else:
+ dif = padv[yr, xr] - padv[yr - 1, xr]
+ v_next = padv[yr, xr] / (1 + dif)
+ i = xr - u_next
+ j = yr - v_next
+ i_int = int(i)
+ j_int = int(j)
+
+ # The same as traditional iterative search method
+ for _ in range(maxIter):
+ if not 0 <= i <= (W - 1) or not 0 <= j <= (H - 1):
+ return i, j
+
+ u11 = padu[j_int, i_int]
+ v11 = padv[j_int, i_int]
+
+ u12 = padu[j_int, i_int + 1]
+ v12 = padv[j_int, i_int + 1]
+
+ int1 = padu[j_int + 1, i_int]
+ v21 = padv[j_int + 1, i_int]
+
+ int2 = padu[j_int + 1, i_int + 1]
+ v22 = padv[j_int + 1, i_int + 1]
+
+ u = (
+ u11 * (i_int + 1 - i) * (j_int + 1 - j)
+ + u12 * (i - i_int) * (j_int + 1 - j)
+ + int1 * (i_int + 1 - i) * (j - j_int)
+ + int2 * (i - i_int) * (j - j_int)
+ )
+
+ v = (
+ v11 * (i_int + 1 - i) * (j_int + 1 - j)
+ + v12 * (i - i_int) * (j_int + 1 - j)
+ + v21 * (i_int + 1 - i) * (j - j_int)
+ + v22 * (i - i_int) * (j - j_int)
+ )
+
+ i_next = xr - u
+ j_next = yr - v
+
+ if abs(i - i_next) < precision and abs(j - j_next) < precision:
+ return i, j
+
+ i = i_next
+ j = j_next
+
+ # if the search doesn't converge within max iter, it will return the last iter result
+ return i_next, j_next
+
+@cuda.jit
+def iterSearch_gpu(distortImg, resultImg, padu, padv, W, H, maxIter=5, precision=1e-2):
+
+ start_x, start_y = cuda.grid(2)
+ stride_x, stride_y = cuda.gridsize(2)
+
+ for xr in range(start_x, W, stride_x):
+ for yr in range(start_y, H, stride_y):
+
+ i,j = iterSearchShader_gpu(padu, padv, xr, yr, W, H, maxIter, precision)
+
+ if i > W - 1:
+ i = 2 * W - 1 - i
+ if i < 0:
+ i = -i
+ if j > H - 1:
+ j = 2 * H - 1 - j
+ if j < 0:
+ j = -j
+
+ resultImg[yr, xr,0] = biInterpolation_gpu(distortImg[:,:,0], i, j)
+ resultImg[yr, xr,1] = biInterpolation_gpu(distortImg[:,:,1], i, j)
+ resultImg[yr, xr,2] = biInterpolation_gpu(distortImg[:,:,2], i, j)
+ return None
+
+def forward_mapping_gpu(source_image, u, v, maxIter=5, precision=1e-2):
+ """
+ warp the image according to the forward flow
+ u: horizontal
+ v: vertical
+ """
+ H = source_image.shape[0]
+ W = source_image.shape[1]
+
+ resultImg = np.array(np.zeros((H, W, 3)), dtype=np.uint8)
+
+ distortImg = np.array(np.zeros((H + 1, W + 1, 3)), dtype=np.uint8)
+ distortImg[0:H, 0:W] = source_image[0:H, 0:W]
+ distortImg[H, 0:W] = source_image[H - 1, 0:W]
+ distortImg[0:H, W] = source_image[0:H, W - 1]
+ distortImg[H, W] = source_image[H - 1, W - 1]
+
+ padu = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32)
+ padu[0:H, 0:W] = u[0:H, 0:W]
+ padu[H, 0:W] = u[H - 1, 0:W]
+ padu[0:H, W] = u[0:H, W - 1]
+ padu[H, W] = u[H - 1, W - 1]
+
+ padv = np.array(np.zeros((H + 1, W + 1)), dtype=np.float32)
+ padv[0:H, 0:W] = v[0:H, 0:W]
+ padv[H, 0:W] = v[H - 1, 0:W]
+ padv[0:H, W] = v[0:H, W - 1]
+ padv[H, W] = v[H - 1, W - 1]
+
+ padu = cuda.to_device(padu)
+ padv = cuda.to_device(padv)
+ distortImg = cuda.to_device(distortImg)
+ resultImg = cuda.to_device(resultImg)
+
+ threadsperblock = (16, 16)
+ blockspergrid_x = math.ceil(W / threadsperblock[0])
+ blockspergrid_y = math.ceil(H / threadsperblock[1])
+ blockspergrid = (blockspergrid_x, blockspergrid_y)
+
+
+ iterSearch_gpu[blockspergrid, threadsperblock](distortImg, resultImg, padu, padv, W, H, maxIter, precision)
+ resultImg = resultImg.copy_to_host()
+ return resultImg
+
+class Distortion_with_flow_gpu(object):
+
+ def __init__(self, maxIter=3, precision=1e-3):
+ self.maxIter = maxIter
+ self.precision = precision
+
+ def __call__(self, inputs, dx, dy):
+ inputs = np.array(inputs)
+ shape = inputs.shape[0], inputs.shape[1]
+ remap_image = forward_mapping_gpu(inputs, dy, dx, maxIter=self.maxIter, precision=self.precision)
+
+ return Image.fromarray(remap_image)
+
+def read_flow(filename):
+ """
+ read optical flow from Middlebury .flo file
+ :param filename: name of the flow file
+ :return: optical flow data in matrix
+ """
+ f = open(filename, "rb")
+ try:
+ magic = np.fromfile(f, np.float32, count=1)[0] # For Python3.x
+ except:
+ magic = np.fromfile(f, np.float32, count=1) # For Python2.x
+ data2d = None
+ if (202021.25 != magic)and(123.25!=magic):
+ print("Magic number incorrect. Invalid .flo file")
+ elif (123.25==magic):
+ w = np.fromfile(f, np.int32, count=1)[0]
+ h = np.fromfile(f, np.int32, count=1)[0]
+ # print("Reading %d x %d flo file" % (h, w))
+ data2d = np.fromfile(f, np.float16, count=2 * w * h)
+ # reshape data into 3D array (columns, rows, channels)
+ data2d = np.resize(data2d, (h, w, 2))
+ elif (202021.25 == magic):
+ w = np.fromfile(f, np.int32, count=1)[0]
+ h = np.fromfile(f, np.int32, count=1)[0]
+ # print("Reading %d x %d flo file" % (h, w))
+ data2d = np.fromfile(f, np.float32, count=2 * w * h)
+ # reshape data into 3D array (columns, rows, channels)
+ data2d = np.resize(data2d, (h, w, 2))
+ f.close()
+ return data2d.astype(np.float32)
+
+class LossHandler:
+ def __init__(self):
+ self.loss_dict = {}
+ self.count_sample = 0
+
+ def add_loss(self, key, loss):
+ if key not in self.loss_dict:
+ self.loss_dict[key] = 0
+ self.loss_dict[key] += loss
+
+ def get_loss(self, key):
+ return self.loss_dict[key] / self.count_sample
+
+ def count_one_sample(self):
+ self.count_sample += 1
+
+ def reset(self):
+ self.loss_dict = {}
+ self.count_sample = 0
+
+
+class TimeHandler:
+ def __init__(self):
+ self.time_handler = {}
+
+ def compute_time(self, key):
+ if key not in self.time_handler:
+ self.time_handler[key] = time.time()
+ return None
+ else:
+ return time.time() - self.time_handler.pop(key)
+
+
+def print_num_params(model, is_trainable=False):
+ model_name = model.__class__.__name__.ljust(30)
+
+ if is_trainable:
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"| TRAINABLE | {model_name} | {('{:,}'.format(num_params)).rjust(10)} |")
+ else:
+ num_params = sum(p.numel() for p in model.parameters())
+ print(f"| GENERAL | {model_name} | {('{:,}'.format(num_params)).rjust(10)} |")
+
+ return num_params
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/convert_folder_to_video.py b/utils/convert_folder_to_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b0e79c2197e56028de5079cc0f59775a67dd65f
--- /dev/null
+++ b/utils/convert_folder_to_video.py
@@ -0,0 +1,29 @@
+import cv2
+import os
+import argparse
+from tqdm import tqdm
+
+def convert_frames_to_video(input_folder_path, output_path, fps=24):
+ list_frames = sorted(os.listdir(input_folder_path))
+ first_frame = cv2.imread(os.path.join(input_folder_path, list_frames[0]))
+ height, width, _ = first_frame.shape
+
+ # Create a VideoWriter object
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Use appropriate codec based on the file extension
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+
+ for frame_file in tqdm(list_frames):
+ frame_path = os.path.join(input_folder_path, frame_file)
+ frame = cv2.imread(frame_path)
+ video_writer.write(frame)
+
+ video_writer.release()
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_folder_path', type=str, required=True)
+ parser.add_argument('--output_path', type=str, required=True)
+ parser.add_argument('--fps', type=int, default=24)
+ args = parser.parse_args()
+
+ convert_frames_to_video(args.input_folder_path, args.output_path, args.fps)
\ No newline at end of file