Spaces:
Sleeping
Sleeping
| """ | |
| Neural network models for Khmer space injection | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import random | |
| class CRF(nn.Module): | |
| def __init__(self, num_tags): | |
| super().__init__() | |
| self.num_tags = num_tags | |
| self.transitions = nn.Parameter(torch.randn(num_tags, num_tags)) | |
| self.start_transitions = nn.Parameter(torch.randn(num_tags)) | |
| self.end_transitions = nn.Parameter(torch.randn(num_tags)) | |
| def forward(self, emissions, tags, mask): | |
| log_num = self._score_sentence(emissions, tags, mask) | |
| log_den = self._log_partition(emissions, mask) | |
| return torch.mean(log_den - log_num) | |
| def _score_sentence(self, emissions, tags, mask): | |
| score = self.start_transitions[tags[:, 0]] | |
| for t in range(emissions.size(1) - 1): | |
| score += emissions[:, t, tags[:, t]] | |
| score += self.transitions[tags[:, t], tags[:, t + 1]] * mask[:, t + 1] | |
| last_idx = mask.sum(1).long() - 1 | |
| last_tags = tags.gather(1, last_idx.unsqueeze(1)).squeeze() | |
| score += self.end_transitions[last_tags] | |
| return score | |
| def _log_partition(self, emissions, mask): | |
| alpha = self.start_transitions + emissions[:, 0] | |
| for t in range(1, emissions.size(1)): | |
| emit = emissions[:, t].unsqueeze(2) | |
| trans = self.transitions.unsqueeze(0) | |
| alpha = torch.logsumexp(alpha.unsqueeze(2) + emit + trans, dim=1) | |
| alpha *= mask[:, t].unsqueeze(1) | |
| return torch.logsumexp(alpha + self.end_transitions, dim=1) | |
| class RNN(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super().__init__() | |
| self.Wxh = nn.Linear(input_dim, hidden_dim) | |
| self.Whh = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| def forward(self, x_t, h_prev): | |
| return torch.tanh(self.Wxh(x_t) + self.Whh(h_prev)) | |
| class GRU(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super().__init__() | |
| self.z = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| self.r = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| self.h = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| def forward(self, x_t, h_prev): | |
| concat = torch.cat([x_t, h_prev], dim=-1) | |
| z_t = torch.sigmoid(self.z(concat)) | |
| r_t = torch.sigmoid(self.r(concat)) | |
| concat_reset = torch.cat([x_t, r_t * h_prev], dim=-1) | |
| h_tilde = torch.tanh(self.h(concat_reset)) | |
| return (1 - z_t) * h_prev + z_t * h_tilde | |
| class LSTM(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super().__init__() | |
| self.i = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| self.f = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| self.o = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| self.g = nn.Linear(input_dim + hidden_dim, hidden_dim) | |
| def forward(self, x_t, state): | |
| h_prev, c_prev = state | |
| concat = torch.cat([x_t, h_prev], dim=-1) | |
| i_t = torch.sigmoid(self.i(concat)) | |
| f_t = torch.sigmoid(self.f(concat)) | |
| o_t = torch.sigmoid(self.o(concat)) | |
| g_t = torch.tanh(self.g(concat)) | |
| c_t = f_t * c_prev + i_t * g_t | |
| h_t = o_t * torch.tanh(c_t) | |
| return h_t, c_t | |
| class BiRecurrentLayer(nn.Module): | |
| def __init__(self, cell_cls, input_dim, hidden_dim, bidirectional=True): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.bidirectional = bidirectional | |
| self.fw = cell_cls(input_dim, hidden_dim) | |
| if bidirectional: | |
| self.bw = cell_cls(input_dim, hidden_dim) | |
| def forward(self, x): | |
| B, T, _ = x.shape | |
| device = x.device | |
| H = self.hidden_dim | |
| # ---------- Forward ---------- | |
| h_fw = [] | |
| h = torch.zeros(B, H, device=device) | |
| c = torch.zeros_like(h) if isinstance(self.fw, LSTM) else None | |
| for t in range(T): | |
| if c is not None: | |
| h, c = self.fw(x[:, t], (h, c)) | |
| else: | |
| h = self.fw(x[:, t], h) | |
| h_fw.append(h) | |
| h_fw = torch.stack(h_fw, dim=1) | |
| if not self.bidirectional: | |
| return h_fw | |
| # ---------- Backward ---------- | |
| h_bw = [] | |
| h = torch.zeros(B, H, device=device) | |
| c = torch.zeros_like(h) if isinstance(self.bw, LSTM) else None | |
| for t in reversed(range(T)): | |
| if c is not None: | |
| h, c = self.bw(x[:, t], (h, c)) | |
| else: | |
| h = self.bw(x[:, t], h) | |
| h_bw.append(h) | |
| h_bw.reverse() | |
| h_bw = torch.stack(h_bw, dim=1) | |
| return torch.cat([h_fw, h_bw], dim=-1) | |
| class KhmerRNN(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| embedding_dim=128, | |
| hidden_dim=256, | |
| num_layers=2, | |
| dropout=0.3, | |
| bidirectional=True, | |
| rnn_type="lstm", | |
| residual=True, | |
| use_crf=True, | |
| ): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) | |
| self.dropout = nn.Dropout(dropout) | |
| self.residual = residual | |
| self.use_crf = use_crf | |
| cell_map = { | |
| "rnn": RNN, | |
| "gru": GRU, | |
| "lstm": LSTM, | |
| } | |
| cell_cls = cell_map[rnn_type.lower()] | |
| self.layers = nn.ModuleList() | |
| input_dim = embedding_dim | |
| for _ in range(num_layers): | |
| layer = BiRecurrentLayer( | |
| cell_cls=cell_cls, | |
| input_dim=input_dim, | |
| hidden_dim=hidden_dim, | |
| bidirectional=bidirectional, | |
| ) | |
| self.layers.append(layer) | |
| input_dim = hidden_dim * (2 if bidirectional else 1) | |
| self.fc = nn.Linear(input_dim, 2) | |
| if use_crf: | |
| self.crf = CRF(num_tags=2) | |
| def forward(self, x, tags=None, mask=None): | |
| out = self.embedding(x) | |
| for layer in self.layers: | |
| residual = out | |
| out = layer(out) | |
| if self.residual and out.shape == residual.shape: | |
| out = out + residual | |
| out = self.dropout(out) | |
| emissions = self.fc(out) | |
| if self.use_crf and tags is not None: | |
| return self.crf(emissions, tags, mask) | |
| return emissions | |