Feature Extraction
Transformers
Safetensors
PyTorch
English
eden
text-enhancement
grammar-correction
text-rewriting
encoder-decoder
transformer
custom_code
Instructions to use Rybib/EDEN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Rybib/EDEN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Rybib/EDEN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Rybib/EDEN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """The EDEN encoder-decoder Transformer (training/inference reference model).""" | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .config import TrainConfig | |
| from .constants import * | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int, dropout: float): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe.unsqueeze(0), persistent=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pe[:, : x.size(1), :].to(dtype=x.dtype) | |
| return self.dropout(x) | |
| class EdenTransformer(nn.Module): | |
| def __init__(self, cfg: TrainConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.pad_id = PAD_ID | |
| self.bos_id = BOS_ID | |
| self.eos_id = EOS_ID | |
| self.scale = math.sqrt(cfg.d_model) | |
| self.embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=PAD_ID) | |
| self.pos = PositionalEncoding(cfg.d_model, cfg.max_len + 4, cfg.dropout) | |
| self.transformer = nn.Transformer( | |
| d_model=cfg.d_model, | |
| nhead=cfg.n_heads, | |
| num_encoder_layers=cfg.n_layers, | |
| num_decoder_layers=cfg.n_layers, | |
| dim_feedforward=cfg.dim_feedforward, | |
| dropout=cfg.dropout, | |
| activation="gelu", | |
| batch_first=True, | |
| norm_first=True, | |
| ) | |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.lm_head.weight = self.embedding.weight | |
| self._reset_parameters() | |
| def _reset_parameters(self) -> None: | |
| for name, param in self.named_parameters(): | |
| if param.dim() > 1 and "embedding" not in name: | |
| nn.init.xavier_uniform_(param) | |
| nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02) | |
| with torch.no_grad(): | |
| self.embedding.weight[PAD_ID].zero_() | |
| def parameter_count(self) -> int: | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def encode(self, src: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| src_padding = src.eq(PAD_ID) | |
| src_emb = self.pos(self.embedding(src) * self.scale) | |
| memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding) | |
| return memory, src_padding | |
| def decode( | |
| self, | |
| tgt: torch.Tensor, | |
| memory: torch.Tensor, | |
| src_padding: torch.Tensor, | |
| ) -> torch.Tensor: | |
| tgt_padding = tgt.eq(PAD_ID) | |
| tgt_emb = self.pos(self.embedding(tgt) * self.scale) | |
| tgt_len = tgt.size(1) | |
| causal_mask = torch.triu( | |
| torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=tgt.device), | |
| diagonal=1, | |
| ) | |
| hidden = self.transformer.decoder( | |
| tgt_emb, | |
| memory, | |
| tgt_mask=causal_mask, | |
| tgt_key_padding_mask=tgt_padding, | |
| memory_key_padding_mask=src_padding, | |
| ) | |
| return hidden | |
| def forward(self, src: torch.Tensor, tgt_in: torch.Tensor) -> torch.Tensor: | |
| memory, src_padding = self.encode(src) | |
| hidden = self.decode(tgt_in, memory, src_padding) | |
| return self.lm_head(hidden) | |