Spaces:
Sleeping
Sleeping
| from .pix2pix_model import Pix2PixModel | |
| import torch | |
| from skimage import color # used for lab2rgb | |
| import numpy as np | |
| class ColorizationModel(Pix2PixModel): | |
| """This is a subclass of Pix2PixModel for image colorization (black & white image -> colorful images). | |
| The model training requires '-dataset_model colorization' dataset. | |
| It trains a pix2pix model, mapping from L channel to ab channels in Lab color space. | |
| By default, the colorization dataset will automatically set '--input_nc 1' and '--output_nc 2'. | |
| """ | |
| def modify_commandline_options(parser, is_train=True): | |
| """Add new dataset-specific options, and rewrite default values for existing options. | |
| Parameters: | |
| parser -- original option parser | |
| is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. | |
| Returns: | |
| the modified parser. | |
| By default, we use 'colorization' dataset for this model. | |
| See the original pix2pix paper (https://arxiv.org/pdf/1611.07004.pdf) and colorization results (Figure 9 in the paper) | |
| """ | |
| Pix2PixModel.modify_commandline_options(parser, is_train) | |
| parser.set_defaults(dataset_mode='colorization') | |
| return parser | |
| def __init__(self, opt): | |
| """Initialize the class. | |
| Parameters: | |
| opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions | |
| For visualization, we set 'visual_names' as 'real_A' (input real image), | |
| 'real_B_rgb' (ground truth RGB image), and 'fake_B_rgb' (predicted RGB image) | |
| We convert the Lab image 'real_B' (inherited from Pix2pixModel) to a RGB image 'real_B_rgb'. | |
| we convert the Lab image 'fake_B' (inherited from Pix2pixModel) to a RGB image 'fake_B_rgb'. | |
| """ | |
| # reuse the pix2pix model | |
| Pix2PixModel.__init__(self, opt) | |
| # specify the images to be visualized. | |
| self.visual_names = ['real_A', 'real_B_rgb', 'fake_B_rgb'] | |
| def lab2rgb(self, L, AB): | |
| """Convert an Lab tensor image to a RGB numpy output | |
| Parameters: | |
| L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) | |
| AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) | |
| Returns: | |
| rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) | |
| """ | |
| AB2 = AB * 110.0 | |
| L2 = (L + 1.0) * 50.0 | |
| Lab = torch.cat([L2, AB2], dim=1) | |
| Lab = Lab[0].data.cpu().float().numpy() | |
| Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) | |
| rgb = color.lab2rgb(Lab) * 255 | |
| return rgb | |
| def compute_visuals(self): | |
| """Calculate additional output images for visdom and HTML visualization""" | |
| self.real_B_rgb = self.lab2rgb(self.real_A, self.real_B) | |
| self.fake_B_rgb = self.lab2rgb(self.real_A, self.fake_B) | |