File size: 3,364 Bytes
4ee9870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
import torch.nn.functional as F

class GatedUpdateCell(nn.Module):
    def __init__(self, input_dim, hidden_dim,TMRB_dropout):
        super(GatedUpdateCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.input_to_hidden = nn.Linear(input_dim, hidden_dim)
        
        self.W_r = nn.Linear(hidden_dim*2, hidden_dim)
        self.W_z = nn.Linear(hidden_dim*2, hidden_dim)
        self.W_t = nn.Linear(hidden_dim*2, hidden_dim)
        self.dropout = nn.Dropout(p=TMRB_dropout)

    def forward(self, x, h_prev):
        if x.size(-1) != self.hidden_dim:
            x = self.input_to_hidden(x)
        combined = self.dropout(torch.cat((x, h_prev), dim=-1))
        r_t = torch.sigmoid(self.W_r(combined))
        z_t = torch.sigmoid(self.W_z(combined))
        h_t = torch.tanh(self.W_t(torch.cat((x, h_prev * r_t), dim=-1)))
        h_next = z_t * h_t + (1 - z_t) * x
        return h_next

class TMRB(nn.Module):
    def __init__(self, input_dim, out_dim,top_k, TMRB_dropout,is_update,select_k,device="cuda:0"):
        super(TMRB,self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.device = device
        self.TMRB_dropout = TMRB_dropout
        self.is_update = is_update 
        self.GatedUpdateCell = GatedUpdateCell(out_dim, out_dim,self.TMRB_dropout).to(self.device)
        self.top_k = top_k
        self.select_k = select_k
        self.mlp = nn.Linear(self.top_k * self.input_dim, self.out_dim)
        self.init_hidden = nn.Parameter(
    nn.init.xavier_uniform_(torch.empty(input_dim, 1)))
        self.dorpout = nn.Dropout(TMRB_dropout)
        
    def forward(self, tem_emb, year, hidden_states_per_year):
        B, N, D = tem_emb.shape
        tem_emb = tem_emb.transpose(1,-1)
        
        if year - 1 in hidden_states_per_year.keys():
            prev_hidden = hidden_states_per_year[year - 1]
            prev_hidden = prev_hidden.expand(size = (N,*prev_hidden.shape)).transpose(0,1)
        else:
            prev_hidden = self.init_hidden
            prev_hidden = prev_hidden.expand(size = (N,*prev_hidden.shape)).transpose(0,1)

        prev_hidden = prev_hidden.expand(size = (B,*prev_hidden.shape))
        prev_hidden = prev_hidden.squeeze(-1)
        
        if self.select_k:
            time_step_diff = torch.abs(tem_emb - prev_hidden)
            _, top_k_indices = torch.topk(time_step_diff, k=self.top_k, dim=-1, largest=True, sorted=False)
            top_k_features = torch.gather(tem_emb, dim=2, index=top_k_indices)
            

        else:
            top_k_indices = torch.randint(0, N, (B, self.top_k), device=tem_emb.device)
            top_k_features = torch.gather(tem_emb, dim=2, index=top_k_indices.unsqueeze(-1).expand(-1, -1, D))

        top_k_features = top_k_features.view(B, -1)
        time_step_input = self.mlp(top_k_features)
        
        if self.is_update:
            prev_hidden_avg = torch.mean(prev_hidden, dim=2)
            prev_hidden_avg = prev_hidden_avg.view(B, -1)
            updated_hidden_pool = self.GatedUpdateCell(time_step_input,prev_hidden_avg)
        else:
            updated_hidden_pool = time_step_input

        updated_hidden_pool = updated_hidden_pool.unsqueeze(2).expand(B, self.out_dim, N)
        return updated_hidden_pool