plseng's picture
Upload net.py
0b75fc6 verified
"""
Neural network models for Khmer space injection
"""
import torch
import torch.nn as nn
import random
class CRF(nn.Module):
def __init__(self, num_tags):
super().__init__()
self.num_tags = num_tags
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.end_transitions = nn.Parameter(torch.randn(num_tags))
def forward(self, emissions, tags, mask):
log_num = self._score_sentence(emissions, tags, mask)
log_den = self._log_partition(emissions, mask)
return torch.mean(log_den - log_num)
def _score_sentence(self, emissions, tags, mask):
score = self.start_transitions[tags[:, 0]]
for t in range(emissions.size(1) - 1):
score += emissions[:, t, tags[:, t]]
score += self.transitions[tags[:, t], tags[:, t + 1]] * mask[:, t + 1]
last_idx = mask.sum(1).long() - 1
last_tags = tags.gather(1, last_idx.unsqueeze(1)).squeeze()
score += self.end_transitions[last_tags]
return score
def _log_partition(self, emissions, mask):
alpha = self.start_transitions + emissions[:, 0]
for t in range(1, emissions.size(1)):
emit = emissions[:, t].unsqueeze(2)
trans = self.transitions.unsqueeze(0)
alpha = torch.logsumexp(alpha.unsqueeze(2) + emit + trans, dim=1)
alpha *= mask[:, t].unsqueeze(1)
return torch.logsumexp(alpha + self.end_transitions, dim=1)
class RNN(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.Wxh = nn.Linear(input_dim, hidden_dim)
self.Whh = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, x_t, h_prev):
return torch.tanh(self.Wxh(x_t) + self.Whh(h_prev))
class GRU(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.z = nn.Linear(input_dim + hidden_dim, hidden_dim)
self.r = nn.Linear(input_dim + hidden_dim, hidden_dim)
self.h = nn.Linear(input_dim + hidden_dim, hidden_dim)
def forward(self, x_t, h_prev):
concat = torch.cat([x_t, h_prev], dim=-1)
z_t = torch.sigmoid(self.z(concat))
r_t = torch.sigmoid(self.r(concat))
concat_reset = torch.cat([x_t, r_t * h_prev], dim=-1)
h_tilde = torch.tanh(self.h(concat_reset))
return (1 - z_t) * h_prev + z_t * h_tilde
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.i = nn.Linear(input_dim + hidden_dim, hidden_dim)
self.f = nn.Linear(input_dim + hidden_dim, hidden_dim)
self.o = nn.Linear(input_dim + hidden_dim, hidden_dim)
self.g = nn.Linear(input_dim + hidden_dim, hidden_dim)
def forward(self, x_t, state):
h_prev, c_prev = state
concat = torch.cat([x_t, h_prev], dim=-1)
i_t = torch.sigmoid(self.i(concat))
f_t = torch.sigmoid(self.f(concat))
o_t = torch.sigmoid(self.o(concat))
g_t = torch.tanh(self.g(concat))
c_t = f_t * c_prev + i_t * g_t
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
class BiRecurrentLayer(nn.Module):
def __init__(self, cell_cls, input_dim, hidden_dim, bidirectional=True):
super().__init__()
self.hidden_dim = hidden_dim
self.bidirectional = bidirectional
self.fw = cell_cls(input_dim, hidden_dim)
if bidirectional:
self.bw = cell_cls(input_dim, hidden_dim)
def forward(self, x):
B, T, _ = x.shape
device = x.device
H = self.hidden_dim
# ---------- Forward ----------
h_fw = []
h = torch.zeros(B, H, device=device)
c = torch.zeros_like(h) if isinstance(self.fw, LSTM) else None
for t in range(T):
if c is not None:
h, c = self.fw(x[:, t], (h, c))
else:
h = self.fw(x[:, t], h)
h_fw.append(h)
h_fw = torch.stack(h_fw, dim=1)
if not self.bidirectional:
return h_fw
# ---------- Backward ----------
h_bw = []
h = torch.zeros(B, H, device=device)
c = torch.zeros_like(h) if isinstance(self.bw, LSTM) else None
for t in reversed(range(T)):
if c is not None:
h, c = self.bw(x[:, t], (h, c))
else:
h = self.bw(x[:, t], h)
h_bw.append(h)
h_bw.reverse()
h_bw = torch.stack(h_bw, dim=1)
return torch.cat([h_fw, h_bw], dim=-1)
class KhmerRNN(nn.Module):
def __init__(
self,
vocab_size,
embedding_dim=128,
hidden_dim=256,
num_layers=2,
dropout=0.3,
bidirectional=True,
rnn_type="lstm",
residual=True,
use_crf=True,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.dropout = nn.Dropout(dropout)
self.residual = residual
self.use_crf = use_crf
cell_map = {
"rnn": RNN,
"gru": GRU,
"lstm": LSTM,
}
cell_cls = cell_map[rnn_type.lower()]
self.layers = nn.ModuleList()
input_dim = embedding_dim
for _ in range(num_layers):
layer = BiRecurrentLayer(
cell_cls=cell_cls,
input_dim=input_dim,
hidden_dim=hidden_dim,
bidirectional=bidirectional,
)
self.layers.append(layer)
input_dim = hidden_dim * (2 if bidirectional else 1)
self.fc = nn.Linear(input_dim, 2)
if use_crf:
self.crf = CRF(num_tags=2)
def forward(self, x, tags=None, mask=None):
out = self.embedding(x)
for layer in self.layers:
residual = out
out = layer(out)
if self.residual and out.shape == residual.shape:
out = out + residual
out = self.dropout(out)
emissions = self.fc(out)
if self.use_crf and tags is not None:
return self.crf(emissions, tags, mask)
return emissions