ReSWD / src /color_matcher.py
mboss's picture
Fixes
6d7ec2d
import gc
import os
from typing import List, Literal, Optional, Tuple
import imageio
import lightning as L
import torch
from jaxtyping import Float
from torchvision.transforms import Resize
from tqdm import tqdm
from src.loss import AbstractLoss
from src.loss.vector_swd import VectorSWDLoss
from src.utils.asc_cdl import asc_cdl_forward, save_asc_cdl
from src.utils.color_space import rgb_to_lab
from src.utils.image import from_torch, read_img, to_torch, write_img
class CDL(torch.nn.Module):
def __init__(self, batch_size: int):
super().__init__()
self.cdl_slope = torch.nn.Parameter(torch.ones(batch_size, 3))
self.cdl_offset = torch.nn.Parameter(torch.zeros(batch_size, 3))
self.cdl_power = torch.nn.Parameter(torch.ones(batch_size, 3))
self.cdl_saturation = torch.nn.Parameter(torch.ones(batch_size))
def forward(
self, x: Float[torch.Tensor, "*B C H W"]
) -> Float[torch.Tensor, "*B C H W"]:
return asc_cdl_forward(
x, self.cdl_slope, self.cdl_offset, self.cdl_power, self.cdl_saturation
)
def to_cdl_xml(self) -> str:
ret = []
for b in range(self.cdl_slope.shape[0]):
ret.append(
save_asc_cdl(
{
"slope": self.cdl_slope[b],
"offset": self.cdl_offset[b],
"power": self.cdl_power[b],
"saturation": self.cdl_saturation[b],
},
None,
)
)
return ret
def save(self, path: str):
for b in range(self.cdl_slope.shape[0]):
save_asc_cdl(
{
"slope": self.cdl_slope[b],
"offset": self.cdl_offset[b],
"power": self.cdl_power[b],
"saturation": self.cdl_saturation[b],
},
os.path.join(path, f"cdl_{b}.xml"),
)
def train(
criteria: AbstractLoss,
source_img: Float[torch.Tensor, "B C H W"],
target_img: Float[torch.Tensor, "B C H W"],
num_steps: int,
lr: float,
match_resolution: int,
silent: bool = False,
write_video_animation_path: Optional[str] = None,
) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]:
criteria = criteria.cuda()
source_max_res = Resize(match_resolution, antialias=True)(source_img)
target_max_res = Resize(match_resolution, antialias=True)(target_img)
target_cielab = (
rgb_to_lab(target_max_res).cuda().permute(0, 3, 1, 2)
.permute(0, 2, 3, 1)
.contiguous()
)
source_max_res = source_max_res.cuda()
source_img = source_img.cuda()
batch_size = source_img.shape[0]
cdl = CDL(batch_size).cuda()
optim = torch.optim.Adam(cdl.parameters(), lr=lr)
lossses = []
for i in tqdm(range(num_steps), disable=silent):
optim.zero_grad(set_to_none=True)
cdl_source = cdl(source_max_res)
source_cielab = (
rgb_to_lab(cdl_source.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous()
)
loss = criteria(
source_cielab.view(source_cielab.shape[0], source_cielab.shape[1], -1),
target_cielab.view(target_cielab.shape[0], target_cielab.shape[1], -1),
i,
)
loss.backward()
optim.step()
lossses.append(loss.item())
if write_video_animation_path is not None:
write_img(
os.path.join(write_video_animation_path, f"{i:05d}.jpg"),
from_torch(cdl(source_img).squeeze(0) * 2 - 1),
)
source_full_res_cdl = cdl(source_img)
gc.collect()
torch.cuda.empty_cache()
return source_full_res_cdl, cdl, lossses
def run(
save_dir: str,
source_img: List[str],
target_img: List[str],
matching_resolution: int,
precision: Literal["32-true", "16-mixed"] = "16-mixed",
num_projections: int = 64,
lr: float = 0.01,
steps: int = 300,
use_ucv: bool = False,
use_lcv: bool = False,
distance: Literal["l1", "l2"] = "l1",
refresh_projections_every_n_steps: int = 1,
num_new_candidates: int = 32,
sampling_mode: Literal["gaussian", "qmc"] = "gaussian",
write_video: bool = False,
**kwargs,
):
source_imgs = torch.stack(
[to_torch(read_img(s)) * 0.5 + 0.5 for s in source_img], dim=0
)
target_imgs = torch.stack(
[to_torch(read_img(t)) * 0.5 + 0.5 for t in target_img], dim=0
)
criteria = VectorSWDLoss(
num_proj=num_projections,
distance=distance,
use_ucv=use_ucv,
use_lcv=use_lcv,
refresh_projections_every_n_steps=refresh_projections_every_n_steps,
num_new_candidates=num_new_candidates,
sampling_mode=sampling_mode,
)
os.makedirs(save_dir, exist_ok=True)
animation_dir = os.path.join(save_dir, "animation")
if write_video:
os.makedirs(animation_dir, exist_ok=True)
source_full_res_cdl, cdl, lossses = train(
criteria,
source_imgs,
target_imgs,
steps,
lr,
matching_resolution,
write_video_animation_path=animation_dir if write_video else None,
)
cdl.save(save_dir)
for i, img in enumerate(source_full_res_cdl):
write_img(
os.path.join(save_dir, f"color_matched_{i}.png"),
from_torch(img * 2 - 1),
)
if write_video:
# Get the list of image files in the animation directory
image_files = [f for f in os.listdir(animation_dir) if f.endswith(".jpg")]
image_files.sort(
key=lambda x: int(x.split(".")[0])
) # Ensure they are in the correct order
# Create a video from the images
with imageio.get_writer(
os.path.join(save_dir, "animation.mp4"), fps=30, codec="libx264"
) as writer:
for image_file in image_files:
image = imageio.imread(os.path.join(animation_dir, image_file))
writer.append_data(image)
return source_full_res_cdl, cdl, lossses