| | import gc |
| | import os |
| | from typing import List, Literal, Optional, Tuple |
| |
|
| | import imageio |
| | import lightning as L |
| | import torch |
| | from jaxtyping import Float |
| | from torchvision.transforms import Resize |
| | from tqdm import tqdm |
| |
|
| | from src.loss import AbstractLoss |
| | from src.loss.vector_swd import VectorSWDLoss |
| | from src.utils.asc_cdl import asc_cdl_forward, save_asc_cdl |
| | from src.utils.color_space import rgb_to_lab |
| | from src.utils.image import from_torch, read_img, to_torch, write_img |
| |
|
| |
|
| | class CDL(torch.nn.Module): |
| | def __init__(self, batch_size: int): |
| | super().__init__() |
| | self.cdl_slope = torch.nn.Parameter(torch.ones(batch_size, 3)) |
| | self.cdl_offset = torch.nn.Parameter(torch.zeros(batch_size, 3)) |
| | self.cdl_power = torch.nn.Parameter(torch.ones(batch_size, 3)) |
| | self.cdl_saturation = torch.nn.Parameter(torch.ones(batch_size)) |
| |
|
| | def forward( |
| | self, x: Float[torch.Tensor, "*B C H W"] |
| | ) -> Float[torch.Tensor, "*B C H W"]: |
| | return asc_cdl_forward( |
| | x, self.cdl_slope, self.cdl_offset, self.cdl_power, self.cdl_saturation |
| | ) |
| |
|
| | def to_cdl_xml(self) -> str: |
| | ret = [] |
| | for b in range(self.cdl_slope.shape[0]): |
| | ret.append( |
| | save_asc_cdl( |
| | { |
| | "slope": self.cdl_slope[b], |
| | "offset": self.cdl_offset[b], |
| | "power": self.cdl_power[b], |
| | "saturation": self.cdl_saturation[b], |
| | }, |
| | None, |
| | ) |
| | ) |
| | return ret |
| |
|
| | def save(self, path: str): |
| | for b in range(self.cdl_slope.shape[0]): |
| | save_asc_cdl( |
| | { |
| | "slope": self.cdl_slope[b], |
| | "offset": self.cdl_offset[b], |
| | "power": self.cdl_power[b], |
| | "saturation": self.cdl_saturation[b], |
| | }, |
| | os.path.join(path, f"cdl_{b}.xml"), |
| | ) |
| |
|
| |
|
| | def train( |
| | criteria: AbstractLoss, |
| | source_img: Float[torch.Tensor, "B C H W"], |
| | target_img: Float[torch.Tensor, "B C H W"], |
| | num_steps: int, |
| | lr: float, |
| | match_resolution: int, |
| | silent: bool = False, |
| | write_video_animation_path: Optional[str] = None, |
| | ) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]: |
| | criteria = criteria.cuda() |
| |
|
| | source_max_res = Resize(match_resolution, antialias=True)(source_img) |
| | target_max_res = Resize(match_resolution, antialias=True)(target_img) |
| |
|
| | target_cielab = ( |
| | rgb_to_lab(target_max_res).cuda().permute(0, 3, 1, 2) |
| | .permute(0, 2, 3, 1) |
| | .contiguous() |
| | ) |
| |
|
| | source_max_res = source_max_res.cuda() |
| | source_img = source_img.cuda() |
| |
|
| | batch_size = source_img.shape[0] |
| | cdl = CDL(batch_size).cuda() |
| |
|
| | optim = torch.optim.Adam(cdl.parameters(), lr=lr) |
| |
|
| | lossses = [] |
| | for i in tqdm(range(num_steps), disable=silent): |
| | optim.zero_grad(set_to_none=True) |
| |
|
| | cdl_source = cdl(source_max_res) |
| | source_cielab = ( |
| | rgb_to_lab(cdl_source.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() |
| | ) |
| |
|
| | loss = criteria( |
| | source_cielab.view(source_cielab.shape[0], source_cielab.shape[1], -1), |
| | target_cielab.view(target_cielab.shape[0], target_cielab.shape[1], -1), |
| | i, |
| | ) |
| |
|
| | loss.backward() |
| | optim.step() |
| |
|
| | lossses.append(loss.item()) |
| |
|
| | if write_video_animation_path is not None: |
| | write_img( |
| | os.path.join(write_video_animation_path, f"{i:05d}.jpg"), |
| | from_torch(cdl(source_img).squeeze(0) * 2 - 1), |
| | ) |
| |
|
| | source_full_res_cdl = cdl(source_img) |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | return source_full_res_cdl, cdl, lossses |
| |
|
| |
|
| | def run( |
| | save_dir: str, |
| | source_img: List[str], |
| | target_img: List[str], |
| | matching_resolution: int, |
| | precision: Literal["32-true", "16-mixed"] = "16-mixed", |
| | num_projections: int = 64, |
| | lr: float = 0.01, |
| | steps: int = 300, |
| | use_ucv: bool = False, |
| | use_lcv: bool = False, |
| | distance: Literal["l1", "l2"] = "l1", |
| | refresh_projections_every_n_steps: int = 1, |
| | num_new_candidates: int = 32, |
| | sampling_mode: Literal["gaussian", "qmc"] = "gaussian", |
| | write_video: bool = False, |
| | **kwargs, |
| | ): |
| | source_imgs = torch.stack( |
| | [to_torch(read_img(s)) * 0.5 + 0.5 for s in source_img], dim=0 |
| | ) |
| | target_imgs = torch.stack( |
| | [to_torch(read_img(t)) * 0.5 + 0.5 for t in target_img], dim=0 |
| | ) |
| |
|
| | criteria = VectorSWDLoss( |
| | num_proj=num_projections, |
| | distance=distance, |
| | use_ucv=use_ucv, |
| | use_lcv=use_lcv, |
| | refresh_projections_every_n_steps=refresh_projections_every_n_steps, |
| | num_new_candidates=num_new_candidates, |
| | sampling_mode=sampling_mode, |
| | ) |
| |
|
| | os.makedirs(save_dir, exist_ok=True) |
| | animation_dir = os.path.join(save_dir, "animation") |
| |
|
| | if write_video: |
| | os.makedirs(animation_dir, exist_ok=True) |
| |
|
| | source_full_res_cdl, cdl, lossses = train( |
| | criteria, |
| | source_imgs, |
| | target_imgs, |
| | steps, |
| | lr, |
| | matching_resolution, |
| | write_video_animation_path=animation_dir if write_video else None, |
| | ) |
| |
|
| | cdl.save(save_dir) |
| |
|
| | for i, img in enumerate(source_full_res_cdl): |
| | write_img( |
| | os.path.join(save_dir, f"color_matched_{i}.png"), |
| | from_torch(img * 2 - 1), |
| | ) |
| |
|
| | if write_video: |
| | |
| | image_files = [f for f in os.listdir(animation_dir) if f.endswith(".jpg")] |
| | image_files.sort( |
| | key=lambda x: int(x.split(".")[0]) |
| | ) |
| |
|
| | |
| | with imageio.get_writer( |
| | os.path.join(save_dir, "animation.mp4"), fps=30, codec="libx264" |
| | ) as writer: |
| | for image_file in image_files: |
| | image = imageio.imread(os.path.join(animation_dir, image_file)) |
| | writer.append_data(image) |
| |
|
| | return source_full_res_cdl, cdl, lossses |
| |
|