File size: 6,799 Bytes
b731740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.mlp import MultiLayerPerceptron
from model.TMRB import TMRB

class Basic_Model(nn.Module):
    def __init__(self, args):
        super(Basic_Model, self).__init__()
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.dropout = args.dropout
        self.activation = nn.GELU()
        self.num_feat = args.emb["num_feat"]
        self.args = args
        self.num_layer = args.emb["num_layer"]
    
        self.embed_dim = args.emb["adaptive_emb_dim"]
        self.node_dim = args.emb["D^N"]
        self.temp_dim_tid = args.emb["D^D"]
        self.temp_dim_diw = args.emb["D^W"]
        self.output_len = args.emb["output_len"]
        self.tcn_dim = args.tcn["out_channel"]
        self.is_TMRB = args.is_TMRB
        self.is_update = args.is_update
        self.select_k = args.select_k
        self.TMRB_dropout = args.TMRB["dropout"]

        self.node_embedding = nn.init.xavier_uniform_(
                nn.Parameter(torch.empty(1, self.node_dim))
            )
        self.T_i_D_emb = nn.init.xavier_uniform_(nn.Parameter(torch.empty(288, self.temp_dim_tid)))
        self.D_i_W_emb  = nn.init.xavier_uniform_(nn.Parameter(torch.empty(7, self.temp_dim_diw)))
        self.emb_layer_history = nn.Conv2d(in_channels=args.emb["input_dim"]*args.emb["input_len"], out_channels=self.embed_dim, kernel_size=(1, 1), bias=True)
        self.tcn = nn.Conv1d(in_channels=args.tcn["in_channel"], out_channels=args.tcn["out_channel"], kernel_size=args.tcn["kernel_size"], \
            dilation=args.tcn["dilation"], padding=int((args.tcn["kernel_size"]-1)*args.tcn["dilation"]/2))
        self.hidden_dim = self.embed_dim + self.node_dim + args.TMRB["out_channel"]*self.is_TMRB +self.tcn_dim
        self.encoder = nn.Sequential(
            *[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)]
        )
        self.projection_head = nn.Conv2d(
            in_channels=self.hidden_dim, out_channels=self.output_len, kernel_size=(1, 1), bias=True
        )
        self.online_backbone = self.encoder
        self.online_projection = self.projection_head
        self.target_backbone = self.encoder
        self.target_projection = self.projection_head

        self.momentum = args.momentum
        self.TMRB = TMRB(input_dim=args.TMRB["in_channel"], out_dim=args.TMRB["out_channel"],top_k = args.TMRB["top_k"],TMRB_dropout=self.TMRB_dropout,is_update=self.is_update,select_k = self.select_k).to(self.device)
        self.hidden_states_per_year = {}

    def prepare_inputs(self, history_data):
        batch_size, in_steps, num_nodes, num_channels = history_data.shape
        node_emb = self.node_embedding.expand(size=(num_nodes, *self.node_embedding.shape))
        node_emb = node_emb.expand(size=(batch_size, *node_emb.shape)).transpose(1, 2)

        time_in_day_feat = self.T_i_D_emb[(history_data[:, -1, :, self.num_feat] * 288).long()].to(self.device)
        day_in_week_feat = self.D_i_W_emb[(history_data[:, -1, :, self.num_feat + 1]).long()].to(self.device)
        
        input_data = history_data[:, :, :, :self.num_feat]
        return input_data, time_in_day_feat, day_in_week_feat, node_emb

    def forward(self, data, year):  
        current_data = data['x']
        batch_size, in_steps, num_nodes, num_features = current_data.shape
        input_data, time_in_day_feat, day_in_week_feat, node_emb = self.prepare_inputs(current_data)

        current_data = current_data.transpose(1, 2).contiguous().view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1)
        node_emb_list = [node_emb.transpose(1, -1)]
        emb_history = self.emb_layer_history(current_data)
        tcn_emb = self.tcn(emb_history.squeeze(-1))
        
        tem_emb = torch.cat([time_in_day_feat, day_in_week_feat],dim=-1)
        combined_features = torch.cat([emb_history] + node_emb_list + [tcn_emb.unsqueeze(-1)], dim=1)
        
        if self.is_TMRB:
            hidden_state = self.TMRB(tem_emb, year,self.hidden_states_per_year)
            self.hidden_states_per_year[year] = hidden_state.mean(dim=(0,2))
            combined_features = torch.cat((combined_features,hidden_state.unsqueeze(-1)), dim=1) 

        online_features = self.online_backbone(combined_features)
        online_proj = self.online_projection(online_features)
        return online_proj 

    def update_target_network(self):
        with torch.no_grad():
            for param_o, param_t in zip(self.online_backbone.parameters(), self.target_backbone.parameters()):
                param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum)
            for param_o, param_t in zip(self.online_projection.parameters(), self.target_projection.parameters()):
                param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum)

    def calculate_similarity(self, online_proj, target_proj):
        batch_size, time_steps, num_nodes, feature_dim = online_proj.shape
        online_proj = online_proj.view(-1, feature_dim)
        target_proj = target_proj.view(-1, feature_dim)
        similarity = F.cosine_similarity(online_proj, target_proj)
        similarity = similarity.view(batch_size, time_steps, num_nodes)
        
        top_k_values, top_k_indices = torch.topk(similarity, self.top_k, dim=-1)
        return top_k_indices

    def target_branch(self, data, year):
        history_data = data['x'].to(self.device)
        batch_size, in_steps, num_nodes, num_features = history_data.shape
        input_data, time_in_day_feat, day_in_week_feat, node_emb = self.prepare_inputs(history_data)
        target_aug = history_data
        target_aug = target_aug.transpose(1, 2).contiguous().view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1)

        node_emb_list = [node_emb.transpose(1, -1)]
        emb_target = self.emb_layer_history(target_aug)
        tcn_emb = self.tcn(emb_target.squeeze(-1)).unsqueeze(-1)

        tem_emb = torch.cat([time_in_day_feat, day_in_week_feat],dim=-1)
        combined_features = torch.cat([emb_target] + node_emb_list  + [tcn_emb], dim=1)

        if self.is_TMRB:
            hidden_state = self.TMRB(tem_emb, year,self.hidden_states_per_year)
            self.hidden_states_per_year[year] = hidden_state.mean(dim=(0,2))
            combined_features = torch.cat((combined_features,hidden_state.unsqueeze(-1)), dim=1)
        
        target_features = self.target_backbone(combined_features)
        target_proj = self.target_projection(target_features)
        
        return target_proj

    def contrastive_loss(self, online_proj, target_proj):
        top_k_indices = self.calculate_similarity(online_proj, target_proj)
        return top_k_indices