import collections import json import math import os import cv2 import gradio as gr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import hf_hub_download DEFAULT_REPO_ID = "piddnad/ddcolor_modelscope" _COLORIZER_STATE = { "initialized": False, "pipeline": None, } def _resolve_device(device=None): if device is None: return torch.device("cuda" if torch.cuda.is_available() else "cpu") if isinstance(device, str): return torch.device(device) return device def _load_checkpoint_state_dict(model_path, map_location="cpu"): checkpoint = torch.load(model_path, map_location=map_location) if isinstance(checkpoint, dict): if "params" in checkpoint: return checkpoint["params"] if "state_dict" in checkpoint: return checkpoint["state_dict"] return checkpoint def _load_model_config(config_path): with open(config_path, "r", encoding="utf-8") as handle: return json.load(handle) class DropPath(nn.Module): def __init__(self, drop_prob=0.0): super().__init__() self.drop_prob = float(drop_prob) def forward(self, x): if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1.0 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() return x.div(keep_prob) * random_tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): if hasattr(torch.nn.init, "trunc_normal_"): return torch.nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) def norm_cdf(value): return (1.0 + math.erf(value / math.sqrt(2.0))) / 2.0 with torch.no_grad(): lower = norm_cdf((a - mean) / std) upper = norm_cdf((b - mean) / std) tensor.uniform_(2 * lower - 1, 2 * upper - 1) tensor.erfinv_() tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) tensor.clamp_(min=a, max=b) return tensor class LayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format self.normalized_shape = (normalized_shape,) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) if self.data_format == "channels_first": mean = x.mean(1, keepdim=True) variance = (x - mean).pow(2).mean(1, keepdim=True) x = (x - mean) / torch.sqrt(variance + self.eps) return self.weight[:, None, None] * x + self.bias[:, None, None] raise NotImplementedError(f"Unsupported data_format: {self.data_format}") class ConvNeXtBlock(nn.Module): def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) self.norm = LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, 4 * dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): residual = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) return residual + self.drop_path(x) class ConvNeXt(nn.Module): def __init__( self, in_chans=3, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), drop_path_rate=0.0, layer_scale_init_value=1e-6, ): super().__init__() self.downsample_layers = nn.ModuleList() stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), ) self.downsample_layers.append(stem) for index in range(3): self.downsample_layers.append( nn.Sequential( LayerNorm(dims[index], eps=1e-6, data_format="channels_first"), nn.Conv2d(dims[index], dims[index + 1], kernel_size=2, stride=2), ) ) self.stages = nn.ModuleList() rates = [value.item() for value in torch.linspace(0, drop_path_rate, sum(depths))] cursor = 0 for index in range(4): stage = nn.Sequential( *[ ConvNeXtBlock( dim=dims[index], drop_path=rates[cursor + inner], layer_scale_init_value=layer_scale_init_value, ) for inner in range(depths[index]) ] ) self.stages.append(stage) cursor += depths[index] for index in range(4): self.add_module( f"norm{index}", LayerNorm(dims[index], eps=1e-6, data_format="channels_first"), ) self.norm = nn.LayerNorm(dims[-1], eps=1e-6) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Linear)): trunc_normal_(module.weight, std=0.02) nn.init.constant_(module.bias, 0) def forward(self, x): for index in range(4): x = self.downsample_layers[index](x) x = self.stages[index](x) getattr(self, f"norm{index}")(x) return self.norm(x.mean([-2, -1])) class PositionEmbeddingSine(nn.Module): def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize self.scale = scale if scale is not None else 2 * math.pi def forward(self, x, mask=None): if mask is None: mask = torch.zeros( (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool, ) not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for parameter in self.parameters(): if parameter.dim() > 1: nn.init.xavier_uniform_(parameter) def _with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos def forward(self, target, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): if self.normalize_before: target_norm = self.norm(target) query = key = self._with_pos_embed(target_norm, query_pos) target2 = self.self_attn( query, key, value=target_norm, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, )[0] return target + self.dropout(target2) query = key = self._with_pos_embed(target, query_pos) target2 = self.self_attn( query, key, value=target, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, )[0] target = target + self.dropout(target2) return self.norm(target) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for parameter in self.parameters(): if parameter.dim() > 1: nn.init.xavier_uniform_(parameter) def _with_pos_embed(self, tensor, pos): return tensor if pos is None else tensor + pos def forward( self, target, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None, ): if self.normalize_before: target_norm = self.norm(target) target2 = self.multihead_attn( query=self._with_pos_embed(target_norm, query_pos), key=self._with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] return target + self.dropout(target2) target2 = self.multihead_attn( query=self._with_pos_embed(target, query_pos), key=self._with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] target = target + self.dropout(target2) return self.norm(target) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, normalize_before=False): super().__init__() self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for parameter in self.parameters(): if parameter.dim() > 1: nn.init.xavier_uniform_(parameter) def forward(self, target): if self.normalize_before: target_norm = self.norm(target) target2 = self.linear2(self.dropout(F.relu(self.linear1(target_norm)))) return target + self.dropout(target2) target2 = self.linear2(self.dropout(F.relu(self.linear1(target)))) target = target + self.dropout(target2) return self.norm(target) class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() widths = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(in_features, out_features) for in_features, out_features in zip( [input_dim] + widths, widths + [output_dim], ) ) def forward(self, x): for index, layer in enumerate(self.layers): x = F.relu(layer(x)) if index < len(self.layers) - 1 else layer(x) return x class Hook: feature = None def __init__(self, module): self.hook = module.register_forward_hook(self._hook_fn) def _hook_fn(self, module, inputs, output): if isinstance(output, torch.Tensor): self.feature = output elif isinstance(output, collections.OrderedDict): self.feature = output["out"] def remove(self): self.hook.remove() class NormType: Spectral = "Spectral" def _batchnorm_2d(num_features): batch_norm = nn.BatchNorm2d(num_features) with torch.no_grad(): batch_norm.bias.fill_(1e-3) batch_norm.weight.fill_(1.0) return batch_norm def _init_default(module, init=nn.init.kaiming_normal_): if init is not None: if hasattr(module, "weight"): init(module.weight) if hasattr(module, "bias") and hasattr(module.bias, "data"): module.bias.data.fill_(0.0) return module def _icnr(tensor, scale=2, init=nn.init.kaiming_normal_): in_channels, out_channels, height, width = tensor.shape in_channels_scaled = int(in_channels / (scale**2)) kernel = init(torch.zeros([in_channels_scaled, out_channels, height, width])).transpose(0, 1) kernel = kernel.contiguous().view(in_channels_scaled, out_channels, -1) kernel = kernel.repeat(1, 1, scale**2) kernel = kernel.contiguous().view([out_channels, in_channels, height, width]).transpose(0, 1) tensor.data.copy_(kernel) def _custom_conv_layer( in_channels, out_channels, ks=3, stride=1, padding=None, bias=None, norm_type=NormType.Spectral, use_activation=True, transpose=False, extra_bn=False, ): if padding is None: padding = (ks - 1) // 2 if not transpose else 0 use_batch_norm = extra_bn if bias is None: bias = not use_batch_norm conv_cls = nn.ConvTranspose2d if transpose else nn.Conv2d conv = _init_default( conv_cls(in_channels, out_channels, kernel_size=ks, bias=bias, stride=stride, padding=padding) ) if norm_type == NormType.Spectral: conv = nn.utils.spectral_norm(conv) layers = [conv] if use_activation: layers.append(nn.ReLU(True)) if use_batch_norm: layers.append(nn.BatchNorm2d(out_channels)) return nn.Sequential(*layers) class CustomPixelShuffleICNR(nn.Module): def __init__(self, in_channels, out_channels, scale=2, blur=True, norm_type=NormType.Spectral, extra_bn=False): super().__init__() self.conv = _custom_conv_layer( in_channels, out_channels * (scale**2), ks=1, use_activation=False, norm_type=norm_type, extra_bn=extra_bn, ) _icnr(self.conv[0].weight) self.shuffle = nn.PixelShuffle(scale) self.blur_enabled = blur self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) self.blur = nn.AvgPool2d(2, stride=1) self.relu = nn.ReLU(True) def forward(self, x): x = self.shuffle(self.relu(self.conv(x))) return self.blur(self.pad(x)) if self.blur_enabled else x class UnetBlockWide(nn.Module): def __init__(self, up_in_channels, skip_in_channels, out_channels, hook, blur=False, norm_type=NormType.Spectral): super().__init__() self.hook = hook self.shuf = CustomPixelShuffleICNR( up_in_channels, out_channels, blur=blur, norm_type=norm_type, extra_bn=True, ) self.bn = _batchnorm_2d(skip_in_channels) self.conv = _custom_conv_layer( out_channels + skip_in_channels, out_channels, norm_type=norm_type, extra_bn=True, ) self.relu = nn.ReLU() def forward(self, x): skip = self.hook.feature x = self.shuf(x) x = self.relu(torch.cat([x, self.bn(skip)], dim=1)) return self.conv(x) class ImageEncoder(nn.Module): def __init__(self, encoder_name, hook_names): super().__init__() if encoder_name == "convnext-t": self.arch = ConvNeXt(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768)) elif encoder_name == "convnext-l": self.arch = ConvNeXt(depths=(3, 3, 27, 3), dims=(192, 384, 768, 1536)) else: raise NotImplementedError(f"Unsupported encoder: {encoder_name}") self.hooks = [Hook(self.arch._modules[name]) for name in hook_names] def forward(self, x): return self.arch(x) class MultiScaleColorDecoder(nn.Module): def __init__( self, in_channels, hidden_dim=256, num_queries=100, nheads=8, dim_feedforward=2048, dec_layers=9, pre_norm=False, color_embed_dim=256, enforce_input_project=True, num_scales=3, ): super().__init__() self.num_layers = dec_layers self.num_feature_levels = num_scales self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) self.query_feat = nn.Embedding(num_queries, hidden_dim) self.query_embed = nn.Embedding(num_queries, hidden_dim) self.level_embed = nn.Embedding(num_scales, hidden_dim) self.input_proj = nn.ModuleList() for channels in in_channels: if channels != hidden_dim or enforce_input_project: projection = nn.Conv2d(channels, hidden_dim, kernel_size=1) nn.init.kaiming_uniform_(projection.weight, a=1) if projection.bias is not None: nn.init.constant_(projection.bias, 0) self.input_proj.append(projection) else: self.input_proj.append(nn.Sequential()) self.transformer_self_attention_layers = nn.ModuleList() self.transformer_cross_attention_layers = nn.ModuleList() self.transformer_ffn_layers = nn.ModuleList() for _ in range(dec_layers): self.transformer_self_attention_layers.append( SelfAttentionLayer(hidden_dim, nheads, dropout=0.0, normalize_before=pre_norm) ) self.transformer_cross_attention_layers.append( CrossAttentionLayer(hidden_dim, nheads, dropout=0.0, normalize_before=pre_norm) ) self.transformer_ffn_layers.append( FFNLayer(hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm) ) self.decoder_norm = nn.LayerNorm(hidden_dim) self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3) def forward(self, features, image_features): src = [] pos = [] for index, feature in enumerate(features): pos.append(self.pe_layer(feature).flatten(2).permute(2, 0, 1)) src.append( ( self.input_proj[index](feature).flatten(2) + self.level_embed.weight[index][None, :, None] ).permute(2, 0, 1) ) _, batch_size, _ = src[0].shape query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) output = self.query_feat.weight.unsqueeze(1).repeat(1, batch_size, 1) for index in range(self.num_layers): level_index = index % self.num_feature_levels output = self.transformer_cross_attention_layers[index]( output, src[level_index], memory_mask=None, memory_key_padding_mask=None, pos=pos[level_index], query_pos=query_embed, ) output = self.transformer_self_attention_layers[index]( output, tgt_mask=None, tgt_key_padding_mask=None, query_pos=query_embed, ) output = self.transformer_ffn_layers[index](output) decoder_output = self.decoder_norm(output).transpose(0, 1) color_embed = self.color_embed(decoder_output) return torch.einsum("bqc,bchw->bqhw", color_embed, image_features) class DualDecoder(nn.Module): def __init__(self, hooks, nf=512, blur=True, num_queries=100, num_scales=3, dec_layers=9): super().__init__() self.hooks = hooks self.nf = nf self.blur = blur self.layers = self._make_layers() embed_dim = nf // 2 self.last_shuf = CustomPixelShuffleICNR( embed_dim, embed_dim, scale=4, blur=self.blur, norm_type=NormType.Spectral, ) self.color_decoder = MultiScaleColorDecoder( in_channels=[512, 512, 256], num_queries=num_queries, num_scales=num_scales, dec_layers=dec_layers, ) def _make_layers(self): layers = [] in_channels = self.hooks[-1].feature.shape[1] out_channels = self.nf setup_hooks = self.hooks[-2::-1] for index, hook in enumerate(setup_hooks): skip_channels = hook.feature.shape[1] if index == len(setup_hooks) - 1: out_channels = out_channels // 2 layers.append( UnetBlockWide( in_channels, skip_channels, out_channels, hook, blur=self.blur, norm_type=NormType.Spectral, ) ) in_channels = out_channels return nn.Sequential(*layers) def forward(self): encoded = self.hooks[-1].feature out0 = self.layers[0](encoded) out1 = self.layers[1](out0) out2 = self.layers[2](out1) out3 = self.last_shuf(out2) return self.color_decoder([out0, out1, out2], out3) class DDColor(nn.Module): def __init__( self, encoder_name="convnext-l", decoder_name="MultiScaleColorDecoder", num_input_channels=3, input_size=(256, 256), nf=512, num_output_channels=2, last_norm="Spectral", do_normalize=False, num_queries=100, num_scales=3, dec_layers=9, ): super().__init__() if decoder_name != "MultiScaleColorDecoder": raise NotImplementedError(f"Unsupported decoder: {decoder_name}") if last_norm != "Spectral": raise NotImplementedError(f"Unsupported last_norm: {last_norm}") self.encoder = ImageEncoder(encoder_name, ["norm0", "norm1", "norm2", "norm3"]) self.encoder.eval() test_input = torch.randn(1, num_input_channels, *input_size) with torch.no_grad(): self.encoder(test_input) self.decoder = DualDecoder( self.encoder.hooks, nf=nf, num_queries=num_queries, num_scales=num_scales, dec_layers=dec_layers, ) self.refine_net = nn.Sequential( _custom_conv_layer( num_queries + 3, num_output_channels, ks=1, use_activation=False, norm_type=NormType.Spectral, ) ) self.do_normalize = do_normalize self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def normalize(self, image): return (image - self.mean) / self.std def denormalize(self, image): return image * self.std + self.mean def forward(self, image): if image.shape[1] == 3: image = self.normalize(image) self.encoder(image) decoded = self.decoder() coarse_input = torch.cat([decoded, image], dim=1) output = self.refine_net(coarse_input) if self.do_normalize: output = self.denormalize(output) return output class ColorizationPipeline: def __init__(self, model, input_size=512, device=None): self.input_size = int(input_size) self.device = _resolve_device(device) self.model = model.to(self.device) self.model.eval() def process(self, image_bgr): context = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad with context(): if image_bgr is None: raise ValueError("image is None") height, width = image_bgr.shape[:2] image = (image_bgr / 255.0).astype(np.float32) orig_l = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)[:, :, :1] resized = cv2.resize(image, (self.input_size, self.input_size)) resized_l = cv2.cvtColor(resized, cv2.COLOR_BGR2Lab)[:, :, :1] gray_lab = np.concatenate( (resized_l, np.zeros_like(resized_l), np.zeros_like(resized_l)), axis=-1, ) gray_rgb = cv2.cvtColor(gray_lab, cv2.COLOR_LAB2RGB) tensor = ( torch.from_numpy(gray_rgb.transpose((2, 0, 1))) .float() .unsqueeze(0) .to(self.device) ) output_ab = self.model(tensor).cpu() resized_ab = ( F.interpolate(output_ab, size=(height, width))[0] .float() .numpy() .transpose(1, 2, 0) ) output_lab = np.concatenate((orig_l, resized_ab), axis=-1) output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR) return (output_bgr * 255.0).round().astype(np.uint8) def build_colorizer(repo_id=DEFAULT_REPO_ID, device=None): device = _resolve_device(device) config_path = hf_hub_download(repo_id=repo_id, filename="config.json") weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") config = _load_model_config(config_path) model = DDColor(**config) state_dict = _load_checkpoint_state_dict(weights_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) model = model.to(device) model.eval() input_size = config.get("input_size", [512, 512])[0] return ColorizationPipeline(model, input_size=input_size, device=device) def _get_colorizer(): if _COLORIZER_STATE["initialized"]: return _COLORIZER_STATE["pipeline"] try: colorizer = build_colorizer( repo_id=os.getenv("DDCOLOR_REPO_ID", DEFAULT_REPO_ID), ) except Exception as error: raise gr.Error( "Failed to initialize the DDColor model from Hugging Face Hub. " f"Error: {str(error)[:200]}" ) _COLORIZER_STATE.update( { "initialized": True, "pipeline": colorizer, } ) return colorizer def _normalize_input_image(image): if image.ndim == 2: return np.stack([image, image, image], axis=-1) if image.shape[-1] == 4: return image[..., :3] return image def color(image): if image is None: raise gr.Error("Please upload an image.") image = _normalize_input_image(image) colorizer = _get_colorizer() result_bgr = colorizer.process(image[..., ::-1]) result_rgb = result_bgr[..., ::-1] print("infer finished!") return (image, result_rgb) def clear_ui(): return None, None examples = [["./input.jpg"]] with gr.Blocks(fill_width=True) as demo: with gr.Row(): with gr.Column(): input_image = gr.Image( type="numpy", label="Old Photo", ) with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Colorize", variant="primary") with gr.Column(): comparison_output = gr.ImageSlider( type="numpy", slider_position=50, label="Before / After", ) gr.Examples( examples=examples, inputs=input_image, outputs=comparison_output, fn=color, cache_examples=True, cache_mode="eager", preload=0, ) submit_btn.click( fn=color, inputs=input_image, outputs=comparison_output, ) input_image.input( fn=lambda: None, outputs=comparison_output, ) clear_btn.click( fn=clear_ui, outputs=[input_image, comparison_output], ) if __name__ == "__main__": demo.queue().launch( share=False, ssr_mode=False, theme="Nymbo/Nymbo_Theme", footer_links=[], )