File size: 4,499 Bytes
2875fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch import nn


def get_quantile(samples, q, dim=1):
    return torch.quantile(samples, q, dim=dim).cpu().numpy()


def plot_sample(ori_data, gen_data, masks, sample_idx=0):
    plt.rcParams["font.size"] = 12
    fig, axes = plt.subplots(nrows=7, ncols=4, figsize=(12, 15))
    sample_num, seq_len, feat_dim = ori_data.shape
    observed = ori_data * masks

    quantiles = []
    quantiles.append(
        get_quantile(torch.from_numpy(gen_data), 0.5, dim=0) * (1 - masks) + observed
    )
    quantiles.append(
        get_quantile(torch.from_numpy(gen_data), 0.05, dim=0) * (1 - masks) + observed
    )
    quantiles.append(
        get_quantile(torch.from_numpy(gen_data), 0.95, dim=0) * (1 - masks) + observed
    )

    for feat_idx in range(feat_dim):
        row = feat_idx // 4
        col = feat_idx % 4

        df_x = pd.DataFrame(
            {
                "x": np.arange(0, seq_len),
                "val": ori_data[sample_idx, :, feat_idx],
                "y": masks[sample_idx, :, feat_idx],
            }
        )
        df_x = df_x[df_x.y != 0]

        df_o = pd.DataFrame(
            {
                "x": np.arange(0, seq_len),
                "val": ori_data[sample_idx, :, feat_idx],
                "y": (1 - masks)[sample_idx, :, feat_idx],
            }
        )
        df_o = df_o[df_o.y != 0]

        axes[row][col].plot(
            range(0, seq_len),
            quantiles[0][sample_idx, :, feat_idx],
            color="g",
            linestyle="solid",
            label="Diffusion-TS",
        )
        axes[row][col].fill_between(
            range(0, seq_len),
            quantiles[1][sample_idx, :, feat_idx],
            quantiles[2][sample_idx, :, feat_idx],
            color="g",
            alpha=0.3,
        )

        axes[row][col].plot(df_o.x, df_o.val, color="b", marker="o", linestyle="None")
        axes[row][col].plot(df_x.x, df_x.val, color="r", marker="x", linestyle="None")

        if col == 0:
            plt.setp(axes[row, 0], ylabel="value")
        if row == -1:
            plt.setp(axes[-1, col], xlabel="time")
    plt.tight_layout()
    plt.show()


class MaskedLoss(nn.Module):
    """Masked MSE Loss"""

    def __init__(self, reduction: str = "mean", mode="mse"):

        super().__init__()

        self.reduction = reduction
        if mode == "mse":
            self.loss = nn.MSELoss(reduction=self.reduction)
        else:
            self.loss = nn.L1Loss(reduction=self.reduction)

    def forward(

        self, y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.BoolTensor

    ) -> torch.Tensor:
        """Compute the loss between a target value and a prediction.



        Args:

            y_pred: Estimated values

            y_true: Target values

            mask: boolean tensor with 0s at places where values should be ignored and 1s where they should be considered



        Returns

        -------

        if reduction == 'none':

            (num_active,) Loss for each active batch element as a tensor with gradient attached.

        if reduction == 'mean':

            scalar mean loss over batch as a tensor with gradient attached.

        """

        # for this particular loss, one may also elementwise multiply y_pred and y_true with the inverted mask
        masked_pred = torch.masked_select(y_pred, mask)
        masked_true = torch.masked_select(y_true, mask)

        return self.loss(masked_pred, masked_true)


def random_mask(observed_values, missing_ratio=0.1, seed=1984):
    observed_masks = ~np.isnan(observed_values)

    # randomly set some percentage as ground-truth
    masks = observed_masks.reshape(-1).copy()
    obs_indices = np.where(masks)[0].tolist()

    # Store the state of the RNG to restore later.
    st0 = np.random.get_state()
    np.random.seed(seed)

    miss_indices = np.random.choice(
        obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False
    )

    # Restore RNG.
    np.random.set_state(st0)

    masks[miss_indices] = False
    gt_masks = masks.reshape(observed_masks.shape)

    observed_values = np.nan_to_num(observed_values)
    return (
        torch.from_numpy(observed_values).float(),
        torch.from_numpy(observed_masks).float(),
        torch.from_numpy(gt_masks).float(),
    )