Spaces:
Sleeping
Sleeping
| from src.models.CNN.ColorVidNet import ColorVidNet | |
| from src.models.vit.embed import SwinModel | |
| from src.models.CNN.NonlocalNet import WarpNet | |
| from src.models.CNN.FrameColor import frame_colorization | |
| import torch | |
| from src.models.vit.utils import load_params | |
| import os | |
| import cv2 | |
| from PIL import Image | |
| from PIL import ImageEnhance as IE | |
| import torchvision.transforms as T | |
| from src.utils import ( | |
| RGB2Lab, | |
| ToTensor, | |
| Normalize, | |
| uncenter_l, | |
| tensor_lab2rgb | |
| ) | |
| import numpy as np | |
| class SwinTExCo: | |
| def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None): | |
| if device == None: | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| self.device = device | |
| self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device) | |
| self.nonlocal_net = WarpNet(feature_channel=128).to(self.device) | |
| self.colornet = ColorVidNet(7).to(self.device) | |
| self.embed_net.eval() | |
| self.nonlocal_net.eval() | |
| self.colornet.eval() | |
| self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth")) | |
| self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth")) | |
| self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth")) | |
| self.processor = T.Compose([ | |
| T.Resize((224,224)), | |
| RGB2Lab(), | |
| ToTensor(), | |
| Normalize() | |
| ]) | |
| pass | |
| def __load_models(self, model, weight_path): | |
| params = load_params(weight_path, self.device) | |
| model.load_state_dict(params, strict=True) | |
| def __preprocess_reference(self, img): | |
| color_enhancer = IE.Color(img) | |
| img = color_enhancer.enhance(1.5) | |
| return img | |
| def __upscale_image(self, large_IA_l, I_current_ab_predict): | |
| H, W = large_IA_l.shape[2:] | |
| large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict, | |
| size=(H,W), | |
| mode="bilinear", | |
| align_corners=False) | |
| large_IA_l = torch.cat((large_IA_l, large_current_ab_predict), dim=1) | |
| large_current_rgb_predict = tensor_lab2rgb(large_IA_l) | |
| return large_current_rgb_predict.cpu() | |
| def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B): | |
| large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0) | |
| large_IA_l = large_IA_lab[:, 0:1, :, :].to(self.device) | |
| IA_lab = self.processor(curr_frame) | |
| IA_lab = IA_lab.unsqueeze(0).to(self.device) | |
| IA_l = IA_lab[:, 0:1, :, :] | |
| if I_last_lab_predict is None: | |
| I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device) | |
| with torch.no_grad(): | |
| I_current_ab_predict, _ = frame_colorization( | |
| IA_l, | |
| I_reference_lab, | |
| I_last_lab_predict, | |
| features_B, | |
| self.embed_net, | |
| self.nonlocal_net, | |
| self.colornet, | |
| luminance_noise=0, | |
| temperature=1e-10, | |
| joint_training=False | |
| ) | |
| I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) | |
| IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict) | |
| IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.) | |
| IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8) | |
| return I_last_lab_predict, IA_predict_rgb | |
| def predict_video(self, video, ref_image): | |
| ref_image = self.__preprocess_reference(ref_image) | |
| I_last_lab_predict = None | |
| IB_lab = self.processor(ref_image) | |
| IB_lab = IB_lab.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| I_reference_lab = IB_lab | |
| I_reference_l = I_reference_lab[:, 0:1, :, :] | |
| I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
| I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device) | |
| features_B = self.embed_net(I_reference_rgb) | |
| while video.isOpened(): | |
| ret, curr_frame = video.read() | |
| if not ret: | |
| break | |
| curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) | |
| curr_frame = Image.fromarray(curr_frame) | |
| I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B) | |
| IA_predict_rgb = IA_predict_rgb.transpose(1,2,0) | |
| yield IA_predict_rgb | |
| video.release() | |
| def predict_image(self, image, ref_image): | |
| ref_image = self.__preprocess_reference(ref_image) | |
| I_last_lab_predict = None | |
| IB_lab = self.processor(ref_image) | |
| IB_lab = IB_lab.unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| I_reference_lab = IB_lab | |
| I_reference_l = I_reference_lab[:, 0:1, :, :] | |
| I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
| I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device) | |
| features_B = self.embed_net(I_reference_rgb) | |
| curr_frame = image | |
| I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B) | |
| IA_predict_rgb = IA_predict_rgb.transpose(1,2,0) | |
| return IA_predict_rgb | |
| if __name__ == "__main__": | |
| model = SwinTExCo('checkpoints/epoch_20/') | |
| # Initialize video reader and writer | |
| video = cv2.VideoCapture('sample_input/video_2.mp4') | |
| fps = video.get(cv2.CAP_PROP_FPS) | |
| width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
| # Initialize reference image | |
| ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB') | |
| for colorized_frame in model.predict_video(video, ref_image): | |
| video_writer.write(colorized_frame) | |
| video_writer.release() |