CoMemNet / src /model /model.py
mei2333's picture
Upload src/model/model.py with huggingface_hub
b731740 verified
Raw
History Blame Contribute Delete
6.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.mlp import MultiLayerPerceptron
from model.TMRB import TMRB
class Basic_Model(nn.Module):
def __init__(self, args):
super(Basic_Model, self).__init__()
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.dropout = args.dropout
self.activation = nn.GELU()
self.num_feat = args.emb["num_feat"]
self.args = args
self.num_layer = args.emb["num_layer"]
self.embed_dim = args.emb["adaptive_emb_dim"]
self.node_dim = args.emb["D^N"]
self.temp_dim_tid = args.emb["D^D"]
self.temp_dim_diw = args.emb["D^W"]
self.output_len = args.emb["output_len"]
self.tcn_dim = args.tcn["out_channel"]
self.is_TMRB = args.is_TMRB
self.is_update = args.is_update
self.select_k = args.select_k
self.TMRB_dropout = args.TMRB["dropout"]
self.node_embedding = nn.init.xavier_uniform_(
nn.Parameter(torch.empty(1, self.node_dim))
)
self.T_i_D_emb = nn.init.xavier_uniform_(nn.Parameter(torch.empty(288, self.temp_dim_tid)))
self.D_i_W_emb = nn.init.xavier_uniform_(nn.Parameter(torch.empty(7, self.temp_dim_diw)))
self.emb_layer_history = nn.Conv2d(in_channels=args.emb["input_dim"]*args.emb["input_len"], out_channels=self.embed_dim, kernel_size=(1, 1), bias=True)
self.tcn = nn.Conv1d(in_channels=args.tcn["in_channel"], out_channels=args.tcn["out_channel"], kernel_size=args.tcn["kernel_size"], \
dilation=args.tcn["dilation"], padding=int((args.tcn["kernel_size"]-1)*args.tcn["dilation"]/2))
self.hidden_dim = self.embed_dim + self.node_dim + args.TMRB["out_channel"]*self.is_TMRB +self.tcn_dim
self.encoder = nn.Sequential(
*[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)]
)
self.projection_head = nn.Conv2d(
in_channels=self.hidden_dim, out_channels=self.output_len, kernel_size=(1, 1), bias=True
)
self.online_backbone = self.encoder
self.online_projection = self.projection_head
self.target_backbone = self.encoder
self.target_projection = self.projection_head
self.momentum = args.momentum
self.TMRB = TMRB(input_dim=args.TMRB["in_channel"], out_dim=args.TMRB["out_channel"],top_k = args.TMRB["top_k"],TMRB_dropout=self.TMRB_dropout,is_update=self.is_update,select_k = self.select_k).to(self.device)
self.hidden_states_per_year = {}
def prepare_inputs(self, history_data):
batch_size, in_steps, num_nodes, num_channels = history_data.shape
node_emb = self.node_embedding.expand(size=(num_nodes, *self.node_embedding.shape))
node_emb = node_emb.expand(size=(batch_size, *node_emb.shape)).transpose(1, 2)
time_in_day_feat = self.T_i_D_emb[(history_data[:, -1, :, self.num_feat] * 288).long()].to(self.device)
day_in_week_feat = self.D_i_W_emb[(history_data[:, -1, :, self.num_feat + 1]).long()].to(self.device)
input_data = history_data[:, :, :, :self.num_feat]
return input_data, time_in_day_feat, day_in_week_feat, node_emb
def forward(self, data, year):
current_data = data['x']
batch_size, in_steps, num_nodes, num_features = current_data.shape
input_data, time_in_day_feat, day_in_week_feat, node_emb = self.prepare_inputs(current_data)
current_data = current_data.transpose(1, 2).contiguous().view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1)
node_emb_list = [node_emb.transpose(1, -1)]
emb_history = self.emb_layer_history(current_data)
tcn_emb = self.tcn(emb_history.squeeze(-1))
tem_emb = torch.cat([time_in_day_feat, day_in_week_feat],dim=-1)
combined_features = torch.cat([emb_history] + node_emb_list + [tcn_emb.unsqueeze(-1)], dim=1)
if self.is_TMRB:
hidden_state = self.TMRB(tem_emb, year,self.hidden_states_per_year)
self.hidden_states_per_year[year] = hidden_state.mean(dim=(0,2))
combined_features = torch.cat((combined_features,hidden_state.unsqueeze(-1)), dim=1)
online_features = self.online_backbone(combined_features)
online_proj = self.online_projection(online_features)
return online_proj
def update_target_network(self):
with torch.no_grad():
for param_o, param_t in zip(self.online_backbone.parameters(), self.target_backbone.parameters()):
param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum)
for param_o, param_t in zip(self.online_projection.parameters(), self.target_projection.parameters()):
param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum)
def calculate_similarity(self, online_proj, target_proj):
batch_size, time_steps, num_nodes, feature_dim = online_proj.shape
online_proj = online_proj.view(-1, feature_dim)
target_proj = target_proj.view(-1, feature_dim)
similarity = F.cosine_similarity(online_proj, target_proj)
similarity = similarity.view(batch_size, time_steps, num_nodes)
top_k_values, top_k_indices = torch.topk(similarity, self.top_k, dim=-1)
return top_k_indices
def target_branch(self, data, year):
history_data = data['x'].to(self.device)
batch_size, in_steps, num_nodes, num_features = history_data.shape
input_data, time_in_day_feat, day_in_week_feat, node_emb = self.prepare_inputs(history_data)
target_aug = history_data
target_aug = target_aug.transpose(1, 2).contiguous().view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1)
node_emb_list = [node_emb.transpose(1, -1)]
emb_target = self.emb_layer_history(target_aug)
tcn_emb = self.tcn(emb_target.squeeze(-1)).unsqueeze(-1)
tem_emb = torch.cat([time_in_day_feat, day_in_week_feat],dim=-1)
combined_features = torch.cat([emb_target] + node_emb_list + [tcn_emb], dim=1)
if self.is_TMRB:
hidden_state = self.TMRB(tem_emb, year,self.hidden_states_per_year)
self.hidden_states_per_year[year] = hidden_state.mean(dim=(0,2))
combined_features = torch.cat((combined_features,hidden_state.unsqueeze(-1)), dim=1)
target_features = self.target_backbone(combined_features)
target_proj = self.target_projection(target_features)
return target_proj
def contrastive_loss(self, online_proj, target_proj):
top_k_indices = self.calculate_similarity(online_proj, target_proj)
return top_k_indices