EDEN / eden /model.py
Rybib's picture
Upload EDEN model and code
2f65125 verified
Raw
History Blame Contribute Delete
3.59 kB
"""The EDEN encoder-decoder Transformer (training/inference reference model)."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import TrainConfig
from .constants import *
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0), persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[:, : x.size(1), :].to(dtype=x.dtype)
return self.dropout(x)
class EdenTransformer(nn.Module):
def __init__(self, cfg: TrainConfig):
super().__init__()
self.cfg = cfg
self.pad_id = PAD_ID
self.bos_id = BOS_ID
self.eos_id = EOS_ID
self.scale = math.sqrt(cfg.d_model)
self.embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=PAD_ID)
self.pos = PositionalEncoding(cfg.d_model, cfg.max_len + 4, cfg.dropout)
self.transformer = nn.Transformer(
d_model=cfg.d_model,
nhead=cfg.n_heads,
num_encoder_layers=cfg.n_layers,
num_decoder_layers=cfg.n_layers,
dim_feedforward=cfg.dim_feedforward,
dropout=cfg.dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight
self._reset_parameters()
def _reset_parameters(self) -> None:
for name, param in self.named_parameters():
if param.dim() > 1 and "embedding" not in name:
nn.init.xavier_uniform_(param)
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
with torch.no_grad():
self.embedding.weight[PAD_ID].zero_()
def parameter_count(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def encode(self, src: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
src_padding = src.eq(PAD_ID)
src_emb = self.pos(self.embedding(src) * self.scale)
memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding)
return memory, src_padding
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
src_padding: torch.Tensor,
) -> torch.Tensor:
tgt_padding = tgt.eq(PAD_ID)
tgt_emb = self.pos(self.embedding(tgt) * self.scale)
tgt_len = tgt.size(1)
causal_mask = torch.triu(
torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=tgt.device),
diagonal=1,
)
hidden = self.transformer.decoder(
tgt_emb,
memory,
tgt_mask=causal_mask,
tgt_key_padding_mask=tgt_padding,
memory_key_padding_mask=src_padding,
)
return hidden
def forward(self, src: torch.Tensor, tgt_in: torch.Tensor) -> torch.Tensor:
memory, src_padding = self.encode(src)
hidden = self.decode(tgt_in, memory, src_padding)
return self.lm_head(hidden)