jirong's picture
Upload folder using huggingface_hub
ee3e701 verified
from typing import List, Tuple
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# ROOT_DIR = os.path.dirname(BASE_DIR)
# sys.path.append(ROOT_DIR)
sys.path.append(BASE_DIR)
import io
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from PIL import Image
from torch import Tensor
from torchvision import transforms
from torchvision.transforms import Lambda
from muse import VQGANModel
from losses.vgperceptual import VQLPIPSWithDiscriminator
class ForwardWrapper(nn.Module):
def __init__(self, vq_model, func='encode'):
super(ForwardWrapper, self).__init__()
self.vq_model = vq_model
self.func = func
def forward(self, x):
return getattr(self.vq_model, self.func)(x)
class Train_VQGAN(VQGANModel):
def __init__(
self, num_embeddings,
):
# super().__init__(
# resolution=256,
# num_channels=3,
# hidden_channels=128,
# channel_mult=(1, 1, 2, 2, 4),
# num_res_blocks=2,
# attn_resolutions=[16],
# no_attn_mid_block=True,
# z_channels=256,
# num_embeddings=8192,
# quantized_embed_dim=256,
# dropout=0.0,
# resample_with_conv=True,
# commitment_cost=0.25,
# )
super().__init__(
resolution=256,
num_channels=3,
hidden_channels=128,
channel_mult=(1, 2, 2, 4, 6),
num_res_blocks=2,
attn_resolutions=[],
no_attn_mid_block=True,
z_channels=64,
num_embeddings=num_embeddings,
quantized_embed_dim=64,
dropout=0.0,
resample_with_conv=True,
commitment_cost=0.25,
)
# super().__init__(
# resolution=256,
# num_channels=3,
# hidden_channels=128,
# channel_mult=(1, 1, 2, 2, 4),
# num_res_blocks=2,
# attn_resolutions=[],
# no_attn_mid_block=True,
# z_channels=256,
# num_embeddings=1024,
# quantized_embed_dim=256,
# dropout=0.0,
# resample_with_conv=True,
# commitment_cost=0.25,
# )
self.loss = VQLPIPSWithDiscriminator(
disc_start=0,
disc_in_channels=3,
disc_conditional=False,
disc_weight=0.5,
disc_num_layers=2,
codebook_weight=1.0,
perceptual_weight=1.0,
disc_loss='hinge')
def encode(self, pixel_values, return_loss=False):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
return quantized_states, codebook_loss, codebook_indices
def decode(self, quantized_states):
hidden_states = self.post_quant_conv(quantized_states)
reconstructed_pixel_values = self.decoder(hidden_states)
return reconstructed_pixel_values
def decode_code(self, codebook_indices):
quantized_states = self.quantize.get_codebook_entry(codebook_indices)
reconstructed_pixel_values = self.decode(quantized_states)
return reconstructed_pixel_values
def get_last_layer(self):
return self.decoder.conv_out.weight
def get_code(self, pixel_values):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
codebook_indices = self.quantize.get_code(hidden_states)
return codebook_indices
def forward(self, pixel_values, return_loss=False):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
reconstructed_pixel_values = self.decode(quantized_states)
return reconstructed_pixel_values, codebook_loss
def training_step(self, x, optimizer_idx, device, epoch_cnt):
xrec, qloss = self(x, return_loss=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, epoch_cnt,
last_layer=self.get_last_layer(), split="train")
# self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
# self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss, log_dict_ae
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, epoch_cnt,
last_layer=self.get_last_layer(), split="train")
# self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
# self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss, log_dict_disc
def configure_optimizers(self, lr):
# lr = self.learning_rate
opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr)
opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
lr=lr)
return [opt_ae, opt_disc], []
def load_model(path = None, num_embeddings=None):
# Load the pre-trained vq model from the hub
# vq_model = VQGANModel(
# resolution=256,
# num_channels=3,
# hidden_channels=128,
# channel_mult=(1, 2, 2, 4, 6),
# num_res_blocks=2,
# attn_resolutions=(),
# no_attn_mid_block=True,
# z_channels=64,
# num_embeddings=8192,
# quantized_embed_dim=64,
# dropout=0.0,
# resample_with_conv=True,
# commitment_cost=0.25,
# )
vq_model = Train_VQGAN(num_embeddings)
if path is not None:
ckpt = torch.load(path, map_location='cpu')
vq_model.load_state_dict(ckpt, strict=False)
# vq_model = VQGANModel.from_pretrained(path)
return vq_model
def load_model_no_disc(path='./vqgan_ckpt'):
vq_model = VQGANModel.from_pretrained(path)
return vq_model
def load_encoder(path):
vq_model = load_model(path)
encoder = ForwardWrapper(vq_model)
return encoder
def load_decoder(path):
vq_model = load_model(path)
decoder = ForwardWrapper(vq_model, func='decode')
return decoder
def load_decoder_code(path):
vq_model = load_model(path)
decoder = ForwardWrapper(vq_model, func='decode_code')
return decoder
def convert_decode_to_pil(rec_image):
rec_image = 2.0 * rec_image - 1.0
rec_image = torch.clamp(rec_image, -1.0, 1.0)
rec_image = (rec_image + 1.0) / 2.0
rec_image *= 255.0
rec_image = rec_image.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
pil_images = [Image.fromarray(image) for image in rec_image]
return pil_images
class SixCrop(torch.nn.Module):
def __init__(self, crop_size):
super().__init__()
self.crop_size = crop_size
# def get_dimensions(self, img):
# """Returns the dimensions of an image as [channels, height, width].
#
# Args:
# img (PIL Image or Tensor): The image to be checked.
#
# Returns:
# List[int]: The image dimensions.
# """
# if isinstance(img, torch.Tensor):
# return F_t.get_dimensions(img)
#
# return F_pil.get_dimensions(img)
def get_dimensions(self, img) -> List[int]:
if hasattr(img, "getbands"):
channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size
return [channels, height, width]
def six_crop(self, img: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Crop the given image into four corners and the central crop.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Returns:
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(five_crop)
crop_height, crop_width = self.crop_size
_, image_height, image_width = self.get_dimensions(img)
# if crop_width > image_width or crop_height > image_height:
# msg = "Requested crop size {} is bigger than input size {}"
# raise ValueError(msg.format(self.crop_size, (image_height, image_width)))
if crop_width > image_width:
crop_width = image_width
crop_height = image_width
if crop_height > image_height:
crop_width = image_height
crop_height = image_height
tl = F.crop(img, 0, 0, crop_height, crop_width)
tr = F.crop(img, 0, image_width - crop_width, crop_height, crop_width)
bl = F.crop(img, image_height - crop_height, 0, crop_height, crop_width)
br = F.crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
if image_height > image_width:
center_top = int(round((image_height - crop_height) / 2.0))
cl = F.crop(img, center_top, 0, crop_height, crop_width)
cr = F.crop(img, center_top, image_width - crop_width, crop_height, crop_width)
return tl, tr, cl, cr, bl, br
else:
center_left = int(round((image_width - crop_width) / 2.0))
ct = F.crop(img, 0, center_left, crop_height, crop_width)
cb = F.crop(img, image_height - crop_height, center_left, crop_height, crop_width)
return tl, tr, ct, bl, br, cb
# center = center_crop(img, [crop_height, crop_width])
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
return self.six_crop(img)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.crop_size})"
def six_crop_encode_transform(crop_size):
t = transforms.Compose(
[
SixCrop(crop_size),
# transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
Lambda(lambda crops:
[transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR)(crop) for crop
in crops]),
Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops]),
]
)
return t
encode_transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(256),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
encode_transform_no_crop = transforms.Compose(
[
transforms.Resize([256, 256], interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
encode_transform_2 = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(180),
transforms.RandomResizedCrop(256, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
encode_transform_rain_random = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomResizedCrop(256, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
encode_transform_rain_random_2 = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomCrop(400),
transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
import random
from io import BytesIO
from PIL import Image
# def decode_image(data):
# # assert 'rgb.jpg' in data.keys()
# # print (data.keys())
# rgb = data['video_hand.pickle']
# rgb = random.sample(rgb, 1)[0]
# # rgb = data['rgb.jpg']
# rgb = Image.open(io.BytesIO(rgb))
# rgb = encode_transform(rgb)
# # rgb = (rgb * 2) - 1
# return rgb
def decode_image(data):
# assert 'rgb.jpg' in data.keys()
rgb = data['rgb.jpg']
if 'crop.pickle' in data.keys():
crop = data['crop.pickle']
else:
crop = None
rgb = Image.open(BytesIO(rgb)).convert("RGB")
if crop is not None:
rgb = np.array(rgb)
rgb = rgb[crop[0][0]:crop[1][0], crop[0][1]:crop[1][1], :]
rgb = Image.fromarray(rgb).convert("RGB")
rgb = encode_transform(rgb)
# rgb = (rgb * 2) - 1
# rgb = (rgb + 1) / 2
return rgb
def code_usage_stat(model, ds, save_dir):
import json
model.eval().cuda()
stat = {}
cnt = 0
for i in range(8192):
stat[i] = 0
for d in ds:
rgb = d.unsqueeze(0).cuda()
with torch.no_grad():
idxs = model.get_code(rgb)[0]
idxs = idxs.cpu().numpy().tolist()
for idx in idxs:
stat[idx] = stat[idx] + 1
print (cnt)
cnt += 1
with open(save_dir, 'w') as f:
json.dump(stat, f)
def plot_code_stat(stat, out):
import json
import matplotlib.pyplot as plt
with open(stat, 'r') as f:
stat = json.load(f)
x = []
y = []
for k, v in stat.items():
x.append(int(k))
y.append(int(v))
plt.bar(x, y)
plt.savefig(out)
if __name__ == '__main__':
import numpy as np
from calvin_img import CalvinImgDataset
import matplotlib.pyplot as plt
import webdataset as wds
from tqdm import tqdm
from copy import deepcopy
def generator_to_webdataset(ds, save_path: str):
sink = wds.TarWriter(save_path)
for data_dict in tqdm(ds):
sink.write(data_dict)
sink.close()
# plot_code_stat('./calvin_vqgan_stat.json', 'calvin.jpg')
#
vq_model = load_model('/mnt/bn/roboicl-jirong/codebase/RoboICL/src/vqgan_muse_pretrained/vqgan_ckpt/ckpt.pth', 8192)
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/git/DeLVM/data_generation/vqgan/vqgan_ckpt/ckpt.pth', 8192)
# # # vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_muse_cfg_vl_robot_1e-4_192_codebook_0.1/checkpoint_vq_epoch_70655.tar')
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_muse_finetune_calvin_datacomp_1e-5_192_codebook_0.1_const_lr/checkpoint_vq_epoch_64511.tar')
# # ds = "/mnt/bn/roboicllq-data1/calvin_img/calvin_img_{00000000..00000239}.tar"
# # ds = wds.WebDataset(ds).decode().map(decode_image)
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048)
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.5_resume/checkpoint_vq_epoch_43999.tar', 2048)
# # code_usage_stat(vq_model, ds, './cofinetune_vqgan_stat_24999.json')
# # plot_code_stat('./cofinetune_vqgan_stat_24999.json', 'calvin.jpg')
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024_fintune/checkpoint_vq_epoch_9999.tar')
# # vq_model2 = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_cofinetune_2500/checkpoint_vq_epoch_1999.tar')
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_muse_ego4d_llava_calvin_3e-6_disc_0/checkpoint_vq_epoch_66999.tar', 8192)
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_256_ego4d_calvin_real_aug_disc_0.5_3e-5/checkpoint_vq_epoch_57499.tar', 2048)
# model_state_dict = vq_model.state_dict()
# keys = list(model_state_dict.keys())
# for key in keys:
# if 'loss' in key:
# del model_state_dict[key]
# torch.save(model_state_dict, '/mnt/bn/roboicl-jirong/codebase/RoboICL/src/vqgan_muse_pretrained/ego4d_llava_real_calvin_2048_finetune/pytorch_model.bin')
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_256_datacomp_calvin_real_aug_disc_0.5/checkpoint_vq_epoch_8999.tar', 2048)
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.5_resume/checkpoint_vq_epoch_72999.tar', 2048)
# vq_model.eval()
# vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_256_ego4d_calvin_real_aug_disc_0.5_3e-5/checkpoint_vq_epoch_57499.tar', 2048)
# # vq_model = load_model_no_disc('/mnt/bn/roboicl-jirong/codebase/RoboICL/src/vqgan_muse_pretrained/muse_cofinetune_8192')
# # vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_muse_ego4d_llava_calvin_3e-6_disc_0/checkpoint_vq_epoch_66999.tar', 8192)
# # vq_model = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_muse_ego4d_llava_calvin_1e-5_disc_0_resume/checkpoint_vq_epoch_4999.tar', 8192)
vq_model.eval()
# # # # # trainset = CalvinImgDataset('/mnt/bn/robotics/manipulation_data/calvin_data', use_hand_observation=True)
ds = '/mnt/bn/roboicl-jirong/codebase/real_data_img_0122.tar'
# ds = '/mnt/bn/roboicllq/img_wds/ego4d/ego4d_0000_mLNrAABME3.tar'
# ds = '/mnt/bn/roboicllq-data1/video_img_subset/ego4d/ego4d_0003_6IzswDQ9HomKtv5.tar'
# # ds = '/mnt/bn/roboicllq/img_wds/ego4d/ego4d_0000_IdrgFUyeJM.tar'
# # ds ='/mnt/bn/roboicllq-data1/processed_real/hand_imgs/real_data_hand_img_0130.tar'
# # ds = '/mnt/bn/roboicllq-data1/video_img_subset/ssv2/ssv2_0000_2yW4Acq9GFz6Y1t.tar'
# ds = '/mnt/bn/roboicl-jirong/codebase/calvin_img_00000046.tar'
# # ds = '/mnt/bn/roboicllq-data1/llava_pretrain/LLaVA-Pretrain/wds/llava_pretrain_0060.tar'
# # # ds = '/mnt/bn/roboicllq-data1/aligned_robot_ds/calvin/view_0/calvin_00000011.tar'
# ds = '/mnt/bn/roboicllq-data1/datacomp10m/datacomp_0000_6s2NkdV2Ijp0JTTNG5EUpt6AJ95uV9.tar'
ds = wds.WebDataset(ds).decode().map(decode_image)
cnt = 0
for d in ds:
rgb = d.unsqueeze(0)
print (rgb.shape)
with torch.no_grad():
recon, _ = vq_model(rgb)
# recon2, _ = vq_model2(rgb)
recon = convert_decode_to_pil(recon)[0]
# recon2 = convert_decode_to_pil(recon2)[0]
rgb = transforms.ToPILImage()(rgb[0])
fig, ax = plt.subplots(1,2)
ax[0].imshow(np.array(rgb))
ax[0].set_title('rgb')
ax[1].imshow(np.array(recon))
ax[1].set_title('rgb_1')
# ax[2].imshow(np.array(recon2))
# ax[2].set_title('rgb_2')
fig.savefig('/mnt/bn/robotics-data-hl/jirong/git/DeLVM/data_generation/vqgan/recon/{}.jpg'.format(str(cnt).zfill(4)))
cnt += 1
print (cnt)
if cnt > 100:
break