plseng commited on
Commit
0b75fc6
·
verified ·
1 Parent(s): 7364245

Upload net.py

Browse files
Files changed (1) hide show
  1. src/net.py +203 -0
src/net.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Neural network models for Khmer space injection
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import random
8
+
9
+ class CRF(nn.Module):
10
+ def __init__(self, num_tags):
11
+ super().__init__()
12
+ self.num_tags = num_tags
13
+ self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
14
+ self.start_transitions = nn.Parameter(torch.randn(num_tags))
15
+ self.end_transitions = nn.Parameter(torch.randn(num_tags))
16
+
17
+ def forward(self, emissions, tags, mask):
18
+ log_num = self._score_sentence(emissions, tags, mask)
19
+ log_den = self._log_partition(emissions, mask)
20
+ return torch.mean(log_den - log_num)
21
+
22
+ def _score_sentence(self, emissions, tags, mask):
23
+ score = self.start_transitions[tags[:, 0]]
24
+
25
+ for t in range(emissions.size(1) - 1):
26
+ score += emissions[:, t, tags[:, t]]
27
+ score += self.transitions[tags[:, t], tags[:, t + 1]] * mask[:, t + 1]
28
+
29
+ last_idx = mask.sum(1).long() - 1
30
+ last_tags = tags.gather(1, last_idx.unsqueeze(1)).squeeze()
31
+ score += self.end_transitions[last_tags]
32
+ return score
33
+
34
+ def _log_partition(self, emissions, mask):
35
+ alpha = self.start_transitions + emissions[:, 0]
36
+
37
+ for t in range(1, emissions.size(1)):
38
+ emit = emissions[:, t].unsqueeze(2)
39
+ trans = self.transitions.unsqueeze(0)
40
+ alpha = torch.logsumexp(alpha.unsqueeze(2) + emit + trans, dim=1)
41
+ alpha *= mask[:, t].unsqueeze(1)
42
+
43
+ return torch.logsumexp(alpha + self.end_transitions, dim=1)
44
+
45
+ class RNN(nn.Module):
46
+ def __init__(self, input_dim, hidden_dim):
47
+ super().__init__()
48
+ self.Wxh = nn.Linear(input_dim, hidden_dim)
49
+ self.Whh = nn.Linear(hidden_dim, hidden_dim, bias=False)
50
+
51
+ def forward(self, x_t, h_prev):
52
+ return torch.tanh(self.Wxh(x_t) + self.Whh(h_prev))
53
+
54
+ class GRU(nn.Module):
55
+ def __init__(self, input_dim, hidden_dim):
56
+ super().__init__()
57
+ self.z = nn.Linear(input_dim + hidden_dim, hidden_dim)
58
+ self.r = nn.Linear(input_dim + hidden_dim, hidden_dim)
59
+ self.h = nn.Linear(input_dim + hidden_dim, hidden_dim)
60
+
61
+ def forward(self, x_t, h_prev):
62
+ concat = torch.cat([x_t, h_prev], dim=-1)
63
+ z_t = torch.sigmoid(self.z(concat))
64
+ r_t = torch.sigmoid(self.r(concat))
65
+
66
+ concat_reset = torch.cat([x_t, r_t * h_prev], dim=-1)
67
+ h_tilde = torch.tanh(self.h(concat_reset))
68
+
69
+ return (1 - z_t) * h_prev + z_t * h_tilde
70
+
71
+ class LSTM(nn.Module):
72
+ def __init__(self, input_dim, hidden_dim):
73
+ super().__init__()
74
+ self.i = nn.Linear(input_dim + hidden_dim, hidden_dim)
75
+ self.f = nn.Linear(input_dim + hidden_dim, hidden_dim)
76
+ self.o = nn.Linear(input_dim + hidden_dim, hidden_dim)
77
+ self.g = nn.Linear(input_dim + hidden_dim, hidden_dim)
78
+
79
+ def forward(self, x_t, state):
80
+ h_prev, c_prev = state
81
+ concat = torch.cat([x_t, h_prev], dim=-1)
82
+
83
+ i_t = torch.sigmoid(self.i(concat))
84
+ f_t = torch.sigmoid(self.f(concat))
85
+ o_t = torch.sigmoid(self.o(concat))
86
+ g_t = torch.tanh(self.g(concat))
87
+
88
+ c_t = f_t * c_prev + i_t * g_t
89
+ h_t = o_t * torch.tanh(c_t)
90
+ return h_t, c_t
91
+
92
+ class BiRecurrentLayer(nn.Module):
93
+ def __init__(self, cell_cls, input_dim, hidden_dim, bidirectional=True):
94
+ super().__init__()
95
+ self.hidden_dim = hidden_dim
96
+ self.bidirectional = bidirectional
97
+
98
+ self.fw = cell_cls(input_dim, hidden_dim)
99
+ if bidirectional:
100
+ self.bw = cell_cls(input_dim, hidden_dim)
101
+
102
+ def forward(self, x):
103
+ B, T, _ = x.shape
104
+ device = x.device
105
+ H = self.hidden_dim
106
+
107
+ # ---------- Forward ----------
108
+ h_fw = []
109
+ h = torch.zeros(B, H, device=device)
110
+ c = torch.zeros_like(h) if isinstance(self.fw, LSTM) else None
111
+
112
+ for t in range(T):
113
+ if c is not None:
114
+ h, c = self.fw(x[:, t], (h, c))
115
+ else:
116
+ h = self.fw(x[:, t], h)
117
+ h_fw.append(h)
118
+
119
+ h_fw = torch.stack(h_fw, dim=1)
120
+
121
+ if not self.bidirectional:
122
+ return h_fw
123
+
124
+ # ---------- Backward ----------
125
+ h_bw = []
126
+ h = torch.zeros(B, H, device=device)
127
+ c = torch.zeros_like(h) if isinstance(self.bw, LSTM) else None
128
+
129
+ for t in reversed(range(T)):
130
+ if c is not None:
131
+ h, c = self.bw(x[:, t], (h, c))
132
+ else:
133
+ h = self.bw(x[:, t], h)
134
+ h_bw.append(h)
135
+
136
+ h_bw.reverse()
137
+ h_bw = torch.stack(h_bw, dim=1)
138
+
139
+ return torch.cat([h_fw, h_bw], dim=-1)
140
+
141
+ class KhmerRNN(nn.Module):
142
+ def __init__(
143
+ self,
144
+ vocab_size,
145
+ embedding_dim=128,
146
+ hidden_dim=256,
147
+ num_layers=2,
148
+ dropout=0.3,
149
+ bidirectional=True,
150
+ rnn_type="lstm",
151
+ residual=True,
152
+ use_crf=True,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
157
+ self.dropout = nn.Dropout(dropout)
158
+ self.residual = residual
159
+ self.use_crf = use_crf
160
+
161
+ cell_map = {
162
+ "rnn": RNN,
163
+ "gru": GRU,
164
+ "lstm": LSTM,
165
+ }
166
+ cell_cls = cell_map[rnn_type.lower()]
167
+
168
+ self.layers = nn.ModuleList()
169
+ input_dim = embedding_dim
170
+
171
+ for _ in range(num_layers):
172
+ layer = BiRecurrentLayer(
173
+ cell_cls=cell_cls,
174
+ input_dim=input_dim,
175
+ hidden_dim=hidden_dim,
176
+ bidirectional=bidirectional,
177
+ )
178
+ self.layers.append(layer)
179
+ input_dim = hidden_dim * (2 if bidirectional else 1)
180
+
181
+ self.fc = nn.Linear(input_dim, 2)
182
+
183
+ if use_crf:
184
+ self.crf = CRF(num_tags=2)
185
+
186
+ def forward(self, x, tags=None, mask=None):
187
+ out = self.embedding(x)
188
+
189
+ for layer in self.layers:
190
+ residual = out
191
+ out = layer(out)
192
+
193
+ if self.residual and out.shape == residual.shape:
194
+ out = out + residual
195
+
196
+ out = self.dropout(out)
197
+
198
+ emissions = self.fc(out)
199
+
200
+ if self.use_crf and tags is not None:
201
+ return self.crf(emissions, tags, mask)
202
+
203
+ return emissions