CoMemNet / src /model /TMRB.py
mei2333's picture
Upload src/model/TMRB.py with huggingface_hub
4ee9870 verified
Raw
History Blame Contribute Delete
3.36 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedUpdateCell(nn.Module):
def __init__(self, input_dim, hidden_dim,TMRB_dropout):
super(GatedUpdateCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.input_to_hidden = nn.Linear(input_dim, hidden_dim)
self.W_r = nn.Linear(hidden_dim*2, hidden_dim)
self.W_z = nn.Linear(hidden_dim*2, hidden_dim)
self.W_t = nn.Linear(hidden_dim*2, hidden_dim)
self.dropout = nn.Dropout(p=TMRB_dropout)
def forward(self, x, h_prev):
if x.size(-1) != self.hidden_dim:
x = self.input_to_hidden(x)
combined = self.dropout(torch.cat((x, h_prev), dim=-1))
r_t = torch.sigmoid(self.W_r(combined))
z_t = torch.sigmoid(self.W_z(combined))
h_t = torch.tanh(self.W_t(torch.cat((x, h_prev * r_t), dim=-1)))
h_next = z_t * h_t + (1 - z_t) * x
return h_next
class TMRB(nn.Module):
def __init__(self, input_dim, out_dim,top_k, TMRB_dropout,is_update,select_k,device="cuda:0"):
super(TMRB,self).__init__()
self.input_dim = input_dim
self.out_dim = out_dim
self.device = device
self.TMRB_dropout = TMRB_dropout
self.is_update = is_update
self.GatedUpdateCell = GatedUpdateCell(out_dim, out_dim,self.TMRB_dropout).to(self.device)
self.top_k = top_k
self.select_k = select_k
self.mlp = nn.Linear(self.top_k * self.input_dim, self.out_dim)
self.init_hidden = nn.Parameter(
nn.init.xavier_uniform_(torch.empty(input_dim, 1)))
self.dorpout = nn.Dropout(TMRB_dropout)
def forward(self, tem_emb, year, hidden_states_per_year):
B, N, D = tem_emb.shape
tem_emb = tem_emb.transpose(1,-1)
if year - 1 in hidden_states_per_year.keys():
prev_hidden = hidden_states_per_year[year - 1]
prev_hidden = prev_hidden.expand(size = (N,*prev_hidden.shape)).transpose(0,1)
else:
prev_hidden = self.init_hidden
prev_hidden = prev_hidden.expand(size = (N,*prev_hidden.shape)).transpose(0,1)
prev_hidden = prev_hidden.expand(size = (B,*prev_hidden.shape))
prev_hidden = prev_hidden.squeeze(-1)
if self.select_k:
time_step_diff = torch.abs(tem_emb - prev_hidden)
_, top_k_indices = torch.topk(time_step_diff, k=self.top_k, dim=-1, largest=True, sorted=False)
top_k_features = torch.gather(tem_emb, dim=2, index=top_k_indices)
else:
top_k_indices = torch.randint(0, N, (B, self.top_k), device=tem_emb.device)
top_k_features = torch.gather(tem_emb, dim=2, index=top_k_indices.unsqueeze(-1).expand(-1, -1, D))
top_k_features = top_k_features.view(B, -1)
time_step_input = self.mlp(top_k_features)
if self.is_update:
prev_hidden_avg = torch.mean(prev_hidden, dim=2)
prev_hidden_avg = prev_hidden_avg.view(B, -1)
updated_hidden_pool = self.GatedUpdateCell(time_step_input,prev_hidden_avg)
else:
updated_hidden_pool = time_step_input
updated_hidden_pool = updated_hidden_pool.unsqueeze(2).expand(B, self.out_dim, N)
return updated_hidden_pool