Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| from ema_pytorch import EMA | |
| from torch.optim import Adam | |
| from torch.nn.utils import clip_grad_norm_ | |
| from utils.io_utils import instantiate_from_config, get_model_parameters_info | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "../")) | |
| def cycle(dl): | |
| while True: | |
| for data in dl: | |
| yield data | |
| class Trainer(object): | |
| def __init__(self, config, args, model, dataloader, logger=None): | |
| super().__init__() | |
| if os.getenv("WANDB_ENABLED") == "true": | |
| import wandb | |
| self.run = wandb.init(project="tiffusion-revenue", config=config) | |
| else: | |
| self.run = None | |
| self.model = model | |
| self.device = self.model.betas.device | |
| self.train_num_steps = config["solver"]["max_epochs"] | |
| self.gradient_accumulate_every = config["solver"]["gradient_accumulate_every"] | |
| self.save_cycle = config["solver"]["save_cycle"] | |
| self.dl = cycle(dataloader["dataloader"]) | |
| self.step = 0 | |
| self.milestone = 0 | |
| self.args = args | |
| self.logger = logger | |
| self.results_folder = Path( | |
| config["solver"]["results_folder"] + f"_{model.seq_length}" | |
| ) | |
| os.makedirs(self.results_folder, exist_ok=True) | |
| start_lr = config["solver"].get("base_lr", 1.0e-4) | |
| ema_decay = config["solver"]["ema"]["decay"] | |
| ema_update_every = config["solver"]["ema"]["update_interval"] | |
| self.opt = Adam( | |
| filter(lambda p: p.requires_grad, self.model.parameters()), | |
| lr=start_lr, | |
| betas=[0.9, 0.96], | |
| ) | |
| self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to( | |
| self.device | |
| ) | |
| sc_cfg = config["solver"]["scheduler"] | |
| sc_cfg["params"]["optimizer"] = self.opt | |
| self.sch = instantiate_from_config(sc_cfg) | |
| if self.logger is not None: | |
| self.logger.log_info(str(get_model_parameters_info(self.model))) | |
| self.log_frequency = 100 | |
| def save(self, milestone, verbose=False): | |
| if self.logger is not None and verbose: | |
| self.logger.log_info( | |
| "Save current model to {}".format( | |
| str(self.results_folder / f"checkpoint-{milestone}.pt") | |
| ) | |
| ) | |
| data = { | |
| "step": self.step, | |
| "model": self.model.state_dict(), | |
| "ema": self.ema.state_dict(), | |
| "opt": self.opt.state_dict(), | |
| } | |
| torch.save(data, str(self.results_folder / f"checkpoint-{milestone}.pt")) | |
| def load(self, milestone, verbose=False, from_folder=None): | |
| if self.logger is not None and verbose: | |
| self.logger.log_info( | |
| "Resume from {}".format( | |
| os.path.join(from_folder, f"checkpoint-{milestone}.pt") | |
| ) | |
| ) | |
| device = self.device | |
| data = torch.load( | |
| os.path.join(from_folder,f"checkpoint-{milestone}.pt") if from_folder else str(self.results_folder / f"checkpoint-{milestone}.pt"), map_location=device, weights_only=True | |
| ) | |
| self.model.load_state_dict(data["model"], ) | |
| self.step = data["step"] | |
| self.opt.load_state_dict(data["opt"]) | |
| self.ema.load_state_dict(data["ema"]) | |
| self.milestone = milestone | |
| def train(self): | |
| device = self.device | |
| step = 0 | |
| if self.logger is not None: | |
| tic = time.time() | |
| self.logger.log_info( | |
| "{}: start training...".format(self.args.name), check_primary=False | |
| ) | |
| with tqdm(initial=step, total=self.train_num_steps) as pbar: | |
| while step < self.train_num_steps: | |
| total_loss = 0.0 | |
| for _ in range(self.gradient_accumulate_every): | |
| data = next(self.dl).to(device) | |
| loss = self.model(data, target=data) | |
| loss = loss / self.gradient_accumulate_every | |
| loss.backward() | |
| total_loss += loss.item() | |
| pbar.set_description( | |
| f'loss: {total_loss:.6f} lr: {self.opt.param_groups[0]["lr"]:.6f}' | |
| ) | |
| if self.run is not None: | |
| wandb.log( | |
| { | |
| "step": step, | |
| "loss": total_loss, | |
| "lr": self.opt.param_groups[0]["lr"], | |
| }, | |
| step=self.step, | |
| ) | |
| clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.opt.step() | |
| self.sch.step(total_loss) | |
| self.opt.zero_grad() | |
| self.step += 1 | |
| step += 1 | |
| self.ema.update() | |
| with torch.no_grad(): | |
| if self.step != 0 and self.step % self.save_cycle == 0: | |
| self.milestone += 1 | |
| self.save(self.milestone) | |
| # self.logger.log_info('saved in {}'.format(str(self.results_folder / f'checkpoint-{self.milestone}.pt'))) | |
| if self.logger is not None and self.step % self.log_frequency == 0: | |
| # info = '{}: train'.format(self.args.name) | |
| # info = info + ': Epoch {}/{}'.format(self.step, self.train_num_steps) | |
| # info += ' ||' | |
| # info += '' if loss_f == 'none' else ' Fourier Loss: {:.4f}'.format(loss_f.item()) | |
| # info += '' if loss_r == 'none' else ' Reglarization: {:.4f}'.format(loss_r.item()) | |
| # info += ' | Total Loss: {:.6f}'.format(total_loss) | |
| # self.logger.log_info(info) | |
| self.logger.add_scalar( | |
| tag="train/loss", | |
| scalar_value=total_loss, | |
| global_step=self.step, | |
| ) | |
| pbar.update(1) | |
| print("training complete") | |
| if self.logger is not None: | |
| self.logger.log_info( | |
| "Training done, time: {:.2f}".format(time.time() - tic) | |
| ) | |
| def sample(self, num, size_every, shape=None): | |
| if self.logger is not None: | |
| tic = time.time() | |
| self.logger.log_info("Begin to sample...") | |
| samples = np.empty([0, shape[0], shape[1]]) | |
| num_cycle = int(num // size_every) + 1 | |
| for _ in range(num_cycle): | |
| sample = self.ema.ema_model.generate_mts(batch_size=size_every) | |
| samples = np.row_stack([samples, sample.detach().cpu().numpy()]) | |
| torch.cuda.empty_cache() | |
| if self.logger is not None: | |
| self.logger.log_info( | |
| "Sampling done, time: {:.2f}".format(time.time() - tic) | |
| ) | |
| return samples | |
| def control_sample(self, num, size_every, shape=None, model_kwargs={}, target=None, partial_mask=None): | |
| samples = np.empty([0, shape[0], shape[1]]) | |
| import math | |
| num_cycle = math.ceil(num / size_every) | |
| assert not ((target is None) ^ (partial_mask is None)), "target and partial_mask should be provided" | |
| if self.logger is not None: | |
| tic = time.time() | |
| self.logger.log_info("Begin to infill sample...") | |
| target = torch.tensor(target).to(self.device) if target is not None else torch.zeros(shape).to(self.device) | |
| target = target.repeat(size_every, 1, 1) if len(target.shape) == 2 else target | |
| partial_mask = torch.tensor(partial_mask).to(self.device) if partial_mask is not None else torch.zeros(shape).to(self.device) | |
| partial_mask = partial_mask.repeat(size_every, 1, 1) if len(partial_mask.shape) == 2 else partial_mask | |
| for _ in range(num_cycle): | |
| sample = self.ema.ema_model.generate_mts_infill(target, partial_mask, model_kwargs=model_kwargs) | |
| samples = np.row_stack([samples, sample.detach().cpu().numpy()]) | |
| torch.cuda.empty_cache() | |
| if self.logger is not None: | |
| self.logger.log_info( | |
| "Sampling done, time: {:.2f}".format(time.time() - tic) | |
| ) | |
| return samples | |
| def predict( | |
| self, | |
| observed_points: torch.Tensor, | |
| coef=1e-1, | |
| stepsize=1e-1, | |
| sampling_steps=50, | |
| **kargs, | |
| ): | |
| model_kwargs = {} | |
| model_kwargs["coef"] = coef | |
| model_kwargs["learning_rate"] = stepsize | |
| model_kwargs = {**model_kwargs, **kargs} | |
| assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" | |
| x = observed_points.unsqueeze(0) | |
| t_m = x != 0 | |
| x = x * 2 - 1 # normalize | |
| x, t_m = x.to(self.device), t_m.to(self.device) | |
| if sampling_steps == self.model.num_timesteps: | |
| print("normal sampling") | |
| sample = self.ema.ema_model.sample_infill( | |
| shape=x.shape, | |
| target=x * t_m, | |
| partial_mask=t_m, | |
| model_kwargs=model_kwargs, | |
| ) | |
| # x: partially noise : (batch_size, seq_length, feature_dim) | |
| else: | |
| print("fast sampling") | |
| sample = self.ema.ema_model.fast_sample_infill( | |
| shape=x.shape, | |
| target=x * t_m, | |
| partial_mask=t_m, | |
| model_kwargs=model_kwargs, | |
| sampling_timesteps=sampling_steps, | |
| ) | |
| # unnormalize | |
| sample = (sample + 1) / 2 | |
| return sample.squeeze(0).detach().cpu().numpy() | |
| def predict_weighted_points( | |
| self, | |
| observed_points: torch.Tensor, | |
| observed_mask: torch.Tensor, | |
| coef=1e-1, | |
| stepsize=1e-1, | |
| sampling_steps=50, | |
| **kargs, | |
| ): | |
| model_kwargs = {} | |
| model_kwargs["coef"] = coef | |
| model_kwargs["learning_rate"] = stepsize | |
| model_kwargs = {**model_kwargs, **kargs} | |
| assert len(observed_points.shape) == 2, "observed_points should be 2D, batch size = 1" | |
| x = observed_points.unsqueeze(0) | |
| float_mask = observed_mask.unsqueeze(0) # x != 0, 1 for observed, 0 for missing, bool tensor | |
| binary_mask = float_mask.clone() | |
| binary_mask[binary_mask > 0] = 1 | |
| x = x * 2 - 1 # normalize | |
| x, float_mask, binary_mask = x.to(self.device), float_mask.to(self.device), binary_mask.to(self.device) | |
| if sampling_steps == self.model.num_timesteps: | |
| print("normal sampling") | |
| raise NotImplementedError | |
| sample = self.ema.ema_model.sample_infill_float_mask( | |
| shape=x.shape, | |
| target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
| partial_mask=float_mask, | |
| model_kwargs=model_kwargs, | |
| ) | |
| # x: partially noise : (batch_size, seq_length, feature_dim) | |
| else: | |
| print("fast sampling") | |
| sample = self.ema.ema_model.fast_sample_infill_float_mask( | |
| shape=x.shape, | |
| target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
| partial_mask=float_mask, | |
| model_kwargs=model_kwargs, | |
| sampling_timesteps=sampling_steps, | |
| ) | |
| # unnormalize | |
| sample = (sample + 1) / 2 | |
| return sample.squeeze(0).detach().cpu().numpy() | |
| def restore( | |
| self, | |
| raw_dataloader, | |
| shape=None, | |
| coef=1e-1, | |
| stepsize=1e-1, | |
| sampling_steps=50, | |
| **kargs, | |
| ): | |
| if self.logger is not None: | |
| tic = time.time() | |
| self.logger.log_info("Begin to restore...") | |
| model_kwargs = {} | |
| model_kwargs["coef"] = coef | |
| model_kwargs["learning_rate"] = stepsize | |
| model_kwargs = {**model_kwargs, **kargs} | |
| test = kargs.get("test", False) | |
| samples = np.empty([0, shape[0], shape[1]]) # seq_length, feature_dim | |
| reals = np.empty([0, shape[0], shape[1]]) | |
| masks = np.empty([0, shape[0], shape[1]]) | |
| for idx, (x, t_m) in enumerate(raw_dataloader): | |
| # # take first 5 example | |
| # # x, t_m = x[:5], t_m[:5] | |
| # # x[~t_m] = 0 | |
| # # print(x, t_m) | |
| # # 1M 2021/2/10 9 | |
| # # 2M 2021/2/16 6+9 | |
| # # 3M 2021/2/19 9+9 | |
| # # 4M 2021/2/24 14+ | |
| # # 5M 2021/3/3 20+9 | |
| # x = torch.zeros_like(x)[:1] | |
| # # x[0, 0, 0] = 0.03 | |
| # # x[0, 9, 0] = 0.16 | |
| # # x[0, 15, 0] = 0.25 | |
| # # x[0, 18, 0] = 0.22 | |
| # # x[0, 24, 0] = 0.21 | |
| # # x[0, 33, 0] = 0.16 | |
| # x[0, 0, 0] = 0.04 | |
| # x[0, 2, 0] = 0.58 | |
| # x[0, 6, 0] = 0.27 | |
| # x[0, 58, 0] = 1. | |
| # x[0, -1, 0] = 0.05 | |
| # # x[0, 0, 0] = 0.01 | |
| # # x[0, -1, 0] = 0.01 | |
| # # x[0, -20, 0] = 0.01 | |
| # # x[0, -100, 0] = 0.01 | |
| # # x[0, -50, 0] = 0.01 | |
| # # x[0, -120, 0] = 0.01 | |
| # # import math | |
| # # for i in range(35, 240, 2): | |
| # # x[0, i, 0] = max(0.01, math.exp(-0.01*i) / 10) | |
| # # import matplotlib.pyplot as plt | |
| # # plt.plot(x[0, :, 0].detach().cpu().numpy()) | |
| # # plt.show() | |
| t_m = x == 0 # x != 0, 1 for observed, 0 for missing, bool tensor | |
| # # | |
| if test: | |
| t_m = t_m.type_as(x) | |
| binary_mask = t_m.clone() | |
| binary_mask[binary_mask > 0] = 1 | |
| else: | |
| binary_mask = t_m | |
| # x = x * 2 - 1 | |
| x, t_m = x.to(self.device), t_m.to(self.device) | |
| binary_mask = binary_mask.to(self.device) | |
| if sampling_steps == self.model.num_timesteps: | |
| print("normal sampling") | |
| sample = self.ema.ema_model.sample_infill( | |
| shape=x.shape, | |
| target=x * t_m, | |
| partial_mask=t_m, | |
| model_kwargs=model_kwargs, | |
| ) | |
| # x: partially noise : (batch_size, seq_length, feature_dim) | |
| else: | |
| print("fast sampling") | |
| if test: | |
| sample = self.ema.ema_model.fast_sample_infill_float_mask( | |
| shape=x.shape, | |
| target=x * binary_mask, # x * t_m, 1 for observed, 0 for missing | |
| partial_mask=t_m, | |
| model_kwargs=model_kwargs, | |
| sampling_timesteps=sampling_steps, | |
| ) | |
| else: | |
| sample = self.ema.ema_model.fast_sample_infill( | |
| shape=x.shape, | |
| target=x * t_m, | |
| partial_mask=t_m, | |
| model_kwargs=model_kwargs, | |
| sampling_timesteps=sampling_steps, | |
| ) | |
| samples = np.row_stack([samples, sample.detach().cpu().numpy()]) | |
| reals = np.row_stack([reals, x.detach().cpu().numpy()]) | |
| masks = np.row_stack([masks, t_m.detach().cpu().numpy()]) | |
| break | |
| if self.logger is not None: | |
| self.logger.log_info( | |
| "Imputation done, time: {:.2f}".format(time.time() - tic) | |
| ) | |
| return samples, reals, masks | |
| # return samples | |