| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from model.mlp import MultiLayerPerceptron |
| from model.TMRB import TMRB |
|
|
| class Basic_Model(nn.Module): |
| def __init__(self, args): |
| super(Basic_Model, self).__init__() |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| self.dropout = args.dropout |
| self.activation = nn.GELU() |
| self.num_feat = args.emb["num_feat"] |
| self.args = args |
| self.num_layer = args.emb["num_layer"] |
| |
| self.embed_dim = args.emb["adaptive_emb_dim"] |
| self.node_dim = args.emb["D^N"] |
| self.temp_dim_tid = args.emb["D^D"] |
| self.temp_dim_diw = args.emb["D^W"] |
| self.output_len = args.emb["output_len"] |
| self.tcn_dim = args.tcn["out_channel"] |
| self.is_TMRB = args.is_TMRB |
| self.is_update = args.is_update |
| self.select_k = args.select_k |
| self.TMRB_dropout = args.TMRB["dropout"] |
|
|
| self.node_embedding = nn.init.xavier_uniform_( |
| nn.Parameter(torch.empty(1, self.node_dim)) |
| ) |
| self.T_i_D_emb = nn.init.xavier_uniform_(nn.Parameter(torch.empty(288, self.temp_dim_tid))) |
| self.D_i_W_emb = nn.init.xavier_uniform_(nn.Parameter(torch.empty(7, self.temp_dim_diw))) |
| self.emb_layer_history = nn.Conv2d(in_channels=args.emb["input_dim"]*args.emb["input_len"], out_channels=self.embed_dim, kernel_size=(1, 1), bias=True) |
| self.tcn = nn.Conv1d(in_channels=args.tcn["in_channel"], out_channels=args.tcn["out_channel"], kernel_size=args.tcn["kernel_size"], \ |
| dilation=args.tcn["dilation"], padding=int((args.tcn["kernel_size"]-1)*args.tcn["dilation"]/2)) |
| self.hidden_dim = self.embed_dim + self.node_dim + args.TMRB["out_channel"]*self.is_TMRB +self.tcn_dim |
| self.encoder = nn.Sequential( |
| *[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)] |
| ) |
| self.projection_head = nn.Conv2d( |
| in_channels=self.hidden_dim, out_channels=self.output_len, kernel_size=(1, 1), bias=True |
| ) |
| self.online_backbone = self.encoder |
| self.online_projection = self.projection_head |
| self.target_backbone = self.encoder |
| self.target_projection = self.projection_head |
|
|
| self.momentum = args.momentum |
| self.TMRB = TMRB(input_dim=args.TMRB["in_channel"], out_dim=args.TMRB["out_channel"],top_k = args.TMRB["top_k"],TMRB_dropout=self.TMRB_dropout,is_update=self.is_update,select_k = self.select_k).to(self.device) |
| self.hidden_states_per_year = {} |
|
|
| def prepare_inputs(self, history_data): |
| batch_size, in_steps, num_nodes, num_channels = history_data.shape |
| node_emb = self.node_embedding.expand(size=(num_nodes, *self.node_embedding.shape)) |
| node_emb = node_emb.expand(size=(batch_size, *node_emb.shape)).transpose(1, 2) |
|
|
| time_in_day_feat = self.T_i_D_emb[(history_data[:, -1, :, self.num_feat] * 288).long()].to(self.device) |
| day_in_week_feat = self.D_i_W_emb[(history_data[:, -1, :, self.num_feat + 1]).long()].to(self.device) |
| |
| input_data = history_data[:, :, :, :self.num_feat] |
| return input_data, time_in_day_feat, day_in_week_feat, node_emb |
|
|
| def forward(self, data, year): |
| current_data = data['x'] |
| batch_size, in_steps, num_nodes, num_features = current_data.shape |
| input_data, time_in_day_feat, day_in_week_feat, node_emb = self.prepare_inputs(current_data) |
|
|
| current_data = current_data.transpose(1, 2).contiguous().view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1) |
| node_emb_list = [node_emb.transpose(1, -1)] |
| emb_history = self.emb_layer_history(current_data) |
| tcn_emb = self.tcn(emb_history.squeeze(-1)) |
| |
| tem_emb = torch.cat([time_in_day_feat, day_in_week_feat],dim=-1) |
| combined_features = torch.cat([emb_history] + node_emb_list + [tcn_emb.unsqueeze(-1)], dim=1) |
| |
| if self.is_TMRB: |
| hidden_state = self.TMRB(tem_emb, year,self.hidden_states_per_year) |
| self.hidden_states_per_year[year] = hidden_state.mean(dim=(0,2)) |
| combined_features = torch.cat((combined_features,hidden_state.unsqueeze(-1)), dim=1) |
|
|
| online_features = self.online_backbone(combined_features) |
| online_proj = self.online_projection(online_features) |
| return online_proj |
|
|
| def update_target_network(self): |
| with torch.no_grad(): |
| for param_o, param_t in zip(self.online_backbone.parameters(), self.target_backbone.parameters()): |
| param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum) |
| for param_o, param_t in zip(self.online_projection.parameters(), self.target_projection.parameters()): |
| param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum) |
|
|
| def calculate_similarity(self, online_proj, target_proj): |
| batch_size, time_steps, num_nodes, feature_dim = online_proj.shape |
| online_proj = online_proj.view(-1, feature_dim) |
| target_proj = target_proj.view(-1, feature_dim) |
| similarity = F.cosine_similarity(online_proj, target_proj) |
| similarity = similarity.view(batch_size, time_steps, num_nodes) |
| |
| top_k_values, top_k_indices = torch.topk(similarity, self.top_k, dim=-1) |
| return top_k_indices |
|
|
| def target_branch(self, data, year): |
| history_data = data['x'].to(self.device) |
| batch_size, in_steps, num_nodes, num_features = history_data.shape |
| input_data, time_in_day_feat, day_in_week_feat, node_emb = self.prepare_inputs(history_data) |
| target_aug = history_data |
| target_aug = target_aug.transpose(1, 2).contiguous().view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1) |
|
|
| node_emb_list = [node_emb.transpose(1, -1)] |
| emb_target = self.emb_layer_history(target_aug) |
| tcn_emb = self.tcn(emb_target.squeeze(-1)).unsqueeze(-1) |
|
|
| tem_emb = torch.cat([time_in_day_feat, day_in_week_feat],dim=-1) |
| combined_features = torch.cat([emb_target] + node_emb_list + [tcn_emb], dim=1) |
|
|
| if self.is_TMRB: |
| hidden_state = self.TMRB(tem_emb, year,self.hidden_states_per_year) |
| self.hidden_states_per_year[year] = hidden_state.mean(dim=(0,2)) |
| combined_features = torch.cat((combined_features,hidden_state.unsqueeze(-1)), dim=1) |
| |
| target_features = self.target_backbone(combined_features) |
| target_proj = self.target_projection(target_features) |
| |
| return target_proj |
|
|
| def contrastive_loss(self, online_proj, target_proj): |
| top_k_indices = self.calculate_similarity(online_proj, target_proj) |
| return top_k_indices |