import torch import torch.nn as nn import torch.nn.functional as F class GatedUpdateCell(nn.Module): def __init__(self, input_dim, hidden_dim,TMRB_dropout): super(GatedUpdateCell, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.input_to_hidden = nn.Linear(input_dim, hidden_dim) self.W_r = nn.Linear(hidden_dim*2, hidden_dim) self.W_z = nn.Linear(hidden_dim*2, hidden_dim) self.W_t = nn.Linear(hidden_dim*2, hidden_dim) self.dropout = nn.Dropout(p=TMRB_dropout) def forward(self, x, h_prev): if x.size(-1) != self.hidden_dim: x = self.input_to_hidden(x) combined = self.dropout(torch.cat((x, h_prev), dim=-1)) r_t = torch.sigmoid(self.W_r(combined)) z_t = torch.sigmoid(self.W_z(combined)) h_t = torch.tanh(self.W_t(torch.cat((x, h_prev * r_t), dim=-1))) h_next = z_t * h_t + (1 - z_t) * x return h_next class TMRB(nn.Module): def __init__(self, input_dim, out_dim,top_k, TMRB_dropout,is_update,select_k,device="cuda:0"): super(TMRB,self).__init__() self.input_dim = input_dim self.out_dim = out_dim self.device = device self.TMRB_dropout = TMRB_dropout self.is_update = is_update self.GatedUpdateCell = GatedUpdateCell(out_dim, out_dim,self.TMRB_dropout).to(self.device) self.top_k = top_k self.select_k = select_k self.mlp = nn.Linear(self.top_k * self.input_dim, self.out_dim) self.init_hidden = nn.Parameter( nn.init.xavier_uniform_(torch.empty(input_dim, 1))) self.dorpout = nn.Dropout(TMRB_dropout) def forward(self, tem_emb, year, hidden_states_per_year): B, N, D = tem_emb.shape tem_emb = tem_emb.transpose(1,-1) if year - 1 in hidden_states_per_year.keys(): prev_hidden = hidden_states_per_year[year - 1] prev_hidden = prev_hidden.expand(size = (N,*prev_hidden.shape)).transpose(0,1) else: prev_hidden = self.init_hidden prev_hidden = prev_hidden.expand(size = (N,*prev_hidden.shape)).transpose(0,1) prev_hidden = prev_hidden.expand(size = (B,*prev_hidden.shape)) prev_hidden = prev_hidden.squeeze(-1) if self.select_k: time_step_diff = torch.abs(tem_emb - prev_hidden) _, top_k_indices = torch.topk(time_step_diff, k=self.top_k, dim=-1, largest=True, sorted=False) top_k_features = torch.gather(tem_emb, dim=2, index=top_k_indices) else: top_k_indices = torch.randint(0, N, (B, self.top_k), device=tem_emb.device) top_k_features = torch.gather(tem_emb, dim=2, index=top_k_indices.unsqueeze(-1).expand(-1, -1, D)) top_k_features = top_k_features.view(B, -1) time_step_input = self.mlp(top_k_features) if self.is_update: prev_hidden_avg = torch.mean(prev_hidden, dim=2) prev_hidden_avg = prev_hidden_avg.view(B, -1) updated_hidden_pool = self.GatedUpdateCell(time_step_input,prev_hidden_avg) else: updated_hidden_pool = time_step_input updated_hidden_pool = updated_hidden_pool.unsqueeze(2).expand(B, self.out_dim, N) return updated_hidden_pool