Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import torch | |
| from diffusers import UNet1DModel | |
| os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) | |
| os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) | |
| os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) | |
| def unet(hor): | |
| if hor == 128: | |
| down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") | |
| block_out_channels = (32, 128, 256) | |
| up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") | |
| elif hor == 32: | |
| down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") | |
| block_out_channels = (32, 64, 128, 256) | |
| up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") | |
| model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") | |
| state_dict = model.state_dict() | |
| config = dict( | |
| down_block_types=down_block_types, | |
| block_out_channels=block_out_channels, | |
| up_block_types=up_block_types, | |
| layers_per_block=1, | |
| use_timestep_embedding=True, | |
| out_block_type="OutConv1DBlock", | |
| norm_num_groups=8, | |
| downsample_each_block=False, | |
| in_channels=14, | |
| out_channels=14, | |
| extra_in_channels=0, | |
| time_embedding_type="positional", | |
| flip_sin_to_cos=False, | |
| freq_shift=1, | |
| sample_size=65536, | |
| mid_block_type="MidResTemporalBlock1D", | |
| act_fn="mish", | |
| ) | |
| hf_value_function = UNet1DModel(**config) | |
| print(f"length of state dict: {len(state_dict.keys())}") | |
| print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") | |
| mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) | |
| for k, v in mapping.items(): | |
| state_dict[v] = state_dict.pop(k) | |
| hf_value_function.load_state_dict(state_dict) | |
| torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") | |
| with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: | |
| json.dump(config, f) | |
| def value_function(): | |
| config = dict( | |
| in_channels=14, | |
| down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), | |
| up_block_types=(), | |
| out_block_type="ValueFunction", | |
| mid_block_type="ValueFunctionMidBlock1D", | |
| block_out_channels=(32, 64, 128, 256), | |
| layers_per_block=1, | |
| downsample_each_block=True, | |
| sample_size=65536, | |
| out_channels=14, | |
| extra_in_channels=0, | |
| time_embedding_type="positional", | |
| use_timestep_embedding=True, | |
| flip_sin_to_cos=False, | |
| freq_shift=1, | |
| norm_num_groups=8, | |
| act_fn="mish", | |
| ) | |
| model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") | |
| state_dict = model | |
| hf_value_function = UNet1DModel(**config) | |
| print(f"length of state dict: {len(state_dict.keys())}") | |
| print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") | |
| mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) | |
| for k, v in mapping.items(): | |
| state_dict[v] = state_dict.pop(k) | |
| hf_value_function.load_state_dict(state_dict) | |
| torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") | |
| with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: | |
| json.dump(config, f) | |
| if __name__ == "__main__": | |
| unet(32) | |
| # unet(128) | |
| value_function() | |