Spaces:
Running
Running
| import collections | |
| import json | |
| import math | |
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| DEFAULT_REPO_ID = "piddnad/ddcolor_modelscope" | |
| _COLORIZER_STATE = { | |
| "initialized": False, | |
| "pipeline": None, | |
| } | |
| def _resolve_device(device=None): | |
| if device is None: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if isinstance(device, str): | |
| return torch.device(device) | |
| return device | |
| def _load_checkpoint_state_dict(model_path, map_location="cpu"): | |
| checkpoint = torch.load(model_path, map_location=map_location) | |
| if isinstance(checkpoint, dict): | |
| if "params" in checkpoint: | |
| return checkpoint["params"] | |
| if "state_dict" in checkpoint: | |
| return checkpoint["state_dict"] | |
| return checkpoint | |
| def _load_model_config(config_path): | |
| with open(config_path, "r", encoding="utf-8") as handle: | |
| return json.load(handle) | |
| class DropPath(nn.Module): | |
| def __init__(self, drop_prob=0.0): | |
| super().__init__() | |
| self.drop_prob = float(drop_prob) | |
| def forward(self, x): | |
| if self.drop_prob == 0.0 or not self.training: | |
| return x | |
| keep_prob = 1.0 - self.drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
| random_tensor.floor_() | |
| return x.div(keep_prob) * random_tensor | |
| def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | |
| if hasattr(torch.nn.init, "trunc_normal_"): | |
| return torch.nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) | |
| def norm_cdf(value): | |
| return (1.0 + math.erf(value / math.sqrt(2.0))) / 2.0 | |
| with torch.no_grad(): | |
| lower = norm_cdf((a - mean) / std) | |
| upper = norm_cdf((b - mean) / std) | |
| tensor.uniform_(2 * lower - 1, 2 * upper - 1) | |
| tensor.erfinv_() | |
| tensor.mul_(std * math.sqrt(2.0)) | |
| tensor.add_(mean) | |
| tensor.clamp_(min=a, max=b) | |
| return tensor | |
| class LayerNorm(nn.Module): | |
| def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.data_format = data_format | |
| self.normalized_shape = (normalized_shape,) | |
| def forward(self, x): | |
| if self.data_format == "channels_last": | |
| return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| if self.data_format == "channels_first": | |
| mean = x.mean(1, keepdim=True) | |
| variance = (x - mean).pow(2).mean(1, keepdim=True) | |
| x = (x - mean) / torch.sqrt(variance + self.eps) | |
| return self.weight[:, None, None] * x + self.bias[:, None, None] | |
| raise NotImplementedError(f"Unsupported data_format: {self.data_format}") | |
| class ConvNeXtBlock(nn.Module): | |
| def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): | |
| super().__init__() | |
| self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) | |
| self.norm = LayerNorm(dim, eps=1e-6) | |
| self.pwconv1 = nn.Linear(dim, 4 * dim) | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(4 * dim, dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
| if layer_scale_init_value > 0 | |
| else None | |
| ) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| def forward(self, x): | |
| residual = x | |
| x = self.dwconv(x) | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.norm(x) | |
| x = self.pwconv1(x) | |
| x = self.act(x) | |
| x = self.pwconv2(x) | |
| if self.gamma is not None: | |
| x = self.gamma * x | |
| x = x.permute(0, 3, 1, 2) | |
| return residual + self.drop_path(x) | |
| class ConvNeXt(nn.Module): | |
| def __init__( | |
| self, | |
| in_chans=3, | |
| depths=(3, 3, 9, 3), | |
| dims=(96, 192, 384, 768), | |
| drop_path_rate=0.0, | |
| layer_scale_init_value=1e-6, | |
| ): | |
| super().__init__() | |
| self.downsample_layers = nn.ModuleList() | |
| stem = nn.Sequential( | |
| nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), | |
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | |
| ) | |
| self.downsample_layers.append(stem) | |
| for index in range(3): | |
| self.downsample_layers.append( | |
| nn.Sequential( | |
| LayerNorm(dims[index], eps=1e-6, data_format="channels_first"), | |
| nn.Conv2d(dims[index], dims[index + 1], kernel_size=2, stride=2), | |
| ) | |
| ) | |
| self.stages = nn.ModuleList() | |
| rates = [value.item() for value in torch.linspace(0, drop_path_rate, sum(depths))] | |
| cursor = 0 | |
| for index in range(4): | |
| stage = nn.Sequential( | |
| *[ | |
| ConvNeXtBlock( | |
| dim=dims[index], | |
| drop_path=rates[cursor + inner], | |
| layer_scale_init_value=layer_scale_init_value, | |
| ) | |
| for inner in range(depths[index]) | |
| ] | |
| ) | |
| self.stages.append(stage) | |
| cursor += depths[index] | |
| for index in range(4): | |
| self.add_module( | |
| f"norm{index}", | |
| LayerNorm(dims[index], eps=1e-6, data_format="channels_first"), | |
| ) | |
| self.norm = nn.LayerNorm(dims[-1], eps=1e-6) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Conv2d, nn.Linear)): | |
| trunc_normal_(module.weight, std=0.02) | |
| nn.init.constant_(module.bias, 0) | |
| def forward(self, x): | |
| for index in range(4): | |
| x = self.downsample_layers[index](x) | |
| x = self.stages[index](x) | |
| getattr(self, f"norm{index}")(x) | |
| return self.norm(x.mean([-2, -1])) | |
| class PositionEmbeddingSine(nn.Module): | |
| def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): | |
| super().__init__() | |
| self.num_pos_feats = num_pos_feats | |
| self.temperature = temperature | |
| self.normalize = normalize | |
| self.scale = scale if scale is not None else 2 * math.pi | |
| def forward(self, x, mask=None): | |
| if mask is None: | |
| mask = torch.zeros( | |
| (x.size(0), x.size(2), x.size(3)), | |
| device=x.device, | |
| dtype=torch.bool, | |
| ) | |
| not_mask = ~mask | |
| y_embed = not_mask.cumsum(1, dtype=torch.float32) | |
| x_embed = not_mask.cumsum(2, dtype=torch.float32) | |
| if self.normalize: | |
| eps = 1e-6 | |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) | |
| dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) | |
| pos_x = x_embed[:, :, :, None] / dim_t | |
| pos_y = y_embed[:, :, :, None] / dim_t | |
| pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) | |
| pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) | |
| return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
| class SelfAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, normalize_before=False): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for parameter in self.parameters(): | |
| if parameter.dim() > 1: | |
| nn.init.xavier_uniform_(parameter) | |
| def _with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward(self, target, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): | |
| if self.normalize_before: | |
| target_norm = self.norm(target) | |
| query = key = self._with_pos_embed(target_norm, query_pos) | |
| target2 = self.self_attn( | |
| query, | |
| key, | |
| value=target_norm, | |
| attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask, | |
| )[0] | |
| return target + self.dropout(target2) | |
| query = key = self._with_pos_embed(target, query_pos) | |
| target2 = self.self_attn( | |
| query, | |
| key, | |
| value=target, | |
| attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask, | |
| )[0] | |
| target = target + self.dropout(target2) | |
| return self.norm(target) | |
| class CrossAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, normalize_before=False): | |
| super().__init__() | |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for parameter in self.parameters(): | |
| if parameter.dim() > 1: | |
| nn.init.xavier_uniform_(parameter) | |
| def _with_pos_embed(self, tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward( | |
| self, | |
| target, | |
| memory, | |
| memory_mask=None, | |
| memory_key_padding_mask=None, | |
| pos=None, | |
| query_pos=None, | |
| ): | |
| if self.normalize_before: | |
| target_norm = self.norm(target) | |
| target2 = self.multihead_attn( | |
| query=self._with_pos_embed(target_norm, query_pos), | |
| key=self._with_pos_embed(memory, pos), | |
| value=memory, | |
| attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask, | |
| )[0] | |
| return target + self.dropout(target2) | |
| target2 = self.multihead_attn( | |
| query=self._with_pos_embed(target, query_pos), | |
| key=self._with_pos_embed(memory, pos), | |
| value=memory, | |
| attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask, | |
| )[0] | |
| target = target + self.dropout(target2) | |
| return self.norm(target) | |
| class FFNLayer(nn.Module): | |
| def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, normalize_before=False): | |
| super().__init__() | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for parameter in self.parameters(): | |
| if parameter.dim() > 1: | |
| nn.init.xavier_uniform_(parameter) | |
| def forward(self, target): | |
| if self.normalize_before: | |
| target_norm = self.norm(target) | |
| target2 = self.linear2(self.dropout(F.relu(self.linear1(target_norm)))) | |
| return target + self.dropout(target2) | |
| target2 = self.linear2(self.dropout(F.relu(self.linear1(target)))) | |
| target = target + self.dropout(target2) | |
| return self.norm(target) | |
| class MLP(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| widths = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Linear(in_features, out_features) | |
| for in_features, out_features in zip( | |
| [input_dim] + widths, | |
| widths + [output_dim], | |
| ) | |
| ) | |
| def forward(self, x): | |
| for index, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if index < len(self.layers) - 1 else layer(x) | |
| return x | |
| class Hook: | |
| feature = None | |
| def __init__(self, module): | |
| self.hook = module.register_forward_hook(self._hook_fn) | |
| def _hook_fn(self, module, inputs, output): | |
| if isinstance(output, torch.Tensor): | |
| self.feature = output | |
| elif isinstance(output, collections.OrderedDict): | |
| self.feature = output["out"] | |
| def remove(self): | |
| self.hook.remove() | |
| class NormType: | |
| Spectral = "Spectral" | |
| def _batchnorm_2d(num_features): | |
| batch_norm = nn.BatchNorm2d(num_features) | |
| with torch.no_grad(): | |
| batch_norm.bias.fill_(1e-3) | |
| batch_norm.weight.fill_(1.0) | |
| return batch_norm | |
| def _init_default(module, init=nn.init.kaiming_normal_): | |
| if init is not None: | |
| if hasattr(module, "weight"): | |
| init(module.weight) | |
| if hasattr(module, "bias") and hasattr(module.bias, "data"): | |
| module.bias.data.fill_(0.0) | |
| return module | |
| def _icnr(tensor, scale=2, init=nn.init.kaiming_normal_): | |
| in_channels, out_channels, height, width = tensor.shape | |
| in_channels_scaled = int(in_channels / (scale**2)) | |
| kernel = init(torch.zeros([in_channels_scaled, out_channels, height, width])).transpose(0, 1) | |
| kernel = kernel.contiguous().view(in_channels_scaled, out_channels, -1) | |
| kernel = kernel.repeat(1, 1, scale**2) | |
| kernel = kernel.contiguous().view([out_channels, in_channels, height, width]).transpose(0, 1) | |
| tensor.data.copy_(kernel) | |
| def _custom_conv_layer( | |
| in_channels, | |
| out_channels, | |
| ks=3, | |
| stride=1, | |
| padding=None, | |
| bias=None, | |
| norm_type=NormType.Spectral, | |
| use_activation=True, | |
| transpose=False, | |
| extra_bn=False, | |
| ): | |
| if padding is None: | |
| padding = (ks - 1) // 2 if not transpose else 0 | |
| use_batch_norm = extra_bn | |
| if bias is None: | |
| bias = not use_batch_norm | |
| conv_cls = nn.ConvTranspose2d if transpose else nn.Conv2d | |
| conv = _init_default( | |
| conv_cls(in_channels, out_channels, kernel_size=ks, bias=bias, stride=stride, padding=padding) | |
| ) | |
| if norm_type == NormType.Spectral: | |
| conv = nn.utils.spectral_norm(conv) | |
| layers = [conv] | |
| if use_activation: | |
| layers.append(nn.ReLU(True)) | |
| if use_batch_norm: | |
| layers.append(nn.BatchNorm2d(out_channels)) | |
| return nn.Sequential(*layers) | |
| class CustomPixelShuffleICNR(nn.Module): | |
| def __init__(self, in_channels, out_channels, scale=2, blur=True, norm_type=NormType.Spectral, extra_bn=False): | |
| super().__init__() | |
| self.conv = _custom_conv_layer( | |
| in_channels, | |
| out_channels * (scale**2), | |
| ks=1, | |
| use_activation=False, | |
| norm_type=norm_type, | |
| extra_bn=extra_bn, | |
| ) | |
| _icnr(self.conv[0].weight) | |
| self.shuffle = nn.PixelShuffle(scale) | |
| self.blur_enabled = blur | |
| self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) | |
| self.blur = nn.AvgPool2d(2, stride=1) | |
| self.relu = nn.ReLU(True) | |
| def forward(self, x): | |
| x = self.shuffle(self.relu(self.conv(x))) | |
| return self.blur(self.pad(x)) if self.blur_enabled else x | |
| class UnetBlockWide(nn.Module): | |
| def __init__(self, up_in_channels, skip_in_channels, out_channels, hook, blur=False, norm_type=NormType.Spectral): | |
| super().__init__() | |
| self.hook = hook | |
| self.shuf = CustomPixelShuffleICNR( | |
| up_in_channels, | |
| out_channels, | |
| blur=blur, | |
| norm_type=norm_type, | |
| extra_bn=True, | |
| ) | |
| self.bn = _batchnorm_2d(skip_in_channels) | |
| self.conv = _custom_conv_layer( | |
| out_channels + skip_in_channels, | |
| out_channels, | |
| norm_type=norm_type, | |
| extra_bn=True, | |
| ) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| skip = self.hook.feature | |
| x = self.shuf(x) | |
| x = self.relu(torch.cat([x, self.bn(skip)], dim=1)) | |
| return self.conv(x) | |
| class ImageEncoder(nn.Module): | |
| def __init__(self, encoder_name, hook_names): | |
| super().__init__() | |
| if encoder_name == "convnext-t": | |
| self.arch = ConvNeXt(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768)) | |
| elif encoder_name == "convnext-l": | |
| self.arch = ConvNeXt(depths=(3, 3, 27, 3), dims=(192, 384, 768, 1536)) | |
| else: | |
| raise NotImplementedError(f"Unsupported encoder: {encoder_name}") | |
| self.hooks = [Hook(self.arch._modules[name]) for name in hook_names] | |
| def forward(self, x): | |
| return self.arch(x) | |
| class MultiScaleColorDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| hidden_dim=256, | |
| num_queries=100, | |
| nheads=8, | |
| dim_feedforward=2048, | |
| dec_layers=9, | |
| pre_norm=False, | |
| color_embed_dim=256, | |
| enforce_input_project=True, | |
| num_scales=3, | |
| ): | |
| super().__init__() | |
| self.num_layers = dec_layers | |
| self.num_feature_levels = num_scales | |
| self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) | |
| self.query_feat = nn.Embedding(num_queries, hidden_dim) | |
| self.query_embed = nn.Embedding(num_queries, hidden_dim) | |
| self.level_embed = nn.Embedding(num_scales, hidden_dim) | |
| self.input_proj = nn.ModuleList() | |
| for channels in in_channels: | |
| if channels != hidden_dim or enforce_input_project: | |
| projection = nn.Conv2d(channels, hidden_dim, kernel_size=1) | |
| nn.init.kaiming_uniform_(projection.weight, a=1) | |
| if projection.bias is not None: | |
| nn.init.constant_(projection.bias, 0) | |
| self.input_proj.append(projection) | |
| else: | |
| self.input_proj.append(nn.Sequential()) | |
| self.transformer_self_attention_layers = nn.ModuleList() | |
| self.transformer_cross_attention_layers = nn.ModuleList() | |
| self.transformer_ffn_layers = nn.ModuleList() | |
| for _ in range(dec_layers): | |
| self.transformer_self_attention_layers.append( | |
| SelfAttentionLayer(hidden_dim, nheads, dropout=0.0, normalize_before=pre_norm) | |
| ) | |
| self.transformer_cross_attention_layers.append( | |
| CrossAttentionLayer(hidden_dim, nheads, dropout=0.0, normalize_before=pre_norm) | |
| ) | |
| self.transformer_ffn_layers.append( | |
| FFNLayer(hidden_dim, dim_feedforward=dim_feedforward, dropout=0.0, normalize_before=pre_norm) | |
| ) | |
| self.decoder_norm = nn.LayerNorm(hidden_dim) | |
| self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3) | |
| def forward(self, features, image_features): | |
| src = [] | |
| pos = [] | |
| for index, feature in enumerate(features): | |
| pos.append(self.pe_layer(feature).flatten(2).permute(2, 0, 1)) | |
| src.append( | |
| ( | |
| self.input_proj[index](feature).flatten(2) | |
| + self.level_embed.weight[index][None, :, None] | |
| ).permute(2, 0, 1) | |
| ) | |
| _, batch_size, _ = src[0].shape | |
| query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) | |
| output = self.query_feat.weight.unsqueeze(1).repeat(1, batch_size, 1) | |
| for index in range(self.num_layers): | |
| level_index = index % self.num_feature_levels | |
| output = self.transformer_cross_attention_layers[index]( | |
| output, | |
| src[level_index], | |
| memory_mask=None, | |
| memory_key_padding_mask=None, | |
| pos=pos[level_index], | |
| query_pos=query_embed, | |
| ) | |
| output = self.transformer_self_attention_layers[index]( | |
| output, | |
| tgt_mask=None, | |
| tgt_key_padding_mask=None, | |
| query_pos=query_embed, | |
| ) | |
| output = self.transformer_ffn_layers[index](output) | |
| decoder_output = self.decoder_norm(output).transpose(0, 1) | |
| color_embed = self.color_embed(decoder_output) | |
| return torch.einsum("bqc,bchw->bqhw", color_embed, image_features) | |
| class DualDecoder(nn.Module): | |
| def __init__(self, hooks, nf=512, blur=True, num_queries=100, num_scales=3, dec_layers=9): | |
| super().__init__() | |
| self.hooks = hooks | |
| self.nf = nf | |
| self.blur = blur | |
| self.layers = self._make_layers() | |
| embed_dim = nf // 2 | |
| self.last_shuf = CustomPixelShuffleICNR( | |
| embed_dim, | |
| embed_dim, | |
| scale=4, | |
| blur=self.blur, | |
| norm_type=NormType.Spectral, | |
| ) | |
| self.color_decoder = MultiScaleColorDecoder( | |
| in_channels=[512, 512, 256], | |
| num_queries=num_queries, | |
| num_scales=num_scales, | |
| dec_layers=dec_layers, | |
| ) | |
| def _make_layers(self): | |
| layers = [] | |
| in_channels = self.hooks[-1].feature.shape[1] | |
| out_channels = self.nf | |
| setup_hooks = self.hooks[-2::-1] | |
| for index, hook in enumerate(setup_hooks): | |
| skip_channels = hook.feature.shape[1] | |
| if index == len(setup_hooks) - 1: | |
| out_channels = out_channels // 2 | |
| layers.append( | |
| UnetBlockWide( | |
| in_channels, | |
| skip_channels, | |
| out_channels, | |
| hook, | |
| blur=self.blur, | |
| norm_type=NormType.Spectral, | |
| ) | |
| ) | |
| in_channels = out_channels | |
| return nn.Sequential(*layers) | |
| def forward(self): | |
| encoded = self.hooks[-1].feature | |
| out0 = self.layers[0](encoded) | |
| out1 = self.layers[1](out0) | |
| out2 = self.layers[2](out1) | |
| out3 = self.last_shuf(out2) | |
| return self.color_decoder([out0, out1, out2], out3) | |
| class DDColor(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_name="convnext-l", | |
| decoder_name="MultiScaleColorDecoder", | |
| num_input_channels=3, | |
| input_size=(256, 256), | |
| nf=512, | |
| num_output_channels=2, | |
| last_norm="Spectral", | |
| do_normalize=False, | |
| num_queries=100, | |
| num_scales=3, | |
| dec_layers=9, | |
| ): | |
| super().__init__() | |
| if decoder_name != "MultiScaleColorDecoder": | |
| raise NotImplementedError(f"Unsupported decoder: {decoder_name}") | |
| if last_norm != "Spectral": | |
| raise NotImplementedError(f"Unsupported last_norm: {last_norm}") | |
| self.encoder = ImageEncoder(encoder_name, ["norm0", "norm1", "norm2", "norm3"]) | |
| self.encoder.eval() | |
| test_input = torch.randn(1, num_input_channels, *input_size) | |
| with torch.no_grad(): | |
| self.encoder(test_input) | |
| self.decoder = DualDecoder( | |
| self.encoder.hooks, | |
| nf=nf, | |
| num_queries=num_queries, | |
| num_scales=num_scales, | |
| dec_layers=dec_layers, | |
| ) | |
| self.refine_net = nn.Sequential( | |
| _custom_conv_layer( | |
| num_queries + 3, | |
| num_output_channels, | |
| ks=1, | |
| use_activation=False, | |
| norm_type=NormType.Spectral, | |
| ) | |
| ) | |
| self.do_normalize = do_normalize | |
| self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
| self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
| def normalize(self, image): | |
| return (image - self.mean) / self.std | |
| def denormalize(self, image): | |
| return image * self.std + self.mean | |
| def forward(self, image): | |
| if image.shape[1] == 3: | |
| image = self.normalize(image) | |
| self.encoder(image) | |
| decoded = self.decoder() | |
| coarse_input = torch.cat([decoded, image], dim=1) | |
| output = self.refine_net(coarse_input) | |
| if self.do_normalize: | |
| output = self.denormalize(output) | |
| return output | |
| class ColorizationPipeline: | |
| def __init__(self, model, input_size=512, device=None): | |
| self.input_size = int(input_size) | |
| self.device = _resolve_device(device) | |
| self.model = model.to(self.device) | |
| self.model.eval() | |
| def process(self, image_bgr): | |
| context = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad | |
| with context(): | |
| if image_bgr is None: | |
| raise ValueError("image is None") | |
| height, width = image_bgr.shape[:2] | |
| image = (image_bgr / 255.0).astype(np.float32) | |
| orig_l = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)[:, :, :1] | |
| resized = cv2.resize(image, (self.input_size, self.input_size)) | |
| resized_l = cv2.cvtColor(resized, cv2.COLOR_BGR2Lab)[:, :, :1] | |
| gray_lab = np.concatenate( | |
| (resized_l, np.zeros_like(resized_l), np.zeros_like(resized_l)), | |
| axis=-1, | |
| ) | |
| gray_rgb = cv2.cvtColor(gray_lab, cv2.COLOR_LAB2RGB) | |
| tensor = ( | |
| torch.from_numpy(gray_rgb.transpose((2, 0, 1))) | |
| .float() | |
| .unsqueeze(0) | |
| .to(self.device) | |
| ) | |
| output_ab = self.model(tensor).cpu() | |
| resized_ab = ( | |
| F.interpolate(output_ab, size=(height, width))[0] | |
| .float() | |
| .numpy() | |
| .transpose(1, 2, 0) | |
| ) | |
| output_lab = np.concatenate((orig_l, resized_ab), axis=-1) | |
| output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR) | |
| return (output_bgr * 255.0).round().astype(np.uint8) | |
| def build_colorizer(repo_id=DEFAULT_REPO_ID, device=None): | |
| device = _resolve_device(device) | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") | |
| config = _load_model_config(config_path) | |
| model = DDColor(**config) | |
| state_dict = _load_checkpoint_state_dict(weights_path, map_location="cpu") | |
| model.load_state_dict(state_dict, strict=True) | |
| model = model.to(device) | |
| model.eval() | |
| input_size = config.get("input_size", [512, 512])[0] | |
| return ColorizationPipeline(model, input_size=input_size, device=device) | |
| def _get_colorizer(): | |
| if _COLORIZER_STATE["initialized"]: | |
| return _COLORIZER_STATE["pipeline"] | |
| try: | |
| colorizer = build_colorizer( | |
| repo_id=os.getenv("DDCOLOR_REPO_ID", DEFAULT_REPO_ID), | |
| ) | |
| except Exception as error: | |
| raise gr.Error( | |
| "Failed to initialize the DDColor model from Hugging Face Hub. " | |
| f"Error: {str(error)[:200]}" | |
| ) | |
| _COLORIZER_STATE.update( | |
| { | |
| "initialized": True, | |
| "pipeline": colorizer, | |
| } | |
| ) | |
| return colorizer | |
| def _normalize_input_image(image): | |
| if image.ndim == 2: | |
| return np.stack([image, image, image], axis=-1) | |
| if image.shape[-1] == 4: | |
| return image[..., :3] | |
| return image | |
| def color(image): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| image = _normalize_input_image(image) | |
| colorizer = _get_colorizer() | |
| result_bgr = colorizer.process(image[..., ::-1]) | |
| result_rgb = result_bgr[..., ::-1] | |
| print("infer finished!") | |
| return (image, result_rgb) | |
| def clear_ui(): | |
| return None, None | |
| examples = [["./input.jpg"]] | |
| with gr.Blocks(fill_width=True) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="numpy", | |
| label="Old Photo", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| submit_btn = gr.Button("Colorize", variant="primary") | |
| with gr.Column(): | |
| comparison_output = gr.ImageSlider( | |
| type="numpy", | |
| slider_position=50, | |
| label="Before / After", | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_image, | |
| outputs=comparison_output, | |
| fn=color, | |
| cache_examples=True, | |
| cache_mode="eager", | |
| preload=0, | |
| ) | |
| submit_btn.click( | |
| fn=color, | |
| inputs=input_image, | |
| outputs=comparison_output, | |
| ) | |
| input_image.input( | |
| fn=lambda: None, | |
| outputs=comparison_output, | |
| ) | |
| clear_btn.click( | |
| fn=clear_ui, | |
| outputs=[input_image, comparison_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| share=False, | |
| ssr_mode=False, | |
| theme="Nymbo/Nymbo_Theme", | |
| footer_links=[], | |
| ) | |