| from typing import List, Tuple
|
| import sys
|
| import os
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
| sys.path.append(BASE_DIR)
|
| import io
|
| import torch
|
| import torch.nn as nn
|
| import torchvision.transforms.functional as F
|
| from PIL import Image
|
| from torch import Tensor
|
| from torchvision import transforms
|
| from torchvision.transforms import Lambda
|
|
|
| from muse import VQGANModel
|
| from losses.vgperceptual import VQLPIPSWithDiscriminator
|
|
|
|
|
| class ForwardWrapper(nn.Module):
|
| def __init__(self, vq_model, func='encode'):
|
| super(ForwardWrapper, self).__init__()
|
| self.vq_model = vq_model
|
| self.func = func
|
|
|
| def forward(self, x):
|
| return getattr(self.vq_model, self.func)(x)
|
|
|
| class Train_VQGAN(VQGANModel):
|
| def __init__(
|
| self, num_embeddings,
|
| ):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| super().__init__(
|
| resolution=256,
|
| num_channels=3,
|
| hidden_channels=128,
|
| channel_mult=(1, 2, 2, 4, 6),
|
| num_res_blocks=2,
|
| attn_resolutions=[],
|
| no_attn_mid_block=True,
|
| z_channels=64,
|
| num_embeddings=num_embeddings,
|
| quantized_embed_dim=64,
|
| dropout=0.0,
|
| resample_with_conv=True,
|
| commitment_cost=0.25,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.loss = VQLPIPSWithDiscriminator(
|
| disc_start=0,
|
| disc_in_channels=3,
|
| disc_conditional=False,
|
| disc_weight=0.5,
|
| disc_num_layers=2,
|
| codebook_weight=1.0,
|
| perceptual_weight=1.0,
|
| disc_loss='hinge')
|
|
|
| def encode(self, pixel_values, return_loss=False):
|
| hidden_states = self.encoder(pixel_values)
|
| hidden_states = self.quant_conv(hidden_states)
|
| quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
|
|
|
| return quantized_states, codebook_loss, codebook_indices
|
|
|
| def decode(self, quantized_states):
|
| hidden_states = self.post_quant_conv(quantized_states)
|
| reconstructed_pixel_values = self.decoder(hidden_states)
|
| return reconstructed_pixel_values
|
|
|
| def decode_code(self, codebook_indices):
|
| quantized_states = self.quantize.get_codebook_entry(codebook_indices)
|
| reconstructed_pixel_values = self.decode(quantized_states)
|
| return reconstructed_pixel_values
|
|
|
| def get_last_layer(self):
|
| return self.decoder.conv_out.weight
|
|
|
| def get_code(self, pixel_values):
|
| hidden_states = self.encoder(pixel_values)
|
| hidden_states = self.quant_conv(hidden_states)
|
| codebook_indices = self.quantize.get_code(hidden_states)
|
| return codebook_indices
|
|
|
| def forward(self, pixel_values, return_loss=False):
|
| hidden_states = self.encoder(pixel_values)
|
| hidden_states = self.quant_conv(hidden_states)
|
| quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
|
| reconstructed_pixel_values = self.decode(quantized_states)
|
|
|
| return reconstructed_pixel_values, codebook_loss
|
|
|
| def training_step(self, x, optimizer_idx, device, epoch_cnt):
|
|
|
| xrec, qloss = self(x, return_loss=True)
|
|
|
| if optimizer_idx == 0:
|
|
|
| aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, epoch_cnt,
|
| last_layer=self.get_last_layer(), split="train")
|
|
|
|
|
|
|
| return aeloss, log_dict_ae
|
|
|
| if optimizer_idx == 1:
|
|
|
| discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, epoch_cnt,
|
| last_layer=self.get_last_layer(), split="train")
|
|
|
|
|
| return discloss, log_dict_disc
|
|
|
| def configure_optimizers(self, lr):
|
|
|
| opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
|
| list(self.decoder.parameters())+
|
| list(self.quantize.parameters())+
|
| list(self.quant_conv.parameters())+
|
| list(self.post_quant_conv.parameters()),
|
| lr=lr)
|
| opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
|
| lr=lr)
|
| return [opt_ae, opt_disc], []
|
|
|
| def load_model(path = None, num_embeddings=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| vq_model = Train_VQGAN(num_embeddings)
|
| if path is not None:
|
| ckpt = torch.load(path, map_location='cpu')
|
| vq_model.load_state_dict(ckpt, strict=False)
|
|
|
| return vq_model
|
|
|
| def load_model_no_disc(path='./vqgan_ckpt'):
|
| vq_model = VQGANModel.from_pretrained(path)
|
| return vq_model
|
|
|
| def load_encoder(path):
|
| vq_model = load_model(path)
|
| encoder = ForwardWrapper(vq_model)
|
| return encoder
|
|
|
|
|
| def load_decoder(path):
|
| vq_model = load_model(path)
|
| decoder = ForwardWrapper(vq_model, func='decode')
|
| return decoder
|
|
|
|
|
| def load_decoder_code(path):
|
| vq_model = load_model(path)
|
| decoder = ForwardWrapper(vq_model, func='decode_code')
|
| return decoder
|
|
|
|
|
| def convert_decode_to_pil(rec_image):
|
| rec_image = 2.0 * rec_image - 1.0
|
| rec_image = torch.clamp(rec_image, -1.0, 1.0)
|
| rec_image = (rec_image + 1.0) / 2.0
|
| rec_image *= 255.0
|
| rec_image = rec_image.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
|
| pil_images = [Image.fromarray(image) for image in rec_image]
|
| return pil_images
|
|
|
|
|
| class SixCrop(torch.nn.Module):
|
| def __init__(self, crop_size):
|
| super().__init__()
|
| self.crop_size = crop_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_dimensions(self, img) -> List[int]:
|
| if hasattr(img, "getbands"):
|
| channels = len(img.getbands())
|
| else:
|
| channels = img.channels
|
| width, height = img.size
|
| return [channels, height, width]
|
|
|
| def six_crop(self, img: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| """Crop the given image into four corners and the central crop.
|
| If the image is torch Tensor, it is expected
|
| to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
|
|
| .. Note::
|
| This transform returns a tuple of images and there may be a
|
| mismatch in the number of inputs and targets your ``Dataset`` returns.
|
|
|
| Args:
|
| img (PIL Image or Tensor): Image to be cropped.
|
| size (sequence or int): Desired output size of the crop. If size is an
|
| int instead of sequence like (h, w), a square crop (size, size) is
|
| made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
|
|
| Returns:
|
| tuple: tuple (tl, tr, bl, br, center)
|
| Corresponding top left, top right, bottom left, bottom right and center crop.
|
| """
|
|
|
|
|
|
|
| crop_height, crop_width = self.crop_size
|
| _, image_height, image_width = self.get_dimensions(img)
|
|
|
|
|
|
|
|
|
|
|
| if crop_width > image_width:
|
| crop_width = image_width
|
| crop_height = image_width
|
|
|
| if crop_height > image_height:
|
| crop_width = image_height
|
| crop_height = image_height
|
|
|
| tl = F.crop(img, 0, 0, crop_height, crop_width)
|
| tr = F.crop(img, 0, image_width - crop_width, crop_height, crop_width)
|
| bl = F.crop(img, image_height - crop_height, 0, crop_height, crop_width)
|
| br = F.crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
|
|
|
| if image_height > image_width:
|
| center_top = int(round((image_height - crop_height) / 2.0))
|
| cl = F.crop(img, center_top, 0, crop_height, crop_width)
|
| cr = F.crop(img, center_top, image_width - crop_width, crop_height, crop_width)
|
| return tl, tr, cl, cr, bl, br
|
| else:
|
| center_left = int(round((image_width - crop_width) / 2.0))
|
| ct = F.crop(img, 0, center_left, crop_height, crop_width)
|
| cb = F.crop(img, image_height - crop_height, center_left, crop_height, crop_width)
|
| return tl, tr, ct, bl, br, cb
|
|
|
|
|
|
|
| def forward(self, img):
|
| """
|
| Args:
|
| img (PIL Image or Tensor): Image to be scaled.
|
|
|
| Returns:
|
| PIL Image or Tensor: Rescaled image.
|
| """
|
| return self.six_crop(img)
|
|
|
| def __repr__(self) -> str:
|
| return f"{self.__class__.__name__}(size={self.crop_size})"
|
|
|
|
|
| def six_crop_encode_transform(crop_size):
|
| t = transforms.Compose(
|
| [
|
| SixCrop(crop_size),
|
|
|
| Lambda(lambda crops:
|
| [transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR)(crop) for crop
|
| in crops]),
|
| Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops]),
|
| ]
|
| )
|
| return t
|
|
|
|
|
| encode_transform = transforms.Compose(
|
| [
|
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
|
| transforms.CenterCrop(256),
|
| transforms.ToTensor(),
|
|
|
| ]
|
| )
|
|
|
| encode_transform_no_crop = transforms.Compose(
|
| [
|
| transforms.Resize([256, 256], interpolation=transforms.InterpolationMode.BILINEAR),
|
| transforms.ToTensor(),
|
| ]
|
| )
|
|
|
| encode_transform_2 = transforms.Compose(
|
| [
|
| transforms.RandomHorizontalFlip(),
|
| transforms.RandomVerticalFlip(),
|
| transforms.RandomRotation(180),
|
| transforms.RandomResizedCrop(256, interpolation=transforms.InterpolationMode.BILINEAR),
|
| transforms.ToTensor(),
|
| ]
|
| )
|
|
|
| encode_transform_rain_random = transforms.Compose(
|
| [
|
| transforms.RandomHorizontalFlip(),
|
| transforms.RandomVerticalFlip(),
|
| transforms.RandomResizedCrop(256, interpolation=transforms.InterpolationMode.BILINEAR),
|
| transforms.ToTensor(),
|
| ]
|
| )
|
|
|
| encode_transform_rain_random_2 = transforms.Compose(
|
| [
|
| transforms.RandomHorizontalFlip(),
|
| transforms.RandomVerticalFlip(),
|
| transforms.RandomCrop(400),
|
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
|
| transforms.ToTensor(),
|
| ]
|
|
|
| )
|
|
|
| import random
|
|
|
| from io import BytesIO
|
| from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def decode_image(data):
|
|
|
| rgb = data['rgb.jpg']
|
|
|
| if 'crop.pickle' in data.keys():
|
| crop = data['crop.pickle']
|
| else:
|
| crop = None
|
|
|
| rgb = Image.open(BytesIO(rgb)).convert("RGB")
|
| if crop is not None:
|
| rgb = np.array(rgb)
|
| rgb = rgb[crop[0][0]:crop[1][0], crop[0][1]:crop[1][1], :]
|
| rgb = Image.fromarray(rgb).convert("RGB")
|
|
|
| rgb = encode_transform(rgb)
|
|
|
|
|
| return rgb
|
|
|
| def code_usage_stat(model, ds, save_dir):
|
|
|
| import json
|
| model.eval().cuda()
|
|
|
| stat = {}
|
|
|
| cnt = 0
|
| for i in range(8192):
|
| stat[i] = 0
|
|
|
| for d in ds:
|
|
|
| rgb = d.unsqueeze(0).cuda()
|
|
|
| with torch.no_grad():
|
| idxs = model.get_code(rgb)[0]
|
|
|
| idxs = idxs.cpu().numpy().tolist()
|
|
|
| for idx in idxs:
|
| stat[idx] = stat[idx] + 1
|
|
|
| print (cnt)
|
| cnt += 1
|
|
|
| with open(save_dir, 'w') as f:
|
| json.dump(stat, f)
|
|
|
| def plot_code_stat(stat, out):
|
| import json
|
| import matplotlib.pyplot as plt
|
|
|
| with open(stat, 'r') as f:
|
| stat = json.load(f)
|
|
|
| x = []
|
| y = []
|
|
|
| for k, v in stat.items():
|
| x.append(int(k))
|
| y.append(int(v))
|
|
|
| plt.bar(x, y)
|
| plt.savefig(out)
|
|
|
|
|
|
|
| if __name__ == '__main__':
|
| import numpy as np
|
| from calvin_img import CalvinImgDataset
|
| import matplotlib.pyplot as plt
|
| import webdataset as wds
|
| from tqdm import tqdm
|
| from copy import deepcopy
|
|
|
| def generator_to_webdataset(ds, save_path: str):
|
| sink = wds.TarWriter(save_path)
|
| for data_dict in tqdm(ds):
|
| sink.write(data_dict)
|
| sink.close()
|
|
|
|
|
|
|
|
|
| vq_model = load_model('/mnt/bn/roboicl-jirong/codebase/RoboICL/src/vqgan_muse_pretrained/vqgan_ckpt/ckpt.pth', 8192)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| vq_model.eval()
|
|
|
|
|
| ds = '/mnt/bn/roboicl-jirong/codebase/real_data_img_0122.tar'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| ds = wds.WebDataset(ds).decode().map(decode_image)
|
|
|
| cnt = 0
|
| for d in ds:
|
|
|
| rgb = d.unsqueeze(0)
|
| print (rgb.shape)
|
|
|
| with torch.no_grad():
|
| recon, _ = vq_model(rgb)
|
|
|
|
|
| recon = convert_decode_to_pil(recon)[0]
|
|
|
| rgb = transforms.ToPILImage()(rgb[0])
|
|
|
| fig, ax = plt.subplots(1,2)
|
|
|
| ax[0].imshow(np.array(rgb))
|
| ax[0].set_title('rgb')
|
|
|
| ax[1].imshow(np.array(recon))
|
| ax[1].set_title('rgb_1')
|
|
|
|
|
|
|
|
|
| fig.savefig('/mnt/bn/robotics-data-hl/jirong/git/DeLVM/data_generation/vqgan/recon/{}.jpg'.format(str(cnt).zfill(4)))
|
| cnt += 1
|
| print (cnt)
|
| if cnt > 100:
|
| break
|
|
|