Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| from typing import List, Literal, Optional, Tuple | |
| import imageio | |
| import lightning as L | |
| import torch | |
| from jaxtyping import Float | |
| from torchvision.transforms import Resize | |
| from tqdm import tqdm | |
| from src.loss import AbstractLoss | |
| from src.loss.vector_swd import VectorSWDLoss | |
| from src.utils.asc_cdl import asc_cdl_forward, save_asc_cdl | |
| from src.utils.color_space import rgb_to_lab | |
| from src.utils.image import from_torch, read_img, to_torch, write_img | |
| class CDL(torch.nn.Module): | |
| def __init__(self, batch_size: int): | |
| super().__init__() | |
| self.cdl_slope = torch.nn.Parameter(torch.ones(batch_size, 3)) | |
| self.cdl_offset = torch.nn.Parameter(torch.zeros(batch_size, 3)) | |
| self.cdl_power = torch.nn.Parameter(torch.ones(batch_size, 3)) | |
| self.cdl_saturation = torch.nn.Parameter(torch.ones(batch_size)) | |
| def forward( | |
| self, x: Float[torch.Tensor, "*B C H W"] | |
| ) -> Float[torch.Tensor, "*B C H W"]: | |
| return asc_cdl_forward( | |
| x, self.cdl_slope, self.cdl_offset, self.cdl_power, self.cdl_saturation | |
| ) | |
| def to_cdl_xml(self) -> str: | |
| ret = [] | |
| for b in range(self.cdl_slope.shape[0]): | |
| ret.append( | |
| save_asc_cdl( | |
| { | |
| "slope": self.cdl_slope[b], | |
| "offset": self.cdl_offset[b], | |
| "power": self.cdl_power[b], | |
| "saturation": self.cdl_saturation[b], | |
| }, | |
| None, | |
| ) | |
| ) | |
| return ret | |
| def save(self, path: str): | |
| for b in range(self.cdl_slope.shape[0]): | |
| save_asc_cdl( | |
| { | |
| "slope": self.cdl_slope[b], | |
| "offset": self.cdl_offset[b], | |
| "power": self.cdl_power[b], | |
| "saturation": self.cdl_saturation[b], | |
| }, | |
| os.path.join(path, f"cdl_{b}.xml"), | |
| ) | |
| def train( | |
| criteria: AbstractLoss, | |
| source_img: Float[torch.Tensor, "B C H W"], | |
| target_img: Float[torch.Tensor, "B C H W"], | |
| num_steps: int, | |
| lr: float, | |
| match_resolution: int, | |
| silent: bool = False, | |
| write_video_animation_path: Optional[str] = None, | |
| ) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]: | |
| criteria = criteria.cuda() | |
| source_max_res = Resize(match_resolution, antialias=True)(source_img) | |
| target_max_res = Resize(match_resolution, antialias=True)(target_img) | |
| target_cielab = ( | |
| rgb_to_lab(target_max_res).cuda().permute(0, 3, 1, 2) | |
| .permute(0, 2, 3, 1) | |
| .contiguous() | |
| ) | |
| source_max_res = source_max_res.cuda() | |
| source_img = source_img.cuda() | |
| batch_size = source_img.shape[0] | |
| cdl = CDL(batch_size).cuda() | |
| optim = torch.optim.Adam(cdl.parameters(), lr=lr) | |
| lossses = [] | |
| for i in tqdm(range(num_steps), disable=silent): | |
| optim.zero_grad(set_to_none=True) | |
| cdl_source = cdl(source_max_res) | |
| source_cielab = ( | |
| rgb_to_lab(cdl_source.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() | |
| ) | |
| loss = criteria( | |
| source_cielab.view(source_cielab.shape[0], source_cielab.shape[1], -1), | |
| target_cielab.view(target_cielab.shape[0], target_cielab.shape[1], -1), | |
| i, | |
| ) | |
| loss.backward() | |
| optim.step() | |
| lossses.append(loss.item()) | |
| if write_video_animation_path is not None: | |
| write_img( | |
| os.path.join(write_video_animation_path, f"{i:05d}.jpg"), | |
| from_torch(cdl(source_img).squeeze(0) * 2 - 1), | |
| ) | |
| source_full_res_cdl = cdl(source_img) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return source_full_res_cdl, cdl, lossses | |
| def run( | |
| save_dir: str, | |
| source_img: List[str], | |
| target_img: List[str], | |
| matching_resolution: int, | |
| precision: Literal["32-true", "16-mixed"] = "16-mixed", | |
| num_projections: int = 64, | |
| lr: float = 0.01, | |
| steps: int = 300, | |
| use_ucv: bool = False, | |
| use_lcv: bool = False, | |
| distance: Literal["l1", "l2"] = "l1", | |
| refresh_projections_every_n_steps: int = 1, | |
| num_new_candidates: int = 32, | |
| sampling_mode: Literal["gaussian", "qmc"] = "gaussian", | |
| write_video: bool = False, | |
| **kwargs, | |
| ): | |
| source_imgs = torch.stack( | |
| [to_torch(read_img(s)) * 0.5 + 0.5 for s in source_img], dim=0 | |
| ) | |
| target_imgs = torch.stack( | |
| [to_torch(read_img(t)) * 0.5 + 0.5 for t in target_img], dim=0 | |
| ) | |
| criteria = VectorSWDLoss( | |
| num_proj=num_projections, | |
| distance=distance, | |
| use_ucv=use_ucv, | |
| use_lcv=use_lcv, | |
| refresh_projections_every_n_steps=refresh_projections_every_n_steps, | |
| num_new_candidates=num_new_candidates, | |
| sampling_mode=sampling_mode, | |
| ) | |
| os.makedirs(save_dir, exist_ok=True) | |
| animation_dir = os.path.join(save_dir, "animation") | |
| if write_video: | |
| os.makedirs(animation_dir, exist_ok=True) | |
| source_full_res_cdl, cdl, lossses = train( | |
| criteria, | |
| source_imgs, | |
| target_imgs, | |
| steps, | |
| lr, | |
| matching_resolution, | |
| write_video_animation_path=animation_dir if write_video else None, | |
| ) | |
| cdl.save(save_dir) | |
| for i, img in enumerate(source_full_res_cdl): | |
| write_img( | |
| os.path.join(save_dir, f"color_matched_{i}.png"), | |
| from_torch(img * 2 - 1), | |
| ) | |
| if write_video: | |
| # Get the list of image files in the animation directory | |
| image_files = [f for f in os.listdir(animation_dir) if f.endswith(".jpg")] | |
| image_files.sort( | |
| key=lambda x: int(x.split(".")[0]) | |
| ) # Ensure they are in the correct order | |
| # Create a video from the images | |
| with imageio.get_writer( | |
| os.path.join(save_dir, "animation.mp4"), fps=30, codec="libx264" | |
| ) as writer: | |
| for image_file in image_files: | |
| image = imageio.imread(os.path.join(animation_dir, image_file)) | |
| writer.append_data(image) | |
| return source_full_res_cdl, cdl, lossses | |