| | import sys, os |
| | root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
| | sys.path.append(root) |
| | os.chdir(root) |
| |
|
| | |
| | import time |
| | import types |
| | import torch |
| | import copy |
| | from torch import nn |
| | from model.diffusion import DDIMSampler, DDPMSampler |
| | |
| | import importlib |
| | item = importlib.import_module(f"{sys.argv[1]}") |
| | Dataset = item.Dataset |
| | train_set = item.train_set |
| | config = item.config |
| | model = item.model |
| | assert config.get("tag") is not None, "Remember to set a tag." |
| |
|
| |
|
| |
|
| |
|
| | generate_config = { |
| | "device": "cuda", |
| | "num_generated": 1, |
| | "checkpoint": f"./checkpoint/{config['tag']}.pth", |
| | "generated_path": os.path.join(Dataset.generated_path.rsplit("/", 1)[0], "generated_{}_{}.pth"), |
| | "test_command": os.path.join(Dataset.test_command.rsplit("/", 1)[0], "generated_{}_{}.pth"), |
| | "need_test": True, |
| | |
| | "sampler": DDIMSampler, |
| | "steps": 60, |
| | } |
| | config.update(generate_config) |
| |
|
| |
|
| |
|
| |
|
| | |
| | print('==> Building model..') |
| | diction = torch.load(config["checkpoint"]) |
| | permutation_shape = diction["to_permutation_state.weight"].shape |
| | model.to_permutation_state = nn.Embedding(*permutation_shape) |
| | model.load_state_dict(diction) |
| | model.criteria.diffusion_sampler = config["sampler"]( |
| | model=model.criteria.diffusion_sampler.model, |
| | beta=config["model_config"]["beta"], |
| | T=config["model_config"]["T"], |
| | ) |
| | model.condi_embedder = copy.deepcopy(model.criteria.diffusion_sampler.model.condi_embedder) |
| | @torch.no_grad() |
| | def new_sample(self, x=None, condition=None): |
| | z = self.model([1, self.sequence_length, self.config["d_model"]], condition) |
| | z = self.condi_embedder(z) |
| | if x is None: |
| | x = torch.randn((1, self.sequence_length, self.config["model_dim"]), device=z.device) |
| | x = self.criteria.sample(x, z, steps=config["steps"]) |
| | return x |
| | model.sample = types.MethodType(new_sample, model) |
| | model.criteria.diffusion_sampler.model.condi_embedder = nn.Identity() |
| | model = model.to(config["device"]) |
| |
|
| |
|
| | |
| | print('==> Defining generate..') |
| | def generate(save_path=config["generated_path"], test_command=config["test_command"], need_test=True): |
| | print("\n==> Generating..") |
| | model.eval() |
| | with torch.cuda.amp.autocast(True, torch.bfloat16): |
| | with torch.no_grad(): |
| | start_time = time.time() |
| | prediction = model(sample=True) |
| | end_time = time.time() |
| | generated_norm = torch.nanmean(prediction.abs()) |
| | print("used time (seconds):", end_time - start_time) |
| | print("memory usage (GB):", torch.cuda.max_memory_allocated() / (1024 ** 3)) |
| | |
| | train_set.save_params(prediction, save_path=save_path) |
| | if need_test: |
| | os.system(test_command) |
| | print("\n") |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | for i in range(config["num_generated"]): |
| | index = str(i+1).zfill(3) |
| | print("Save to", config["generated_path"].format(config["tag"], index)) |
| | generate( |
| | save_path=config["generated_path"].format(config["tag"], index), |
| | test_command=config["test_command"].format(config["tag"], index), |
| | need_test=config["need_test"], |
| | ) |