Feature Extraction
Transformers
Safetensors
PyTorch
English
eden
text-enhancement
grammar-correction
text-rewriting
encoder-decoder
transformer
custom_code
Instructions to use Rybib/EDEN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Rybib/EDEN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Rybib/EDEN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Rybib/EDEN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """EDEN model for Hugging Face Transformers. | |
| EDEN (Encoder Decoder Enhancement Network) is a from-scratch encoder-decoder | |
| Transformer that rewrites rough text into polished text. This module wraps the | |
| original architecture in a ``PreTrainedModel`` so the model can be loaded with | |
| ``AutoModel.from_pretrained(..., trust_remote_code=True)`` and saved with | |
| ``save_pretrained`` as safetensors. | |
| The layer structure (embedding, positional encoding, ``nn.Transformer``, tied | |
| language-model head) matches the original training code exactly, so checkpoints | |
| trained with the standalone trainer load here without any key remapping beyond | |
| loading the model weights. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import re | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import Seq2SeqLMOutput | |
| # When this file is loaded from the Hugging Face Hub it lives inside a package, | |
| # so the sibling config is a relative import. When imported directly by a local | |
| # script (for example the conversion script) it is a top-level module instead. | |
| # The try/except supports both, and Transformers ignores imports inside a try | |
| # block when checking dependencies. | |
| try: | |
| from .configuration_eden import EdenConfig | |
| except ImportError: | |
| from configuration_eden import EdenConfig | |
| def _normalise_text(text: str) -> str: | |
| text = str(text or "") | |
| text = text.replace("‘", "'").replace("’", "'") | |
| text = text.replace("“", '"').replace("”", '"') | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def _sentence_split(text: str) -> list[str]: | |
| parts = re.split(r"(?<=[.!?])\s+", text.strip()) | |
| return [p for p in parts if p] | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int, dropout: float): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| # Persistent so the table is written to safetensors and restored on load. | |
| # Transformers initialises models on the meta device, which would leave a | |
| # non-persistent buffer uninitialised (NaN) after from_pretrained. | |
| self.register_buffer("pe", pe.unsqueeze(0), persistent=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pe[:, : x.size(1), :].to(dtype=x.dtype) | |
| return self.dropout(x) | |
| class EdenPreTrainedModel(PreTrainedModel): | |
| config_class = EdenConfig | |
| base_model_prefix = "eden" | |
| supports_gradient_checkpointing = False | |
| def _init_weights(self, module: nn.Module) -> None: | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.padding_idx is not None: | |
| with torch.no_grad(): | |
| module.weight[module.padding_idx].zero_() | |
| class EdenForTextEnhancement(EdenPreTrainedModel): | |
| """Encoder-decoder Transformer with a tied language-model head.""" | |
| _tied_weights_keys = {"lm_head.weight": "embedding.weight"} | |
| def __init__(self, config: EdenConfig): | |
| super().__init__(config) | |
| self.pad_id = config.pad_token_id | |
| self.bos_id = config.bos_token_id | |
| self.eos_id = config.eos_token_id | |
| self.unk_id = config.unk_token_id | |
| self.scale = math.sqrt(config.d_model) | |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id) | |
| self.pos = PositionalEncoding(config.d_model, config.max_len + 4, config.dropout) | |
| self.transformer = nn.Transformer( | |
| d_model=config.d_model, | |
| nhead=config.n_heads, | |
| num_encoder_layers=config.n_layers, | |
| num_decoder_layers=config.n_layers, | |
| dim_feedforward=config.dim_feedforward, | |
| dropout=config.dropout, | |
| activation="gelu", | |
| batch_first=True, | |
| norm_first=True, | |
| ) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| self.post_init() | |
| # ------------------------------------------------------------------ # | |
| # Hugging Face plumbing | |
| # ------------------------------------------------------------------ # | |
| def get_input_embeddings(self) -> nn.Module: | |
| return self.embedding | |
| def set_input_embeddings(self, value: nn.Module) -> None: | |
| self.embedding = value | |
| def get_output_embeddings(self) -> nn.Module: | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: | |
| self.lm_head = new_embeddings | |
| def _tie_weights(self) -> None: | |
| self.lm_head.weight = self.embedding.weight | |
| # ------------------------------------------------------------------ # | |
| # Core encoder-decoder | |
| # ------------------------------------------------------------------ # | |
| def encode(self, src: torch.Tensor): | |
| src_padding = src.eq(self.pad_id) | |
| src_emb = self.pos(self.embedding(src) * self.scale) | |
| memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding) | |
| return memory, src_padding | |
| def decode(self, tgt: torch.Tensor, memory: torch.Tensor, src_padding: torch.Tensor) -> torch.Tensor: | |
| tgt_padding = tgt.eq(self.pad_id) | |
| tgt_emb = self.pos(self.embedding(tgt) * self.scale) | |
| tgt_len = tgt.size(1) | |
| causal_mask = torch.triu( | |
| torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=tgt.device), | |
| diagonal=1, | |
| ) | |
| return self.transformer.decoder( | |
| tgt_emb, | |
| memory, | |
| tgt_mask=causal_mask, | |
| tgt_key_padding_mask=tgt_padding, | |
| memory_key_padding_mask=src_padding, | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| decoder_input_ids: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| return_dict: bool | None = None, | |
| **kwargs, | |
| ) -> Seq2SeqLMOutput: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if decoder_input_ids is None and labels is not None: | |
| decoder_input_ids = self._shift_right(labels) | |
| if decoder_input_ids is None: | |
| raise ValueError("Provide decoder_input_ids or labels to EdenForTextEnhancement.forward.") | |
| memory, src_padding = self.encode(input_ids) | |
| hidden = self.decode(decoder_input_ids, memory, src_padding) | |
| logits = self.lm_head(hidden) | |
| loss = None | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits.float().reshape(-1, logits.size(-1)), | |
| labels.reshape(-1), | |
| ignore_index=-100, | |
| ) | |
| if not return_dict: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| return Seq2SeqLMOutput(loss=loss, logits=logits) | |
| def _shift_right(self, labels: torch.Tensor) -> torch.Tensor: | |
| shifted = labels.new_zeros(labels.shape) | |
| shifted[:, 1:] = labels[:, :-1].clone() | |
| shifted[:, 0] = self.bos_id | |
| shifted.masked_fill_(shifted == -100, self.pad_id) | |
| return shifted | |
| # ------------------------------------------------------------------ # | |
| # Generation (ported from the original trainer, no KV cache needed for | |
| # the short sequences this model handles) | |
| # ------------------------------------------------------------------ # | |
| def _beam_generate(self, src, beam_size, max_new_tokens, length_penalty, repetition_penalty): | |
| self.eval() | |
| device = src.device | |
| max_len = self.config.max_len | |
| memory, src_padding = self.encode(src) | |
| beams = [([self.bos_id], 0.0, False)] | |
| for _ in range(max_new_tokens): | |
| if all(done for _, _, done in beams): | |
| break | |
| candidates = [] | |
| for tokens, score, done in beams: | |
| if done: | |
| candidates.append((tokens, score, True)) | |
| continue | |
| tgt = torch.tensor([tokens[-max_len:]], dtype=torch.long, device=device) | |
| hidden = self.decode(tgt, memory, src_padding) | |
| logits = self.lm_head(hidden[:, -1, :]).float().squeeze(0) | |
| if repetition_penalty != 1.0: | |
| for token_id in set(tokens): | |
| if 0 <= token_id < logits.numel(): | |
| logits[token_id] /= repetition_penalty | |
| logits[self.unk_id] = -float("inf") | |
| logits[self.pad_id] = -float("inf") | |
| logits[self.bos_id] = -float("inf") | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| values, indices = torch.topk(log_probs, k=min(beam_size, log_probs.numel())) | |
| for value, index in zip(values.tolist(), indices.tolist()): | |
| new_tokens = tokens + [int(index)] | |
| candidates.append((new_tokens, score + float(value), int(index) == self.eos_id)) | |
| def rank(item): | |
| toks, sc, _ = item | |
| length = max(1, len(toks) - 1) | |
| return sc / (length ** length_penalty) | |
| candidates.sort(key=rank, reverse=True) | |
| beams = candidates[:beam_size] | |
| best = max(beams, key=lambda item: item[1] / (max(1, len(item[0]) - 1) ** length_penalty)) | |
| out = best[0][1:] | |
| if self.eos_id in out: | |
| out = out[: out.index(self.eos_id)] | |
| skip = {self.pad_id, self.bos_id, self.eos_id, self.unk_id} | |
| return [t for t in out if t not in skip] | |
| def _sample_generate(self, src, strategy, max_new_tokens, temperature, top_k, top_p, repetition_penalty): | |
| self.eval() | |
| device = src.device | |
| max_len = self.config.max_len | |
| memory, src_padding = self.encode(src) | |
| tokens = [self.bos_id] | |
| skip = {self.pad_id, self.bos_id, self.eos_id, self.unk_id} | |
| for _ in range(max_new_tokens): | |
| tgt = torch.tensor([tokens[-max_len:]], dtype=torch.long, device=device) | |
| hidden = self.decode(tgt, memory, src_padding) | |
| logits = self.lm_head(hidden[:, -1, :]).float().squeeze(0) | |
| if repetition_penalty != 1.0: | |
| for token_id in set(tokens): | |
| if 0 <= token_id < logits.numel(): | |
| logits[token_id] /= repetition_penalty | |
| logits[self.unk_id] = -float("inf") | |
| logits[self.pad_id] = -float("inf") | |
| logits[self.bos_id] = -float("inf") | |
| if strategy == "sample": | |
| logits = logits / max(0.05, temperature) | |
| logits = self._filter_top_k_top_p(logits, top_k, top_p) | |
| probs = F.softmax(logits, dim=-1) | |
| if not torch.isfinite(probs).all() or float(probs.sum().item()) <= 0: | |
| next_id = int(torch.argmax(logits).item()) | |
| else: | |
| next_id = int(torch.multinomial(probs.detach().cpu(), 1).item()) | |
| else: | |
| next_id = int(torch.argmax(logits).item()) | |
| if next_id == self.eos_id: | |
| break | |
| if next_id not in skip: | |
| tokens.append(next_id) | |
| return tokens[1:] | |
| def _filter_top_k_top_p(logits, top_k, top_p): | |
| filtered = logits.clone() | |
| if top_k > 0 and top_k < filtered.numel(): | |
| threshold = torch.topk(filtered, top_k).values[-1] | |
| filtered[filtered < threshold] = -float("inf") | |
| if 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(filtered, descending=True) | |
| probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative = torch.cumsum(probs, dim=-1) | |
| remove = cumulative > top_p | |
| remove[1:] = remove[:-1].clone() | |
| remove[0] = False | |
| filtered[sorted_indices[remove]] = -float("inf") | |
| return filtered | |
| def _chunk_text(self, text, tokenizer): | |
| text = _normalise_text(text) | |
| ids = tokenizer.encode(text, add_special_tokens=False) | |
| max_src = self.config.max_len - 2 | |
| if len(ids) <= max_src: | |
| return [text] | |
| chunks, current, current_ids = [], [], [] | |
| for sent in _sentence_split(text) or [text]: | |
| sent_ids = tokenizer.encode(sent, add_special_tokens=False) | |
| if current and len(current_ids) + len(sent_ids) > max_src: | |
| chunks.append(" ".join(current)) | |
| current, current_ids = [], [] | |
| if len(sent_ids) > max_src: | |
| for i in range(0, len(sent_ids), max_src): | |
| chunks.append(tokenizer.decode(sent_ids[i : i + max_src], skip_special_tokens=True)) | |
| else: | |
| current.append(sent) | |
| current_ids.extend(sent_ids) | |
| if current: | |
| chunks.append(" ".join(current)) | |
| return chunks | |
| def enhance( | |
| self, | |
| tokenizer, | |
| text: str, | |
| strategy: str = "beam", | |
| beam_size: int | None = None, | |
| max_new_tokens: int | None = None, | |
| temperature: float = 0.7, | |
| top_k: int = 40, | |
| top_p: float = 0.9, | |
| length_penalty: float | None = None, | |
| repetition_penalty: float | None = None, | |
| ) -> str: | |
| """Rewrite ``text`` into polished text. | |
| This mirrors the original trainer's enhancement pipeline: long inputs are | |
| split into sentence-aware chunks, each chunk is rewritten, and the | |
| results are joined back together. | |
| """ | |
| strategy = strategy if strategy in {"beam", "greedy", "sample"} else "beam" | |
| beam = max(1, int(beam_size or self.config.beam_size)) | |
| cap = max(8, self.config.max_len - 1) | |
| max_tokens = int(max_new_tokens) if max_new_tokens else min(256, cap) | |
| max_tokens = max(8, min(cap, max_tokens)) | |
| len_penalty = self.config.length_penalty if length_penalty is None else float(length_penalty) | |
| rep_penalty = self.config.repetition_penalty if repetition_penalty is None else float(repetition_penalty) | |
| device = self.embedding.weight.device | |
| outputs = [] | |
| for chunk in self._chunk_text(text, tokenizer): | |
| src_ids = tokenizer.encode(chunk, add_special_tokens=False)[: self.config.max_len - 2] | |
| src = torch.tensor([[self.bos_id] + src_ids + [self.eos_id]], dtype=torch.long, device=device) | |
| if strategy == "beam": | |
| out_ids = self._beam_generate(src, beam, max_tokens, len_penalty, rep_penalty) | |
| else: | |
| out_ids = self._sample_generate( | |
| src, strategy, max_tokens, temperature, top_k, top_p, rep_penalty | |
| ) | |
| decoded = _normalise_text(tokenizer.decode(out_ids, skip_special_tokens=True)) | |
| outputs.append(decoded or chunk) | |
| return _normalise_text(" ".join(outputs)) | |