Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from torch import nn | |
| def get_quantile(samples, q, dim=1): | |
| return torch.quantile(samples, q, dim=dim).cpu().numpy() | |
| def plot_sample(ori_data, gen_data, masks, sample_idx=0): | |
| plt.rcParams["font.size"] = 12 | |
| fig, axes = plt.subplots(nrows=7, ncols=4, figsize=(12, 15)) | |
| sample_num, seq_len, feat_dim = ori_data.shape | |
| observed = ori_data * masks | |
| quantiles = [] | |
| quantiles.append( | |
| get_quantile(torch.from_numpy(gen_data), 0.5, dim=0) * (1 - masks) + observed | |
| ) | |
| quantiles.append( | |
| get_quantile(torch.from_numpy(gen_data), 0.05, dim=0) * (1 - masks) + observed | |
| ) | |
| quantiles.append( | |
| get_quantile(torch.from_numpy(gen_data), 0.95, dim=0) * (1 - masks) + observed | |
| ) | |
| for feat_idx in range(feat_dim): | |
| row = feat_idx // 4 | |
| col = feat_idx % 4 | |
| df_x = pd.DataFrame( | |
| { | |
| "x": np.arange(0, seq_len), | |
| "val": ori_data[sample_idx, :, feat_idx], | |
| "y": masks[sample_idx, :, feat_idx], | |
| } | |
| ) | |
| df_x = df_x[df_x.y != 0] | |
| df_o = pd.DataFrame( | |
| { | |
| "x": np.arange(0, seq_len), | |
| "val": ori_data[sample_idx, :, feat_idx], | |
| "y": (1 - masks)[sample_idx, :, feat_idx], | |
| } | |
| ) | |
| df_o = df_o[df_o.y != 0] | |
| axes[row][col].plot( | |
| range(0, seq_len), | |
| quantiles[0][sample_idx, :, feat_idx], | |
| color="g", | |
| linestyle="solid", | |
| label="Diffusion-TS", | |
| ) | |
| axes[row][col].fill_between( | |
| range(0, seq_len), | |
| quantiles[1][sample_idx, :, feat_idx], | |
| quantiles[2][sample_idx, :, feat_idx], | |
| color="g", | |
| alpha=0.3, | |
| ) | |
| axes[row][col].plot(df_o.x, df_o.val, color="b", marker="o", linestyle="None") | |
| axes[row][col].plot(df_x.x, df_x.val, color="r", marker="x", linestyle="None") | |
| if col == 0: | |
| plt.setp(axes[row, 0], ylabel="value") | |
| if row == -1: | |
| plt.setp(axes[-1, col], xlabel="time") | |
| plt.tight_layout() | |
| plt.show() | |
| class MaskedLoss(nn.Module): | |
| """Masked MSE Loss""" | |
| def __init__(self, reduction: str = "mean", mode="mse"): | |
| super().__init__() | |
| self.reduction = reduction | |
| if mode == "mse": | |
| self.loss = nn.MSELoss(reduction=self.reduction) | |
| else: | |
| self.loss = nn.L1Loss(reduction=self.reduction) | |
| def forward( | |
| self, y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.BoolTensor | |
| ) -> torch.Tensor: | |
| """Compute the loss between a target value and a prediction. | |
| Args: | |
| y_pred: Estimated values | |
| y_true: Target values | |
| mask: boolean tensor with 0s at places where values should be ignored and 1s where they should be considered | |
| Returns | |
| ------- | |
| if reduction == 'none': | |
| (num_active,) Loss for each active batch element as a tensor with gradient attached. | |
| if reduction == 'mean': | |
| scalar mean loss over batch as a tensor with gradient attached. | |
| """ | |
| # for this particular loss, one may also elementwise multiply y_pred and y_true with the inverted mask | |
| masked_pred = torch.masked_select(y_pred, mask) | |
| masked_true = torch.masked_select(y_true, mask) | |
| return self.loss(masked_pred, masked_true) | |
| def random_mask(observed_values, missing_ratio=0.1, seed=1984): | |
| observed_masks = ~np.isnan(observed_values) | |
| # randomly set some percentage as ground-truth | |
| masks = observed_masks.reshape(-1).copy() | |
| obs_indices = np.where(masks)[0].tolist() | |
| # Store the state of the RNG to restore later. | |
| st0 = np.random.get_state() | |
| np.random.seed(seed) | |
| miss_indices = np.random.choice( | |
| obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False | |
| ) | |
| # Restore RNG. | |
| np.random.set_state(st0) | |
| masks[miss_indices] = False | |
| gt_masks = masks.reshape(observed_masks.shape) | |
| observed_values = np.nan_to_num(observed_values) | |
| return ( | |
| torch.from_numpy(observed_values).float(), | |
| torch.from_numpy(observed_masks).float(), | |
| torch.from_numpy(gt_masks).float(), | |
| ) | |